Taking gradients of minibatch gradient with Flux

Hey, I’m trying to implement minibatching in a Neural ODE system. I calculate the gradient
for each batch with the following code:


using DiffEqFlux, OrdinaryDiffEq, Flux, Optim, Plots, LinearAlgebra, Distributions

ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))

dudt2 = FastChain(FastDense(2, 50, tanh), FastDense(50, 2))
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)
train_node = NeuralODE(dudt2, (0., 1.), Tsit5(), saveat = tsteps[1:40])
y_train = ode_data[:, 1:250]
training_span = tsteps[1:250]

function predict(p, time_batch)
    prob = NeuralODE(dudt2, (tsteps[1], tsteps[end]), Tsit5(), saveat = time_batch)
    prob(u0, p)
end

function loss(p, batch, time_batch)
    pred = predict(p, time_batch)
    sum(abs2, pred .- batch) 
end

train_loader = Flux.Data.DataLoader(y_train, training_span, batchsize = 5, shuffle = true)

function estimategrad(θ, y, t_batch)
    grad = zeros(length(θ))
    for i in 1:length(t_batch)
        grad .+= first(gradient(p -> loss(p, y[:,i], [t_batch[i]]), θ))
    end
    grad ./= length(t_batch)
end

I need to calculate the gradient of this function with respect to θ, but when I do:


 gradient(p -> estimategrad(p, y_train, training_span), prob_neuralode.p)
ERROR: BoundsError
Stacktrace:
 [1] macro expansion at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0 [inlined]
 [2] _pullback(::Zygote.Context, ::typeof(throw), ::BoundsError) at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:12
 [3] getindex at ./number.jl:78 [inlined]
 [4] literal_getindex at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/tools/builtins.jl:15 [inlined]
 [5] _pullback(::Zygote.Context, ::typeof(Zygote.literal_getindex), ::Float32, ::Val{2}) at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [6] loss at ./REPL[15]:3 [inlined]
 [7] _pullback(::Zygote.Context, ::typeof(ZygoteRules._pullback), ::Zygote.Context, ::typeof(loss), ::Array{Float32,1}, ::Array{Float32,1}, ::Array{Float64,1}) at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [8] #11 at ./REPL[23]:4 [inlined]
 [9] _pullback(::Zygote.Context, ::typeof(ZygoteRules._pullback), ::Zygote.Context, ::var"#11#12"{Array{Float32,2},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Int64}, ::Array{Float32,1}) at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [10] adjoint at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/lib/lib.jl:172 [inlined]
 [11] _pullback at /home/aslan_garcia/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:47 [inlined]
 [12] _pullback at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/compiler/interface.jl:38 [inlined]
 [13] _pullback(::Zygote.Context, ::typeof(ZygoteRules._pullback), ::var"#11#12"{Array{Float32,2},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Int64}, ::Array{Float32,1}) at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [14] adjoint at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/lib/lib.jl:172 [inlined]
 [15] _pullback at /home/aslan_garcia/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:47 [inlined]
 [16] pullback at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/compiler/interface.jl:44 [inlined]
 [17] _pullback(::Zygote.Context, ::typeof(ZygoteRules.pullback), ::var"#11#12"{Array{Float32,2},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Int64}, ::Array{Float32,1}) at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [18] adjoint at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/lib/lib.jl:172 [inlined]
 [19] _pullback at /home/aslan_garcia/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:47 [inlined]
 [20] gradient at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/compiler/interface.jl:53 [inlined]
 [21] _pullback(::Zygote.Context, ::typeof(gradient), ::var"#11#12"{Array{Float32,2},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Int64}, ::Array{Float32,1}) at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [22] estimategrad at ./REPL[23]:4 [inlined]
 [23] _pullback(::Zygote.Context, ::typeof(estimategrad), ::Array{Float32,1}, ::Array{Float32,2}, ::StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}) at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [24] #17 at ./REPL[31]:1 [inlined]
 [25] _pullback(::Zygote.Context, ::var"#17#18", ::Array{Float32,1}) at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [26] _pullback(::Function, ::Array{Float32,1}) at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/compiler/interface.jl:38
 [27] pullback(::Function, ::Array{Float32,1}) at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/compiler/interface.jl:44
 [28] gradient(::Function, ::Array{Float32,1}) at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/compiler/interface.jl:53
 [29] top-level scope at REPL[31]:1

Any ideas on how to work this out? Cheers