Flux - Batch data loop in callback causing GPU Memory Error

I’m trying to build a general tutorial for the basic build of an image classifier in flux on a Colab notebook… The dataset is Fashion MNIST. Building it out more and more I have found a problem with my code,

Recently I’ve tried a batch data function and loop minibatches function that my idea is should help cut down on GPU memory if I need it in the future. I understand that this is unnecessary for the current learning task but it can be generally applicable if the size of the input dataset grows larger. The training loop batching in training data will work but if I use mini batches in the callback function I seem to get an out of Memory error. I can try one of 2 things as I see it:

A. I don’t include turning val set into minibatches (commented code in the code below).
B. Fix the issue causing out of memory error.

I would prefer B. What can I do if anything to improve this example and pass over batches of data in the validation and testing set?

Code:

N = size(train_x)[end]
# random permute train indexes
ixs = collect(1:N)
shuffle!(ixs)
n = Int(floor(.8 * N))
# batch size
bs = 100
sz = (28, 28, 1, bs)

# 80/20 hold out split
train_split, val_split = ixs[1:n], ixs[n + 1:end]
# data = train_x[:, :, train_split], train_y[train_split]
train_data = train_x[:, :, train_split], train_y[train_split]
val_data = train_x[:, :, val_split], train_y[val_split]
'''
Make batches of x, y data

returns: list of tuples
'''
function make_batches(data, bs=100)
    n = size(data[1])[end]
    sz = (28, 28, 1, bs)
    iter = gpu.([(reshape(Float32.(data[1][:, :, i]), sz), onehotbatch(data[2][i], 0:9)) for i in partition(1:n, bs)])
end

train = make_batches(train_data)
val = make_batches(val_data)

# train = gpu.([(reshape(Float32.(data[1][:, :, i]), sz), onehotbatch(data[2][i], 0:9)) for i in partition(1:n, bs)])
# val_x = reshape(Float32.(train_x[:, :, val_split]), (28, 28, 1, length(val_split))) |> gpu;
# val_y = onehotbatch(train_y[val_split], 0:9) |> gpu;

Testing Model

In:

img = reshape(train[1][1][:, :, :, 1], (28, 28, 1, 1))
# img = reshape(val_x[:, :, :, 1], (28, 28, 1, 1))

out:

[:, :, 1, 1] =
 0.0        0.0       0.0         …  0.0       0.0        0.0     
 0.0        0.0       0.0            0.0       0.0        0.0     
 0.0        0.0       0.0            0.0       0.0        0.0     
 0.0        0.0       0.0            0.0       0.0        0.0     
 0.0        0.0       0.0            0.0       0.0        0.0     
 0.0        0.0       0.0         …  0.501961  0.529412   0.180392
 0.0        0.0       0.0            0.454902  0.529412   0.490196
 0.0        0.403922  0.447059       0.721569  0.490196   0.341176
 0.0        0.576471  0.733333       0.462745  0.231373   0.0     
 0.0705882  0.658824  0.768628       0.211765  0.0        0.0     
 0.447059   0.713726  0.764706    …  0.682353  0.0        0.0     
 0.372549   0.803922  0.701961       0.478431  0.0        0.0     
 0.0        0.968628  0.701961       0.521569  0.0        0.0     
 ⋮                                ⋱  ⋮                            
 0.313726   0.882353  0.729412       0.580392  0.0196078  0.0     
 0.317647   0.721569  0.686275       0.670588  0.0313726  0.0     
 0.258824   0.701961  0.862745       0.517647  0.0        0.0     
 0.184314   0.639216  0.819608       0.623529  0.0        0.0     
 0.0        0.572549  0.12549     …  0.717647  0.643137   0.0     
 0.0        0.0       0.0            0.454902  0.580392   0.435294
 0.0        0.0       0.00784314     0.811765  0.745098   0.294118
 0.0        0.0       0.0            0.352941  0.227451   0.0     
 0.0        0.0       0.0            0.0       0.0        0.0     
 0.0        0.0       0.0         …  0.0       0.0        0.0     
 0.0        0.0       0.0            0.0       0.0        0.0     
 0.0        0.0       0.0            0.0       0.0        0.0     

In:

m = model()

Out:

