Hi I’m running some tests on CIFAR10, with a residual CNN, and so far I’m getting very discouraging results for the memory consumption. I’m doing something rather atypical where I regularize each layer of my network separately with a graph Laplacian, however it seems to take up way more memory than I would have imagined it should based on what I have seen previously in pytorch.
with a batch size of 100 it quickly eats up all 32GB of memory before I even reach the backpropagation!
I run with 1.5 M parameters in Pytorch, and a batch size of 200, and it consumes about 8 GB of memory.
In Flux I had to downscale severly and still I’m having trouble… I run with the last couple of layers cut off (otherwise same network), so I have 350k parameters. I run a batch size of 10 images, and it consumes ~30 GB of memory.
Should I expect the same behaviour with Zygote?