Memory usage seems extremely high in standard Flux, is it any better with Zygote?

Hi I’m running some tests on CIFAR10, with a residual CNN, and so far I’m getting very discouraging results for the memory consumption. I’m doing something rather atypical where I regularize each layer of my network separately with a graph Laplacian, however it seems to take up way more memory than I would have imagined it should based on what I have seen previously in pytorch.
with a batch size of 100 it quickly eats up all 32GB of memory before I even reach the backpropagation!

For comparison:
I run with 1.5 M parameters in Pytorch, and a batch size of 200, and it consumes about 8 GB of memory.
In Flux I had to downscale severly and still I’m having trouble… I run with the last couple of layers cut off (otherwise same network), so I have 350k parameters. I run a batch size of 10 images, and it consumes ~30 GB of memory.

Should I expect the same behaviour with Zygote?

Not directly answering the question and I’m not sure if it is related, but for reference:

Mine happens even in the first iteration, and furthermore I don’t use GPU/CuArrays at all at this point because of the large memory consumption, so I doubt it is related. But I’m guessing there is a difference in how the autograd is made between Pytorch and Flux?