Flux.jl vanilla ANN loss goes to NaN with mini batch

I’m building a vanilla NN with 1 or 2 hidden layers for a regression problem that’s basically as follows:
I take N sets of known values (x_1, y_1), ..., (x_k, y_k) as input to an approximation solver that returns N sets of (a_1, b_1), ..., (a_k, b_k), and to a solver for true values that returns N sets of (c_1, d_1), ..., (c_k, d_k), where a_i \approx c_i and b_i \approx d_i, for i = 1, 2, ..., k, but could have large inaccuracies. I’m building an ANN that takes N x (k*4) input [x_1, y_1, a_1, b_1, x_2, y_2, a_2, b_2,...] and outputs [c_1', d_1', c_2', d_2', ...] that are better approximations to the true c, d values than the original a, b approximations.

I’m using ADAM optimizer, and naturally, my loss function is MSELoss and the model is simply two Dense layers chained together. All worked well, except when I have a large input size where K*4 \approx 6000 and N = 10000, training the full batch on GPU (Nvidia GTX960M, pretty trash but that’s all I have for now) gave me OutOfMemoryError(). The maximum N I can run full batch training without the error is about 2500. So I decided to try mini-batching the input, and feed the mini batches to GPU in each iteration as following:

randIdx = collect(1:1:size(trainData)[2])  # trainData has shape K*4 x N
numBatches = round(Int, floor(size(trainData)[2] / batch_size))
for epoch = 1:epochs
	println("epoch: ", epoch)
	Random.shuffle!(randIdx)  # to shuffle training set
	i = 1  
	for j = 1:numBatches
		println(j)
		batchData = trainData[:, randIdx[i:batch_size]] |> gpu
		batchTarget = trainTarget[:, randIdx[i:batch_size]] |> gpu
		Flux.train!(loss, Flux.params(model), [(batchData, batchTarget)], opt)
		epochTrainLoss += Tracker.data(loss(batchData, batchTarget))
		epochTrainAcc += Tracker.data(accuracy(batchData, batchTarget))
		i += batch_size
	end
	push!(trainLoss, epochTrainLoss / numBatches)
	push!(trainAcc, epochTrainAcc / numBatches)
	push!(valLoss, Tracker.data(loss(valData, valTarget)))
	push!(valAcc, Tracker.data(accuracy(valData, valTarget)))
	epochTrainLoss, epochTrainAcc = 0.0, 0.0   # reset values
end

However, now in the second iteration of the first epoch, I’m getting Loss is NaN error:

julia> include("case_general.jl")
epoch: 1
1
2.11837100982666
2
ERROR: LoadError: Loss is NaN
Stacktrace:
 [1] losscheck(::Tracker.TrackedReal{Float32}) at C:\Users\me\.julia\packages\Tracker\RRYy6\src\back.jl:155
 [2] gradient_(::getfield(Flux.Optimise, Symbol("##14#20")){getfield(Main, Symbol("#loss#256"))}, ::Tracker.Params) at C:\Users\me\.julia\packages\Tracker\RRYy6\src\back.jl:98
 [3] #gradient#24(::Bool, ::Function, ::Function, ::Tracker.Params) at C:\Users\me\.julia\packages\Tracker\RRYy6\src\back.jl:164
 [4] gradient at C:\Users\me\.julia\packages\Tracker\RRYy6\src\back.jl:164 [inlined]
 [5] macro expansion at C:\Users\me\.julia\packages\Flux\qXNjB\src\optimise\train.jl:71 [inlined]
 [6] macro expansion at C:\Users\me\.julia\packages\Juno\TfNYn\src\progress.jl:124 [inlined]
 [7] #train!#12(::getfield(Flux.Optimise, Symbol("##16#22")), ::Function, ::Function, ::Tracker.Params, ::Array{Tuple{CuArray{Float32,2},CuArray{Float32,2}},1}, ::ADAM) at C:\Users\me\.julia\package\train.jl:69
 [8] train!(::Function, ::Tracker.Params, ::Array{Tuple{CuArray{Float32,2},CuArray{Float32,2}},1}, ::ADAM) at C:\Users\me\.julia\packages\Flux\qXNjB\src\optimise\train.jl:67
 [9] #mlp#254(::Int64, ::Function, ::String, ::Array{Float64,2}, ::Array{Float64,2}, ::Float64, ::Int64, ::Int64, ::Int64) at F:\work\large cases\case_general.jl:156
 [10] (::getfield(Main, Symbol("#kw##mlp")))(::NamedTuple{(:K2,),Tuple{Int64}}, ::typeof(mlp), ::String, ::Array{Float64,2}, ::Array{Float64,2}, ::Float64, ::Int64, ::Int64, ::Int64) at .\none:0
 [11] macro expansion at .\util.jl:156 [inlined]
 [12] main(::Int64) at F:\work\large cases\case_general.jl:221
 [13] top-level scope at none:0
 [14] include at .\boot.jl:326 [inlined]
 [15] include_relative(::Module, ::String) at .\loading.jl:1038
 [16] include(::Module, ::String) at .\sysimg.jl:29
 [17] include(::String) at .\client.jl:403
 [18] top-level scope at none:0
in expression starting at F:\work\large cases\case_general.jl:232

I looked at this link on NaN loss and tried adjusting my learning rate to extremely small (1e-15) and increasing my batch size to 2500 (so that it should have the same behavior as full batch training with 2500 samples), but those didn’t help, and only the first iteration loss slightly changed (around 2.1). I also checked my input data and target arrays and made sure none of the elements is NaN or Inf. Any advice on how I can track where the loss value is exploding?

Update
I tried retraining much smaller models (K*4 \approx 80) with the mini batching code above and I’m still getting the exact same error (Loss NaN at second iteration of first epoch). I’m going to investigate further tomorrow but if anyone spotted any obvious errors please do let me know, thanks!

Update 2
Yep, I was an idiot and forgot to do i+batch_size in trainData[:, randIdx[i:batch_size]]. Now I just need to deal with the OutOfMemoryError().