Neural ODE minibatch error with multi-dimension input condition

Hello -

I’m trying to add minibatching to a problem discussed in https://discourse.julialang.org/t/neuralode-layer-strategy/71297/4. I’ve successfully refactored the code to predict either a single u0 or matrix of u0s, but it fails when adding minibatching to training. The code:

# Simple Pendulum Problem
using OrdinaryDiffEq, Plots, DiffEqFlux, Flux
using IterTools: ncycle
using Debugger

#Constants
const g = 9.81
L = 1.0

#Define the problem
function simplependulum(du,u,p,t)
if size(u) == (2,)
	θ = u[1]
	dθ = u[2]
	du[1] = dθ
	du[2] = -(g/L)*sin.(θ)
else
	for i = 1:size(u)[2]
		θ = u[1,i]
		dθ = u[2,i]
		du[1,i] = dθ
		du[2,i] = -(g/L).*sin.(θ)
	end
end
end

#Initial Conditions
u0 = [0,π/2]
u0 = reshape(u0,(2,1))
tspan = (0.0,6.0)

prob = ODEProblem(simplependulum, u0, tspan)
soln = solve(prob,Tsit5(),saveat=0.1)
ode_data = Array(soln)

function batchTraj(traj;predtime=10,batchsize=20,shuffle=false)
	maxidx = size(traj)[3] - predtime
	s = 1:maxidx
	xt = traj[:,1,s]
	yt = Array{Float64,3}(undef,2,predtime+1,length(s))
	for (i,idx) in enumerate(s)
		yt[:,:,i] = traj[:,1,idx:idx+predtime]
	end
	loader = Flux.Data.DataLoader((xt,yt),batchsize=batchsize,shuffle=false)
	return loader
end

