I’m having a problem where memory usage is gradually increasing with each epoch when training large neural networks with Flux (v0.14.22) and CUDA (v5.5.2). At the same time, training appears to get progressively slower as the memory usage grows. When this slowdown occurs, I observe that my GPU is being used much less effectively. For example, at the start of training I will see a constant usage of around 70%, which drops to 40% after a few epochs.
Upon further investigation, I found that this problem only seems to affect convolutional networks, while vision transformers are largely unaffected. In an attempt to track down the issue, I created the following MWE:
using Flux, Metalhead, Random, Statistics, CUDA, cuDNN, Match
using Pipe: @pipe
function build_model(config::Symbol)
@match config begin
:ResNet => Flux.Chain(
Metalhead.ResNet(18).layers[1],
Flux.GlobalMeanPool(),
Flux.MLUtils.flatten,
Flux.Dense(512 => 1, sigmoid))
:MobileNet => Flux.Chain(
Metalhead.MobileNetv3(:small).layers[1],
Flux.GlobalMeanPool(),
Flux.MLUtils.flatten,
Flux.Dense(576 => 1024, hardswish),
Flux.Dropout(0.2),
Flux.Dense(1024 => 1, sigmoid))
:ViT => Flux.Chain(
Metalhead.ViT(:tiny, pretrain=false, patch_size=(16,16)).layers[1],
Flux.LayerNorm(192),
Flux.Dense(192 => 1, Flux.sigmoid))
end
end
loss(model, x, y) = @pipe model(x) |> Flux.binarycrossentropy(_, y)
free_memory() = round(Sys.free_memory() / 2^30, digits=2)
imgs = rand(Float32, 224, 224, 3, 10000)
labels = rand([0.0f0, 1.0f0], 1, 10000)
μ = mean(imgs, dims=(1, 2, 4))
σ = std(imgs, dims=(1, 2, 4))
norm_imgs = (imgs .- μ) ./ σ
data = Flux.DataLoader((norm_imgs, labels), batchsize=16, shuffle=true, buffer=true)
model = build_model(:ViT) |> Flux.gpu
opt = Flux.Optimisers.Adam()
opt_state = Flux.Optimisers.setup(opt, model)
for epoch in 1:10
for (x, y) in CUDA.CuIterator(data)
grads = Flux.gradient(m -> loss(m, x, y), model)
Flux.Optimisers.update!(opt_state, model, grads[1])
end
@info free_memory()
end
Here’s the results for each of the different model architectures:
ViT:
46.99
47.0
46.96
46.93
46.91
46.88
46.89
46.94
46.86
46.7
ResNet18:
42.13
41.74
41.44
41.08
40.7
40.38
40.04
39.66
39.36
38.97
MobileNet:
37.89
37.02
36.28
35.42
34.62
33.88
33.04
32.25
31.45
30.65
As we can see, memory usage under ViT remains fairly constant throughout the training process, while ResNet18 and MobileNet both show a significant increase as training progresses. It looks like I’m not the first person to report this issue, but none of the proposed solutions appear to be working for me. Does anyone have any ideas what could be causing this?