# Neural ODE minibatch error with multi-dimension input condition

Hello -

I’m trying to add minibatching to a problem discussed in https://discourse.julialang.org/t/neuralode-layer-strategy/71297/4. I’ve successfully refactored the code to predict either a single u0 or matrix of u0s, but it fails when adding minibatching to training. The code:

``````# Simple Pendulum Problem
using OrdinaryDiffEq, Plots, DiffEqFlux, Flux
using IterTools: ncycle
using Debugger

#Constants
const g = 9.81
L = 1.0

#Define the problem
function simplependulum(du,u,p,t)
if size(u) == (2,)
θ = u[1]
dθ = u[2]
du[1] = dθ
du[2] = -(g/L)*sin.(θ)
else
for i = 1:size(u)[2]
θ = u[1,i]
dθ = u[2,i]
du[1,i] = dθ
du[2,i] = -(g/L).*sin.(θ)
end
end
end

#Initial Conditions
u0 = [0,π/2]
u0 = reshape(u0,(2,1))
tspan = (0.0,6.0)

prob = ODEProblem(simplependulum, u0, tspan)
soln = solve(prob,Tsit5(),saveat=0.1)
ode_data = Array(soln)

function batchTraj(traj;predtime=10,batchsize=20,shuffle=false)
maxidx = size(traj)[3] - predtime
s = 1:maxidx
xt = traj[:,1,s]
yt = Array{Float64,3}(undef,2,predtime+1,length(s))
for (i,idx) in enumerate(s)
yt[:,:,i] = traj[:,1,idx:idx+predtime]
end
loader = Flux.Data.DataLoader((xt,yt),batchsize=batchsize,shuffle=false)
return loader
end

