Flux Transformer Out of Memory

So you mean PyTorch only use 3% of your gpu (so that’s 180MB) during backpropagation?

PyTorch uses that amount to store the model and all its gradients. I am not sure exactly how much is used during the calculations as some of it may be freed. PyTorch is about 4x faster though.

Something to consider is that 80% is not necessarily the amount that backpropagation needs; instead, it’s more like how much memory the GC felt like holding onto. Manual usage of CUDA.unsafe_free! would probably bring this number down significantly, but doing so in Flux directly isn’t really feasible in most cases (because who knows if the array might be accessed somewhere else).

I think it would be helpful to first figure out where these allocations are happening, and then determine if any of those allocations can be quickly unsafe_free!'d after usage.

Unfortunately, currently this cannot be done because the ChainRules is design to be used for gradient, jacobian, and also higher order AD. The memory management become tricky with those assumption.

1 Like

Not sure if it’s relevant in this specific context, but I’m wondering if @mcabbott WIP improvement to ChainRules’s CUDA memory management couldn"t be useful here: Flux.jl/cuda.jl at 494d351c15bec709af7049ccc6a2ffec4580f682 · mcabbott/Flux.jl · GitHub

The avbove rrule for Chain did brought significant improvement to the memory mgmt of Resnet models (more than doubled max batch size capacity IIRC).

I am not sure if this is relevant but CUDA.memory_status() shows No memory pool is in use.nothing. My understanding is that a memory pool is normally used by default and I never turned it off.

Also, I just did a more thorough benchmark of the PyTorch implementation and it never goes over 2.6 GB of GPU usage. With a batch size of 64 the PyTorch version never goes over 4.4 GB. The Flux version crashes with a batch size of 58 so that would be a memory usage of around 6 GB. It seems that Flux uses about 36% more memory than PyTorch in this case. Is this relative memory usage difference common or is Flux usually closer to PyTorch is terms of performance/memory usage?

1 Like