function pend(u)
	# network prediction
	val = [u[2,:], -(g/L).*sin.(u[1,:])]
	return vcat(val'...)
end

dudt2 = Chain(x-> pend(x),
		Dense(2,20,tanh),
		Dense(20,2))

p,re = Flux.destructure(dudt2)
dudt(u,p,t) = re(p)(u)
prob = ODEProblem(dudt,u0,tspan)

function predict(x=u0; in_t=(0.0,1.0))
	_prob = remake(prob,u0=x,p=p,tspan=in_t)
	soln = Array(solve(_prob, Tsit5(), saveat=0.1))
	if length(size(x)) == 2
		soln = reshape(soln,2,size(soln)[3],size(soln)[2])
	end
	return soln
end

function loss(x,target)
	pred = predict(x)
	loss = sum(abs2, target .- pred)
end

# test data
predtime = 10
maxidx = size(ode_data)[3] - predtime
st = 1:8:maxidx
xt = ode_data[:,1,st]
yt = Array{Float64,3}(undef,2,predtime+1,length(st));
for (i,idx) in enumerate(st)
	yt[:,:,i] = ode_data[:,1,idx:idx+predtime]
end

train_loader = batchTraj(ode_data)

The code successfully outputs a u0 with “predict(u0)” and xt with “predict(xt)”, but running Flux.train! results in a weird error:

train_loader = batchTraj(ode_data);
Flux.train!(loss, p, train_loader, ADAM(0.05))
ERROR: MethodError: no method matching adjoint(::Tuple{Matrix{Float64}, Matrix{Float64}})
Closest candidates are:
  adjoint(::LinearAlgebra.Hessenberg) at C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.6\LinearAlgebra\src\hessenberg.jl:344
  adjoint(::ChainRulesCore.AbstractZero) at C:\Users\lindblot\.julia\packages\ChainRulesCore\BYuIz\src\differentials\abstract_zero.jl:23
  adjoint(::ReverseDiff.TrackedReal) at C:\Users\lindblot\.julia\packages\ReverseDiff\E4Tzn\src\derivatives\scalars.jl:7
  ...
Stacktrace:
  [1] (::Zygote.var"#back#742")(Δ::Tuple{Matrix{Float64}, Matrix{Float64}})
    @ Zygote ~\.julia\packages\Zygote\TaBlo\src\lib\array.jl:408
  [2] (::Zygote.var"#2991#back#743"{Zygote.var"#back#742"})(Δ::Tuple{Matrix{Float64}, Matrix{Float64}})
    @ Zygote ~\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
  [3] Pullback
    @ .\REPL[17]:3 [inlined]
  [4] (::typeof(∂(pend)))(Δ::Matrix{Float64})
    @ Zygote ~\.julia\packages\Zygote\TaBlo\src\compiler\interface2.jl:0
  [5] Pullback
    @ .\REPL[18]:1 [inlined]
  [6] (::typeof(∂(#2)))(Δ::Matrix{Float64})
    @ Zygote ~\.julia\packages\Zygote\TaBlo\src\compiler\interface2.jl:0
  [7] Pullback
    @ ~\.julia\packages\Flux\Zz9RI\src\layers\basic.jl:37 [inlined]
  [8] (::typeof(∂(applychain)))(Δ::Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}})
    @ Zygote ~\.julia\packages\Zygote\TaBlo\src\compiler\interface2.jl:0
  [9] Pullback
    @ ~\.julia\packages\Flux\Zz9RI\src\layers\basic.jl:39 [inlined]
 [10] (::typeof(∂(λ)))(Δ::Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}})
    @ Zygote ~\.julia\packages\Zygote\TaBlo\src\compiler\interface2.jl:0
 [11] Pullback
    @ .\REPL[20]:1 [inlined]
 [12] (::typeof(∂(dudt)))(Δ::Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}})
    @ Zygote ~\.julia\packages\Zygote\TaBlo\src\compiler\interface2.jl:0
 [13] (::DiffEqBase.var"#204#back#170"{typeof(∂(dudt))})(Δ::Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}})
    @ DiffEqBase ~\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
 [14] Pullback
    @ ~\.julia\packages\DiffEqSensitivity\cLl5o\src\derivative_wrappers.jl:454 [inlined]
 [15] (::typeof(∂(λ)))(Δ::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true})
    @ Zygote ~\.julia\packages\Zygote\TaBlo\src\compiler\interface2.jl:0
 [16] (::Zygote.var"#46#47"{typeof(∂(λ))})(Δ::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true})
    @ Zygote ~\.julia\packages\Zygote\TaBlo\src\compiler\interface.jl:41
 [17] _vecjacobian!(dλ::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, y::Matrix{Float64}, λ::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, p::Vector{Float32}, t::Float64, S::DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}, InterpolatingAdjoint{0, true, Val{:central}, ZygoteVJP, Bool}, Matrix{Float64}, ODESolution{Float64, 3, Vector{Matrix{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Matrix{Float64}}}, ODEProblem{Matrix{Float64}, Tuple{Float64, Float64}, false, Vector{Float32}, ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!)}, OrdinaryDiffEq.InterpolationData{ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Matrix{Float64}}, Vector{Float64}, Vector{Vector{Matrix{Float64}}}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}, DiffEqBase.DEStats}, Nothing, ODEProblem{Matrix{Float64}, Tuple{Float64, Float64}, false, Vector{Float32}, ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}}, isautojacvec::ZygoteVJP, dgrad::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, dy::Nothing, W::Nothing)
    @ DiffEqSensitivity ~\.julia\packages\DiffEqSensitivity\cLl5o\src\derivative_wrappers.jl:461
 [18] #vecjacobian!#36
    @ ~\.julia\packages\DiffEqSensitivity\cLl5o\src\derivative_wrappers.jl:224 [inlined]
 [19] (::DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}, InterpolatingAdjoint{0, true, Val{:central}, ZygoteVJP, Bool}, Matrix{Float64}, ODESolution{Float64, 3, Vector{Matrix{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Matrix{Float64}}}, ODEProblem{Matrix{Float64}, Tuple{Float64, Float64}, false, Vector{Float32}, ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!)}, OrdinaryDiffEq.InterpolationData{ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Matrix{Float64}}, Vector{Float64}, Vector{Vector{Matrix{Float64}}}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}, DiffEqBase.DEStats}, Nothing, ODEProblem{Matrix{Float64}, Tuple{Float64, Float64}, false, Vector{Float32}, ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}})(du::Vector{Float64}, u::Vector{Float64}, p::Vector{Float32}, t::Float64)
    @ DiffEqSensitivity ~\.julia\packages\DiffEqSensitivity\cLl5o\src\interpolating_adjoint.jl:116
 [20] ODEFunction
    @ ~\.julia\packages\SciMLBase\UIp7W\src\scimlfunctions.jl:334 [inlined]
