Flux + GPU memory problems

I was trying to train some models available in the Metalhead (0.7) + Flux (0.13) using an NVIDIA GTX 1080 Ti (11 GB ) using some toy data (512 images size:224x224x3). I was running the code that I show at the end of the post using Metalhead.MobileNetv3(nclasses=2), Metalhead.ResNet34(pretrain=false,nclasses = 2) . However I was not able to run (memory problems) with Metalhead.ResNet50(pretrain=false,nclasses = 2) or Metalhead.ConvNeXt(:tiny, nclasses=2).

**With other training platforms, I was able to train models like ResNet50. **
Any problem with my training approach? Any advice?

Thank you in advance!

device=gpu
#model=Metalhead.MobileNetv3(nclasses=2)|>device # OK!!!!!!!!!!!!!!!!!! 58% de memoria
#model =ResNet34(pretrain=false,nclasses = 2)|> device #OK!!!!!!!!!!!!!! 99% de memoria
model =ResNet50(pretrain=false,nclasses = 2)|> device #notOK

#model =Metalhead.ConvNeXt(:tiny, nclasses=2)|> device #notOK

loss(ŷ, y) = logitcrossentropy(ŷ, y)

ps = Flux.params(model);  

η = 3e-4             # learning rate
λ = 1e-6  
    #opt = ADAM(args.η)
opt = AdaBelief(η)
if λ > 0 # add weight decay, equivalent to L2 regularization
   opt = Optimiser(WeightDecay(λ), opt)
end
   
#GC.gc(true); CUDA.reclaim();  


for ind=1:10
@info ind
    for (x, y) in traindl

        y=onehotbatch(y, ["normal","abnormal"])
        x, y = x |> device, y |> device
        gs = Flux.gradient(ps) do
            ŷ = model(x)
            loss(ŷ, y)
        end

        Flux.Optimise.update!(opt, ps, gs)
    end


end

As a first step, you could wrap traindl in CUDA.CuIterator. That should remove some memory pressure from the data loading.

2 Likes

It seems that the example related to ImageNet that I have found in the ExplicitFluxLayers.jl [ImageNet Example](ExplicitFluxLayers.jl/examples/ImageNet at main · avik-pal/ExplicitFluxLayers.jl · GitHub ExplicitFluxLayers.jl) addresses some problems related with GPU memory + Flux (that includes MPI; your suggestion is also used). The source code is also very well structured.

1 Like