Flux - Batch data loop in callback causing GPU Memory Error

I’m trying to build a general tutorial for the basic build of an image classifier in flux on a Colab notebook… The dataset is Fashion MNIST. Building it out more and more I have found a problem with my code,

Recently I’ve tried a batch data function and loop minibatches function that my idea is should help cut down on GPU memory if I need it in the future. I understand that this is unnecessary for the current learning task but it can be generally applicable if the size of the input dataset grows larger. The training loop batching in training data will work but if I use mini batches in the callback function I seem to get an out of Memory error. I can try one of 2 things as I see it:

A. I don’t include turning val set into minibatches (commented code in the code below).
B. Fix the issue causing out of memory error.

I would prefer B. What can I do if anything to improve this example and pass over batches of data in the validation and testing set?

Code:

N = size(train_x)[end]
# random permute train indexes
ixs = collect(1:N)
shuffle!(ixs)
n = Int(floor(.8 * N))
# batch size
bs = 100
sz = (28, 28, 1, bs)

# 80/20 hold out split
train_split, val_split = ixs[1:n], ixs[n + 1:end]
# data = train_x[:, :, train_split], train_y[train_split]
train_data = train_x[:, :, train_split], train_y[train_split]
val_data = train_x[:, :, val_split], train_y[val_split]
'''
Make batches of x, y data

returns: list of tuples
'''
function make_batches(data, bs=100)
    n = size(data[1])[end]
    sz = (28, 28, 1, bs)
    iter = gpu.([(reshape(Float32.(data[1][:, :, i]), sz), onehotbatch(data[2][i], 0:9)) for i in partition(1:n, bs)])
end

train = make_batches(train_data)
val = make_batches(val_data)

# train = gpu.([(reshape(Float32.(data[1][:, :, i]), sz), onehotbatch(data[2][i], 0:9)) for i in partition(1:n, bs)])
# val_x = reshape(Float32.(train_x[:, :, val_split]), (28, 28, 1, length(val_split))) |> gpu;
# val_y = onehotbatch(train_y[val_split], 0:9) |> gpu;

Testing Model

In:

img = reshape(train[1][1][:, :, :, 1], (28, 28, 1, 1))
# img = reshape(val_x[:, :, :, 1], (28, 28, 1, 1))

out:

[:, :, 1, 1] =
 0.0        0.0       0.0         …  0.0       0.0        0.0     
 0.0        0.0       0.0            0.0       0.0        0.0     
 0.0        0.0       0.0            0.0       0.0        0.0     
 0.0        0.0       0.0            0.0       0.0        0.0     
 0.0        0.0       0.0            0.0       0.0        0.0     
 0.0        0.0       0.0         …  0.501961  0.529412   0.180392
 0.0        0.0       0.0            0.454902  0.529412   0.490196
 0.0        0.403922  0.447059       0.721569  0.490196   0.341176
 0.0        0.576471  0.733333       0.462745  0.231373   0.0     
 0.0705882  0.658824  0.768628       0.211765  0.0        0.0     
 0.447059   0.713726  0.764706    …  0.682353  0.0        0.0     
 0.372549   0.803922  0.701961       0.478431  0.0        0.0     
 0.0        0.968628  0.701961       0.521569  0.0        0.0     
 ⋮                                ⋱  ⋮                            
 0.313726   0.882353  0.729412       0.580392  0.0196078  0.0     
 0.317647   0.721569  0.686275       0.670588  0.0313726  0.0     
 0.258824   0.701961  0.862745       0.517647  0.0        0.0     
 0.184314   0.639216  0.819608       0.623529  0.0        0.0     
 0.0        0.572549  0.12549     …  0.717647  0.643137   0.0     
 0.0        0.0       0.0            0.454902  0.580392   0.435294
 0.0        0.0       0.00784314     0.811765  0.745098   0.294118
 0.0        0.0       0.0            0.352941  0.227451   0.0     
 0.0        0.0       0.0            0.0       0.0        0.0     
 0.0        0.0       0.0         …  0.0       0.0        0.0     
 0.0        0.0       0.0            0.0       0.0        0.0     
 0.0        0.0       0.0            0.0       0.0        0.0     

In:

m = model()

Out:

Chain(Conv((5, 5), 1=>64, elu), BatchNorm(64), MaxPool((3, 3), pad = (2, 2), stride = (2, 2)), Dropout(0.25), Conv((5, 5), 64=>128, elu), BatchNorm(128), MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2)), Dropout(0.25), Conv((5, 5), 128=>256, elu), BatchNorm(256), MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2)), Dropout(0.25), #10, Dense(2304, 256, elu), Dropout(0.5), Dense(256, 10), softmax)

In:

m(img)

Out:

10×1 CuArray{Float32,2,Nothing}:
 0.09464545 
 0.09338491 
 0.09972584 
 0.107280605
 0.10379242 
 0.104369685
 0.11624282 
 0.10718765 
 0.07197585 
 0.101394825

Train on loss function for one Epoch

In:

loss(x, y) = crossentropy(m(x), y)

# accuracy(x, y) = mean(onecold(m(x), 0:9) .== onecold(y, 0:9))

function accuracy(data)
    N = length(data)
    acc = sum([mean(onecold(m(x), 0:9) .== onecold(y, 0:9)) for (x, y) in data]) / N
end


# Defining the callback and the optimizer
# evalcb = throttle(() -> @show(accuracy(val_x, val_y)), 0.1)
evalcb = throttle(() -> @show(accuracy(val)), 1)

opt = ADAM()

# Starting to train models
Flux.train!(loss, params(m), train, opt, cb = evalcb)

Out:

┌ Warning: Performing scalar operations on GPU arrays: This is very slow, consider disallowing these operations with `allowscalar(false)`
└ @ GPUArrays /root/.julia/packages/GPUArrays/1wgPO/src/indexing.jl:16
accuracy(val) = 0.1685
CUDA error: out of memory (code 2, ERROR_OUT_OF_MEMORY)

Stacktrace:
 [1] throw_api_error(::CUDAdrv.cudaError_enum) at /root/.julia/packages/CUDAdrv/mCr0O/src/error.jl:136
 [2] macro expansion at /root/.julia/packages/CUDAdrv/mCr0O/src/error.jl:149 [inlined]
 [3] cuModuleLoadDataEx(::Base.RefValue{Ptr{Nothing}}, ::Ptr{UInt8}, ::Int64, ::Array{CUDAdrv.CUjit_option_enum,1}, ::Array{Ptr{Nothing},1}) at /root/.julia/packages/CUDAdrv/mCr0O/src/libcuda.jl:232
 [4] macro expansion at ./gcutils.jl:91 [inlined]
 [5] CuModule(::String, ::Dict{CUDAdrv.CUjit_option_enum,Any}) at /root/.julia/packages/CUDAdrv/mCr0O/src/module.jl:34
 [6] macro expansion at /root/.julia/packages/CUDAnative/Phjco/src/execution.jl:422 [inlined]
 [7] #cufunction#200(::String, ::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::typeof(cufunction), ::CuArrays.var"#kernel#34"{CuArrays.var"#48#49"{Float32}}, ::Type{Tuple{CuDeviceArray{Int64,1,CUDAnative.AS.Global},CuDeviceArray{Float32,1,CUDAnative.AS.Global}}}) at /root/.julia/packages/CUDAnative/Phjco/src/execution.jl:359
 [8] (::CUDAnative.var"#kw##cufunction")(::NamedTuple{(:name,),Tuple{String}}, ::typeof(cufunction), ::Function, ::Type) at ./none:0
 [9] findfirst(::CuArrays.var"#48#49"{Float32}, ::CuArray{Float32,1,Nothing}) at /root/.julia/packages/CUDAnative/Phjco/src/execution.jl:176
 [10] #findmax#47(::Function, ::typeof(findmax), ::CuArray{Float32,1,Nothing}) at /root/.julia/packages/CuArrays/rNxse/src/indexing.jl:240
 [11] #findmax at ./none:0 [inlined]
 [12] #argmax#616 at ./reducedim.jl:882 [inlined]
 [13] argmax at ./reducedim.jl:882 [inlined]
 [14] onecold at /root/.julia/packages/Flux/oX9Pi/src/onehot.jl:120 [inlined]
 [15] #25 at /root/.julia/packages/Flux/oX9Pi/src/onehot.jl:123 [inlined]
 [16] inner_mapslices!(::Bool, ::Base.Iterators.Drop{CartesianIndices{1,Tuple{Base.OneTo{Int64}}}}, ::Int64, ::Array{Any,1}, ::Array{Int64,1}, ::Array{Any,1}, ::CuArray{Float32,1,Nothing}, ::CuArray{Float32,2,Nothing}, ::Flux.var"#25#26"{Tuple{UnitRange{Int64}}}, ::CuArray{Int64,2,Nothing}) at ./abstractarray.jl:2040
 [17] #mapslices#109(::Int64, ::typeof(mapslices), ::Flux.var"#25#26"{Tuple{UnitRange{Int64}}}, ::CuArray{Float32,2,Nothing}) at ./abstractarray.jl:2030
 [18] #mapslices at ./none:0 [inlined]
 [19] onecold(::CuArray{Float32,2,Nothing}, ::UnitRange{Int64}) at /root/.julia/packages/Flux/oX9Pi/src/onehot.jl:122
 [20] (::var"#16#17")(::Tuple{CuArray{Float32,4,Nothing},Flux.OneHotMatrix{CuArray{Flux.OneHotVector,1,Nothing}}}) at ./none:0
 [21] iterate at ./generator.jl:47 [inlined]
 [22] collect_to!(::Array{Float64,1}, ::Base.Generator{Array{Tuple{CuArray{Float32,4,Nothing},Flux.OneHotMatrix{CuArray{Flux.OneHotVector,1,Nothing}}},1},var"#16#17"}, ::Int64, ::Int64) at ./array.jl:667
 [23] collect_to_with_first!(::Array{Float64,1}, ::Float64, ::Base.Generator{Array{Tuple{CuArray{Float32,4,Nothing},Flux.OneHotMatrix{CuArray{Flux.OneHotVector,1,Nothing}}},1},var"#16#17"}, ::Int64) at ./array.jl:646
 [24] collect(::Base.Generator{Array{Tuple{CuArray{Float32,4,Nothing},Flux.OneHotMatrix{CuArray{Flux.OneHotVector,1,Nothing}}},1},var"#16#17"}) at ./array.jl:627
 [25] accuracy(::Array{Tuple{CuArray{Float32,4,Nothing},Flux.OneHotMatrix{CuArray{Flux.OneHotVector,1,Nothing}}},1}) at ./In[34]:7
 [26] macro expansion at ./show.jl:562 [inlined]
 [27] (::var"#18#19")() at ./In[34]:13
 [28] (::Flux.var"#throttled#10#15"{Bool,Bool,var"#18#19",Int64})(::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::Flux.var"#throttled#14"{Flux.var"#throttled#10#15"{Bool,Bool,var"#18#19",Int64}}) at /root/.julia/packages/Flux/oX9Pi/src/utils.jl:125
 [29] throttled at /root/.julia/packages/Flux/oX9Pi/src/utils.jl:121 [inlined]
 [30] macro expansion at /root/.julia/packages/Flux/oX9Pi/src/optimise/train.jl:72 [inlined]
 [31] macro expansion at /root/.julia/packages/Juno/oLB1d/src/progress.jl:134 [inlined]
 [32] #train!#12(::Flux.var"#throttled#14"{Flux.var"#throttled#10#15"{Bool,Bool,var"#18#19",Int64}}, ::typeof(Flux.Optimise.train!), ::Function, ::Zygote.Params, ::Array{Tuple{CuArray{Float32,4,Nothing},Flux.OneHotMatrix{CuArray{Flux.OneHotVector,1,Nothing}}},1}, ::ADAM) at /root/.julia/packages/Flux/oX9Pi/src/optimise/train.jl:66
 [33] (::Flux.Optimise.var"#kw##train!")(::NamedTuple{(:cb,),Tuple{Flux.var"#throttled#14"{Flux.var"#throttled#10#15"{Bool,Bool,var"#18#19",Int64}}}}, ::typeof(Flux.Optimise.train!), ::Function, ::Zygote.Params, ::Array{Tuple{CuArray{Float32,4,Nothing},Flux.OneHotMatrix{CuArray{Flux.OneHotVector,1,Nothing}}},1}, ::ADAM) at ./none:0
 [34] top-level scope at In[34]:16