function pend(u)
# network prediction
val = [u[2,:], -(g/L).*sin.(u[1,:])]
return vcat(val'...)
end

dudt2 = Chain(x-> pend(x),
Dense(2,20,tanh),
Dense(20,2))

p,re = Flux.destructure(dudt2)
dudt(u,p,t) = re(p)(u)
prob = ODEProblem(dudt,u0,tspan)

function predict(x=u0; in_t=(0.0,1.0))
_prob = remake(prob,u0=x,p=p,tspan=in_t)
soln = Array(solve(_prob, Tsit5(), saveat=0.1))
if length(size(x)) == 2
soln = reshape(soln,2,size(soln)[3],size(soln)[2])
end
return soln
end

function loss(x,target)
pred = predict(x)
loss = sum(abs2, target .- pred)
end

# test data
predtime = 10
maxidx = size(ode_data)[3] - predtime
st = 1:8:maxidx
xt = ode_data[:,1,st]
yt = Array{Float64,3}(undef,2,predtime+1,length(st));
for (i,idx) in enumerate(st)
yt[:,:,i] = ode_data[:,1,idx:idx+predtime]
end

train_loader = batchTraj(ode_data)
``````

The code successfully outputs a u0 with “predict(u0)” and xt with “predict(xt)”, but running Flux.train! results in a weird error:

``````train_loader = batchTraj(ode_data);
Flux.train!(loss, p, train_loader, ADAM(0.05))
ERROR: MethodError: no method matching adjoint(::Tuple{Matrix{Float64}, Matrix{Float64}})
Closest candidates are:
adjoint(::LinearAlgebra.Hessenberg) at C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.6\LinearAlgebra\src\hessenberg.jl:344
adjoint(::ChainRulesCore.AbstractZero) at C:\Users\lindblot\.julia\packages\ChainRulesCore\BYuIz\src\differentials\abstract_zero.jl:23
adjoint(::ReverseDiff.TrackedReal) at C:\Users\lindblot\.julia\packages\ReverseDiff\E4Tzn\src\derivatives\scalars.jl:7
...
Stacktrace:
[1] (::Zygote.var"#back#742")(Δ::Tuple{Matrix{Float64}, Matrix{Float64}})
@ Zygote ~\.julia\packages\Zygote\TaBlo\src\lib\array.jl:408
[2] (::Zygote.var"#2991#back#743"{Zygote.var"#back#742"})(Δ::Tuple{Matrix{Float64}, Matrix{Float64}})
@ Zygote ~\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
[3] Pullback
@ .\REPL[17]:3 [inlined]
[4] (::typeof(∂(pend)))(Δ::Matrix{Float64})
@ Zygote ~\.julia\packages\Zygote\TaBlo\src\compiler\interface2.jl:0
[5] Pullback
@ .\REPL[18]:1 [inlined]
[6] (::typeof(∂(#2)))(Δ::Matrix{Float64})
@ Zygote ~\.julia\packages\Zygote\TaBlo\src\compiler\interface2.jl:0
[7] Pullback
@ ~\.julia\packages\Flux\Zz9RI\src\layers\basic.jl:37 [inlined]
[8] (::typeof(∂(applychain)))(Δ::Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}})
@ Zygote ~\.julia\packages\Zygote\TaBlo\src\compiler\interface2.jl:0
[9] Pullback
@ ~\.julia\packages\Flux\Zz9RI\src\layers\basic.jl:39 [inlined]
[10] (::typeof(∂(λ)))(Δ::Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}})
@ Zygote ~\.julia\packages\Zygote\TaBlo\src\compiler\interface2.jl:0
[11] Pullback
@ .\REPL[20]:1 [inlined]
[12] (::typeof(∂(dudt)))(Δ::Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}})
@ Zygote ~\.julia\packages\Zygote\TaBlo\src\compiler\interface2.jl:0
[13] (::DiffEqBase.var"#204#back#170"{typeof(∂(dudt))})(Δ::Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}})
@ DiffEqBase ~\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
[14] Pullback
@ ~\.julia\packages\DiffEqSensitivity\cLl5o\src\derivative_wrappers.jl:454 [inlined]
[15] (::typeof(∂(λ)))(Δ::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true})
@ Zygote ~\.julia\packages\Zygote\TaBlo\src\compiler\interface2.jl:0
[16] (::Zygote.var"#46#47"{typeof(∂(λ))})(Δ::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true})
@ Zygote ~\.julia\packages\Zygote\TaBlo\src\compiler\interface.jl:41
[17] _vecjacobian!(dλ::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, y::Matrix{Float64}, λ::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, p::Vector{Float32}, t::Float64, S::DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}, InterpolatingAdjoint{0, true, Val{:central}, ZygoteVJP, Bool}, Matrix{Float64}, ODESolution{Float64, 3, Vector{Matrix{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Matrix{Float64}}}, ODEProblem{Matrix{Float64}, Tuple{Float64, Float64}, false, Vector{Float32}, ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!)}, OrdinaryDiffEq.InterpolationData{ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Matrix{Float64}}, Vector{Float64}, Vector{Vector{Matrix{Float64}}}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}, DiffEqBase.DEStats}, Nothing, ODEProblem{Matrix{Float64}, Tuple{Float64, Float64}, false, Vector{Float32}, ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}}, isautojacvec::ZygoteVJP, dgrad::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, dy::Nothing, W::Nothing)
@ DiffEqSensitivity ~\.julia\packages\DiffEqSensitivity\cLl5o\src\derivative_wrappers.jl:461
[18] #vecjacobian!#36
@ ~\.julia\packages\DiffEqSensitivity\cLl5o\src\derivative_wrappers.jl:224 [inlined]
[19] (::DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}, InterpolatingAdjoint{0, true, Val{:central}, ZygoteVJP, Bool}, Matrix{Float64}, ODESolution{Float64, 3, Vector{Matrix{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Matrix{Float64}}}, ODEProblem{Matrix{Float64}, Tuple{Float64, Float64}, false, Vector{Float32}, ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!)}, OrdinaryDiffEq.InterpolationData{ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Matrix{Float64}}, Vector{Float64}, Vector{Vector{Matrix{Float64}}}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}, DiffEqBase.DEStats}, Nothing, ODEProblem{Matrix{Float64}, Tuple{Float64, Float64}, false, Vector{Float32}, ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}})(du::Vector{Float64}, u::Vector{Float64}, p::Vector{Float32}, t::Float64)
@ DiffEqSensitivity ~\.julia\packages\DiffEqSensitivity\cLl5o\src\interpolating_adjoint.jl:116
[20] ODEFunction
@ ~\.julia\packages\SciMLBase\UIp7W\src\scimlfunctions.jl:334 [inlined]
... (cut for post limits)
``````