Chain(Conv((5, 5), 1=>64, elu), BatchNorm(64), MaxPool((3, 3), pad = (2, 2), stride = (2, 2)), Dropout(0.25), Conv((5, 5), 64=>128, elu), BatchNorm(128), MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2)), Dropout(0.25), Conv((5, 5), 128=>256, elu), BatchNorm(256), MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2)), Dropout(0.25), #10, Dense(2304, 256, elu), Dropout(0.5), Dense(256, 10), softmax)

In:

m(img)

Out:

10×1 CuArray{Float32,2,Nothing}:
 0.09464545 
 0.09338491 
 0.09972584 
 0.107280605
 0.10379242 
 0.104369685
 0.11624282 
 0.10718765 
 0.07197585 
 0.101394825

Train on loss function for one Epoch

In:

loss(x, y) = crossentropy(m(x), y)

# accuracy(x, y) = mean(onecold(m(x), 0:9) .== onecold(y, 0:9))

function accuracy(data)
    N = length(data)
    acc = sum([mean(onecold(m(x), 0:9) .== onecold(y, 0:9)) for (x, y) in data]) / N
end


# Defining the callback and the optimizer
# evalcb = throttle(() -> @show(accuracy(val_x, val_y)), 0.1)
evalcb = throttle(() -> @show(accuracy(val)), 1)

opt = ADAM()

# Starting to train models
Flux.train!(loss, params(m), train, opt, cb = evalcb)

Out:

┌ Warning: Performing scalar operations on GPU arrays: This is very slow, consider disallowing these operations with `allowscalar(false)`
└ @ GPUArrays /root/.julia/packages/GPUArrays/1wgPO/src/indexing.jl:16
accuracy(val) = 0.1685
CUDA error: out of memory (code 2, ERROR_OUT_OF_MEMORY)

