GPU memory usage increasing on each epoch (Flux)

Unfortunately, the memory management heuristic in CUDA.jl performs poorly under some workloads. Forcing garbage collection by inserting a GC.gc(false); CUDA.reclaim() at the end of each epoch helps a lot.
You may also want to GC every few mini-batch iterations, in which case you have to roll your own training loop

for (i, (x, y)) in enumerate(train_data)
  g = gradient(model -> loss(model, x, y), model)[1]
  Flux.Optimise.update!(opt_state, model, g)
  if i % 10 == 0
    # @info "Batch $(i)"
    # CUDA.memory_status()
    GC.gc(false)
  end
end

A related issue is

Hopefully, an improvement will come from

2 Likes