After some debugging investigation, it’s the “gradient(ps) do” line in Flux.train that causes the error, but I haven’t figured it out any farther.

Any ideas on what’s wrong? Is there a better way to go about this?

@dhairyagandhi96 do you know what to do here?

The issue is with the `pend` function (or rather it’s backpass). I would confirm that the batches the data loader is producing are correct first. I can take a look in a later today.

On Monday I spent some time and got an example of this working in Python with the torchdiffeq library. Based on that, I switched up the `pend` function to be:

``````function pend(u)
vcat([[0] [1]] * u, -(g/L) .* sin.([[1] [0]] * u))
end
``````

Which gives a more useful error, and seems to align with it being the backpass:

``````Flux.train!(loss, p, train_loader, ADAM(0.05))
ERROR: Only reference types can be differentiated with `Params`.
Stacktrace:
[1] error(s::String)
@ Base .\error.jl:33
[2] getindex
@ ~\.julia\packages\Zygote\FPUm3\src\compiler\interface.jl:274 [inlined]
[3] update!(opt::ADAM, xs::Params, gs::Zygote.Grads)
@ Flux.Optimise ~\.julia\packages\Flux\qAdFM\src\optimise\train.jl:31
[4] macro expansion
@ ~\.julia\packages\Flux\qAdFM\src\optimise\train.jl:112 [inlined]
[5] macro expansion
@ ~\.julia\packages\Juno\n6wyj\src\progress.jl:134 [inlined]
[6] train!(loss::Function, ps::Vector{Float32}, data::Flux.Data.DataLoader{Tuple{Matrix{Float64}, Array{Float64, 3}}, Random._GLOBAL_RNG}, opt::ADAM; cb::Flux.Optimise.var"#40#46")
@ Flux.Optimise ~\.julia\packages\Flux\qAdFM\src\optimise\train.jl:107
[7] train!(loss::Function, ps::Vector{Float32}, data::Flux.Data.DataLoader{Tuple{Matrix{Float64}, Array{Float64, 3}}, Random._GLOBAL_RNG}, opt::ADAM)
@ Flux.Optimise ~\.julia\packages\Flux\qAdFM\src\optimise\train.jl:105
[8] top-level scope
@ REPL[33]:1
[9] top-level scope
@ ~\.julia\packages\CUDA\bki2w\src\initialization.jl:52
``````

I’m confused about the dataloader being the culprit though, because if you use the debugger and pass an (x,y) set into the loss function, it returns the desired answer

``````function test(dl)
for (x,y) in dl
@show x
break
end
end
test (generic function with 1 method)

julia> @enter test(train_loader)
In test(dl) at REPL[32]:1
1  function test(dl)
>2      for (x,y) in dl
3          @show x
4          break
5      end
6  end

About to run: <(iterate)(Flux.Data.DataLoader{Tuple{Matrix{Float64}, Array{Float64, 3}}, Random._GLOBAL_RNG}(([0.0 0...>
1|debug> n
In test(dl) at REPL[32]:1
1  function test(dl)
2      for (x,y) in dl
>3          @show x
4          break
5      end
6  end

About to run: <(repr)([0.0 0.15452704373769013 0.29410409037829427 0.4055027387309618 0.4785203439712781 0.506705434...>
1|julia> loss(x,y)
801.382935069844
``````

Solved. I think there were two issues here. One was changing the pend function to

``````function pend(u)
vcat([[0] [1]] * u, -(g/L) .* sin.([[1] [0]] * u))
end
``````

and then the last error associated with that change was because I didn’t put the parameters in `Flux.params(p)`, like https://discourse.julialang.org/t/zygote-error-only-reference-types-can-be-differentiated-with-params/38224.

I think the true answer here is why the new array multiplication + vcat worked and the original `pend` function did not. @dhairyagandhi96 if you have the reasoning for that, I’ll mark it as the answer

It’s basically that calling the splat + adjoint creates a tuple of gradients, which are not the same type as the primal.

1 Like