A implementation of ResNet-18 uses lot of GPU memory

I did couple of experiments and it seems that the ConvTranspose seems to be culprit here. Just compare performances with these two models

m = Chain(
  ConvTranspose((n, n), 3 => 3, stride = n),
  Conv((7,7), 3=>64, pad = (3,3), stride = (2,2)),
  MeanPool((7,7)),
  x -> reshape(x, :, size(x,4)),
  Dense(512*32, 10),
  softmax,
) |> gpu
m = Chain(
  Conv((7,7), 3=>64, pad = (3,3), stride = (2,2)),
  MeanPool((7,7)),
  x -> reshape(x, :, size(x,4)),
 Dense(256, 512*32),
  Dense(512*32, 10),
  softmax,
) |> gpu
1 Like