Flux runs out of memory

A couple of months ago, I started training a network (the generator of a GAN) that resulted in GPU OOMs. It has ResNet blocks, long skip connections e squeeze-and-excitation blocks, with a total of more than 400M params. After some searches (CUDA.jl docs, issue 137, issue 149, PR 427, PR 33448, and this very discussion), this post talked about setting the environment variable ENV["JULIA_CUDA_MEMORY_POOL"] = "none" before including CUDA.jl. In my case, this was the only way to get past the first batch (of 64 samples). Then, testing other configurations, including GC.gc(); CUDA.reclaim() between batches as well resulted in a ~15% reduction in epoch time.

That’s the setup I’ve been using since. As a side node, I recently replicated my NN in pytorch to see if the performance is better. In the forward pass through this NN, more than 10GB are allocated in my GPU (RTX 2060, 12GB) and I get GPU OOM. I tried the same changes I used with Flux.jl (GC and no GPU cache), but there was no improvement.

Other adjustments on my pytorch implementation don’t seem to be trivial enough to be worth it, considering I can train my model in Flux. Still, it would have been helpful to have all this GPU OOM discussion (and news about it) concentrated in one place. The interactions between julia’s GC, CUDA.jl and Flux.jl are discussed from time to time. But, as a user, I’d like to know the common pitfalls and best practices when I first encounter this issue.

2 Likes