Stacktrace:
 [1] throw_api_error(::CUDAdrv.cudaError_enum) at /root/.julia/packages/CUDAdrv/mCr0O/src/error.jl:136
 [2] macro expansion at /root/.julia/packages/CUDAdrv/mCr0O/src/error.jl:149 [inlined]
 [3] cuModuleLoadDataEx(::Base.RefValue{Ptr{Nothing}}, ::Ptr{UInt8}, ::Int64, ::Array{CUDAdrv.CUjit_option_enum,1}, ::Array{Ptr{Nothing},1}) at /root/.julia/packages/CUDAdrv/mCr0O/src/libcuda.jl:232
 [4] macro expansion at ./gcutils.jl:91 [inlined]
 [5] CuModule(::String, ::Dict{CUDAdrv.CUjit_option_enum,Any}) at /root/.julia/packages/CUDAdrv/mCr0O/src/module.jl:34
 [6] macro expansion at /root/.julia/packages/CUDAnative/Phjco/src/execution.jl:422 [inlined]
 [7] #cufunction#200(::String, ::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::typeof(cufunction), ::CuArrays.var"#kernel#34"{CuArrays.var"#48#49"{Float32}}, ::Type{Tuple{CuDeviceArray{Int64,1,CUDAnative.AS.Global},CuDeviceArray{Float32,1,CUDAnative.AS.Global}}}) at /root/.julia/packages/CUDAnative/Phjco/src/execution.jl:359
 [8] (::CUDAnative.var"#kw##cufunction")(::NamedTuple{(:name,),Tuple{String}}, ::typeof(cufunction), ::Function, ::Type) at ./none:0
 [9] findfirst(::CuArrays.var"#48#49"{Float32}, ::CuArray{Float32,1,Nothing}) at /root/.julia/packages/CUDAnative/Phjco/src/execution.jl:176
 [10] #findmax#47(::Function, ::typeof(findmax), ::CuArray{Float32,1,Nothing}) at /root/.julia/packages/CuArrays/rNxse/src/indexing.jl:240
 [11] #findmax at ./none:0 [inlined]
 [12] #argmax#616 at ./reducedim.jl:882 [inlined]
 [13] argmax at ./reducedim.jl:882 [inlined]
 [14] onecold at /root/.julia/packages/Flux/oX9Pi/src/onehot.jl:120 [inlined]
 [15] #25 at /root/.julia/packages/Flux/oX9Pi/src/onehot.jl:123 [inlined]
 [16] inner_mapslices!(::Bool, ::Base.Iterators.Drop{CartesianIndices{1,Tuple{Base.OneTo{Int64}}}}, ::Int64, ::Array{Any,1}, ::Array{Int64,1}, ::Array{Any,1}, ::CuArray{Float32,1,Nothing}, ::CuArray{Float32,2,Nothing}, ::Flux.var"#25#26"{Tuple{UnitRange{Int64}}}, ::CuArray{Int64,2,Nothing}) at ./abstractarray.jl:2040
 [17] #mapslices#109(::Int64, ::typeof(mapslices), ::Flux.var"#25#26"{Tuple{UnitRange{Int64}}}, ::CuArray{Float32,2,Nothing}) at ./abstractarray.jl:2030
 [18] #mapslices at ./none:0 [inlined]
 [19] onecold(::CuArray{Float32,2,Nothing}, ::UnitRange{Int64}) at /root/.julia/packages/Flux/oX9Pi/src/onehot.jl:122
 [20] (::var"#16#17")(::Tuple{CuArray{Float32,4,Nothing},Flux.OneHotMatrix{CuArray{Flux.OneHotVector,1,Nothing}}}) at ./none:0
 [21] iterate at ./generator.jl:47 [inlined]
 [22] collect_to!(::Array{Float64,1}, ::Base.Generator{Array{Tuple{CuArray{Float32,4,Nothing},Flux.OneHotMatrix{CuArray{Flux.OneHotVector,1,Nothing}}},1},var"#16#17"}, ::Int64, ::Int64) at ./array.jl:667
 [23] collect_to_with_first!(::Array{Float64,1}, ::Float64, ::Base.Generator{Array{Tuple{CuArray{Float32,4,Nothing},Flux.OneHotMatrix{CuArray{Flux.OneHotVector,1,Nothing}}},1},var"#16#17"}, ::Int64) at ./array.jl:646
 [24] collect(::Base.Generator{Array{Tuple{CuArray{Float32,4,Nothing},Flux.OneHotMatrix{CuArray{Flux.OneHotVector,1,Nothing}}},1},var"#16#17"}) at ./array.jl:627
 [25] accuracy(::Array{Tuple{CuArray{Float32,4,Nothing},Flux.OneHotMatrix{CuArray{Flux.OneHotVector,1,Nothing}}},1}) at ./In[34]:7
 [26] macro expansion at ./show.jl:562 [inlined]
 [27] (::var"#18#19")() at ./In[34]:13
 [28] (::Flux.var"#throttled#10#15"{Bool,Bool,var"#18#19",Int64})(::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::Flux.var"#throttled#14"{Flux.var"#throttled#10#15"{Bool,Bool,var"#18#19",Int64}}) at /root/.julia/packages/Flux/oX9Pi/src/utils.jl:125
 [29] throttled at /root/.julia/packages/Flux/oX9Pi/src/utils.jl:121 [inlined]
 [30] macro expansion at /root/.julia/packages/Flux/oX9Pi/src/optimise/train.jl:72 [inlined]
 [31] macro expansion at /root/.julia/packages/Juno/oLB1d/src/progress.jl:134 [inlined]
 [32] #train!#12(::Flux.var"#throttled#14"{Flux.var"#throttled#10#15"{Bool,Bool,var"#18#19",Int64}}, ::typeof(Flux.Optimise.train!), ::Function, ::Zygote.Params, ::Array{Tuple{CuArray{Float32,4,Nothing},Flux.OneHotMatrix{CuArray{Flux.OneHotVector,1,Nothing}}},1}, ::ADAM) at /root/.julia/packages/Flux/oX9Pi/src/optimise/train.jl:66
 [33] (::Flux.Optimise.var"#kw##train!")(::NamedTuple{(:cb,),Tuple{Flux.var"#throttled#14"{Flux.var"#throttled#10#15"{Bool,Bool,var"#18#19",Int64}}}}, ::typeof(Flux.Optimise.train!), ::Function, ::Zygote.Params, ::Array{Tuple{CuArray{Float32,4,Nothing},Flux.OneHotMatrix{CuArray{Flux.OneHotVector,1,Nothing}}},1}, ::ADAM) at ./none:0
 [34] top-level scope at In[34]:16