Closing this out as I answered my own question.

Now when the accuracy function is called, I get the correct value and no scalar iteration is used. The problem was the onecold function that I was using from another tutorial. It’s terrible and no-one should use it. The following code is what I came up with and seems to be performing well

eval_acc = []
batch_idx = 0

Defines accuracy metric to compute on data set
Pushes computed value to eval_acc array
Increments the batch index

function calc_metrics(data)
    global batch_idx
    acc = 0
    for batch in data
        x, y = batch
        pred = m(x) .> 0.5
        tp = Float32(sum((pred .+ y) .== Int16(2)))
        fp = Float32(sum((pred .- y) .== Int16(1)))
        fn = Float32(sum((pred .- y) .== Int16(-1)))
        tn = Float32(sum((pred .+ y) .== Int16(0)))
        acc += (tp + tn) / (tp + tn + fp + fn)
    end
    acc /= length(data)
    push!(eval_acc, acc)
    if batch_idx % 100 == 0
        @show(batch_idx)
    end
    
    batch_idx += 1
end

# Define the loss, callback and optimizer
loss(x, y) = crossentropy(m(x), y)
evalcb = () -> calc_metrics(val)
opt = ADAM()

# Start model training
Flux.train!(loss, params(m), train, opt, cb = evalcb)

The acc is computed for each mini-batch and pushed to an array for plotting. This could be smarter and indeed for more advanced examples I will improve on it but just glad it’s working as intended now.

Did you figure out where the memory leak was?

I had not following yesterdays post but I fixed it today and amended my post above.

You mean the onecold function from Flux? Care to be more specific then, and/or open an issue or pull request to improve its definition (or else the tutorial)?