Hey, I’m trying to implement minibatching in a Neural ODE system. I calculate the gradient
for each batch with the following code:
using DiffEqFlux, OrdinaryDiffEq, Flux, Optim, Plots, LinearAlgebra, Distributions
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))
dudt2 = FastChain(FastDense(2, 50, tanh), FastDense(50, 2))
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)
train_node = NeuralODE(dudt2, (0., 1.), Tsit5(), saveat = tsteps[1:40])
y_train = ode_data[:, 1:250]
training_span = tsteps[1:250]
function predict(p, time_batch)
prob = NeuralODE(dudt2, (tsteps[1], tsteps[end]), Tsit5(), saveat = time_batch)
prob(u0, p)
end
function loss(p, batch, time_batch)
pred = predict(p, time_batch)
sum(abs2, pred .- batch)
end
train_loader = Flux.Data.DataLoader(y_train, training_span, batchsize = 5, shuffle = true)
function estimategrad(θ, y, t_batch)
grad = zeros(length(θ))
for i in 1:length(t_batch)
grad .+= first(gradient(p -> loss(p, y[:,i], [t_batch[i]]), θ))
end
grad ./= length(t_batch)
end
I need to calculate the gradient of this function with respect to θ, but when I do:
gradient(p -> estimategrad(p, y_train, training_span), prob_neuralode.p)
ERROR: BoundsError
Stacktrace:
[1] macro expansion at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0 [inlined]
[2] _pullback(::Zygote.Context, ::typeof(throw), ::BoundsError) at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:12
[3] getindex at ./number.jl:78 [inlined]
[4] literal_getindex at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/tools/builtins.jl:15 [inlined]
[5] _pullback(::Zygote.Context, ::typeof(Zygote.literal_getindex), ::Float32, ::Val{2}) at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
[6] loss at ./REPL[15]:3 [inlined]
[7] _pullback(::Zygote.Context, ::typeof(ZygoteRules._pullback), ::Zygote.Context, ::typeof(loss), ::Array{Float32,1}, ::Array{Float32,1}, ::Array{Float64,1}) at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
[8] #11 at ./REPL[23]:4 [inlined]
[9] _pullback(::Zygote.Context, ::typeof(ZygoteRules._pullback), ::Zygote.Context, ::var"#11#12"{Array{Float32,2},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Int64}, ::Array{Float32,1}) at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
[10] adjoint at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/lib/lib.jl:172 [inlined]
[11] _pullback at /home/aslan_garcia/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:47 [inlined]
[12] _pullback at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/compiler/interface.jl:38 [inlined]
[13] _pullback(::Zygote.Context, ::typeof(ZygoteRules._pullback), ::var"#11#12"{Array{Float32,2},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Int64}, ::Array{Float32,1}) at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
[14] adjoint at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/lib/lib.jl:172 [inlined]
[15] _pullback at /home/aslan_garcia/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:47 [inlined]
[16] pullback at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/compiler/interface.jl:44 [inlined]
[17] _pullback(::Zygote.Context, ::typeof(ZygoteRules.pullback), ::var"#11#12"{Array{Float32,2},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Int64}, ::Array{Float32,1}) at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
[18] adjoint at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/lib/lib.jl:172 [inlined]
[19] _pullback at /home/aslan_garcia/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:47 [inlined]
[20] gradient at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/compiler/interface.jl:53 [inlined]
[21] _pullback(::Zygote.Context, ::typeof(gradient), ::var"#11#12"{Array{Float32,2},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Int64}, ::Array{Float32,1}) at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
[22] estimategrad at ./REPL[23]:4 [inlined]
[23] _pullback(::Zygote.Context, ::typeof(estimategrad), ::Array{Float32,1}, ::Array{Float32,2}, ::StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}) at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
[24] #17 at ./REPL[31]:1 [inlined]
[25] _pullback(::Zygote.Context, ::var"#17#18", ::Array{Float32,1}) at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
[26] _pullback(::Function, ::Array{Float32,1}) at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/compiler/interface.jl:38
[27] pullback(::Function, ::Array{Float32,1}) at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/compiler/interface.jl:44
[28] gradient(::Function, ::Array{Float32,1}) at /home/aslan_garcia/.julia/packages/Zygote/chgvX/src/compiler/interface.jl:53
[29] top-level scope at REPL[31]:1
Any ideas on how to work this out? Cheers