... (cut for post limits)

After some debugging investigation, it’s the “gradient(ps) do” line in Flux.train that causes the error, but I haven’t figured it out any farther.

Any ideas on what’s wrong? Is there a better way to go about this?

@dhairyagandhi96 do you know what to do here?

The issue is with the pend function (or rather it’s backpass). I would confirm that the batches the data loader is producing are correct first. I can take a look in a later today.

On Monday I spent some time and got an example of this working in Python with the torchdiffeq library. Based on that, I switched up the pend function to be:

function pend(u)
	vcat([[0] [1]] * u, -(g/L) .* sin.([[1] [0]] * u))
end

Which gives a more useful error, and seems to align with it being the backpass:

Flux.train!(loss, p, train_loader, ADAM(0.05))
ERROR: Only reference types can be differentiated with `Params`.
Stacktrace:
 [1] error(s::String)
   @ Base .\error.jl:33
 [2] getindex
   @ ~\.julia\packages\Zygote\FPUm3\src\compiler\interface.jl:274 [inlined]
 [3] update!(opt::ADAM, xs::Params, gs::Zygote.Grads)
   @ Flux.Optimise ~\.julia\packages\Flux\qAdFM\src\optimise\train.jl:31
 [4] macro expansion
   @ ~\.julia\packages\Flux\qAdFM\src\optimise\train.jl:112 [inlined]
 [5] macro expansion
   @ ~\.julia\packages\Juno\n6wyj\src\progress.jl:134 [inlined]
 [6] train!(loss::Function, ps::Vector{Float32}, data::Flux.Data.DataLoader{Tuple{Matrix{Float64}, Array{Float64, 3}}, Random._GLOBAL_RNG}, opt::ADAM; cb::Flux.Optimise.var"#40#46")
   @ Flux.Optimise ~\.julia\packages\Flux\qAdFM\src\optimise\train.jl:107
 [7] train!(loss::Function, ps::Vector{Float32}, data::Flux.Data.DataLoader{Tuple{Matrix{Float64}, Array{Float64, 3}}, Random._GLOBAL_RNG}, opt::ADAM)
   @ Flux.Optimise ~\.julia\packages\Flux\qAdFM\src\optimise\train.jl:105
 [8] top-level scope
   @ REPL[33]:1
 [9] top-level scope
   @ ~\.julia\packages\CUDA\bki2w\src\initialization.jl:52

I’m confused about the dataloader being the culprit though, because if you use the debugger and pass an (x,y) set into the loss function, it returns the desired answer

function test(dl)
           for (x,y) in dl
               @show x
               break
           end
       end
test (generic function with 1 method)

julia> @enter test(train_loader)
In test(dl) at REPL[32]:1
 1  function test(dl)
>2      for (x,y) in dl
 3          @show x
 4          break
 5      end
 6  end

About to run: <(iterate)(Flux.Data.DataLoader{Tuple{Matrix{Float64}, Array{Float64, 3}}, Random._GLOBAL_RNG}(([0.0 0...>
1|debug> n
In test(dl) at REPL[32]:1
 1  function test(dl)
 2      for (x,y) in dl
>3          @show x
 4          break
 5      end
 6  end

About to run: <(repr)([0.0 0.15452704373769013 0.29410409037829427 0.4055027387309618 0.4785203439712781 0.506705434...>
1|julia> loss(x,y)
801.382935069844

Solved. I think there were two issues here. One was changing the pend function to

function pend(u)
	vcat([[0] [1]] * u, -(g/L) .* sin.([[1] [0]] * u))
end

and then the last error associated with that change was because I didn’t put the parameters in Flux.params(p), like https://discourse.julialang.org/t/zygote-error-only-reference-types-can-be-differentiated-with-params/38224.

I think the true answer here is why the new array multiplication + vcat worked and the original pend function did not. @dhairyagandhi96 if you have the reasoning for that, I’ll mark it as the answer

It’s basically that calling the splat + adjoint creates a tuple of gradients, which are not the same type as the primal.

1 Like