Hello -
I’m trying to add minibatching to a problem discussed in https://discourse.julialang.org/t/neuralode-layer-strategy/71297/4. I’ve successfully refactored the code to predict either a single u0 or matrix of u0s, but it fails when adding minibatching to training. The code:
# Simple Pendulum Problem
using OrdinaryDiffEq, Plots, DiffEqFlux, Flux
using IterTools: ncycle
using Debugger
#Constants
const g = 9.81
L = 1.0
#Define the problem
function simplependulum(du,u,p,t)
if size(u) == (2,)
θ = u[1]
dθ = u[2]
du[1] = dθ
du[2] = -(g/L)*sin.(θ)
else
for i = 1:size(u)[2]
θ = u[1,i]
dθ = u[2,i]
du[1,i] = dθ
du[2,i] = -(g/L).*sin.(θ)
end
end
end
#Initial Conditions
u0 = [0,π/2]
u0 = reshape(u0,(2,1))
tspan = (0.0,6.0)
prob = ODEProblem(simplependulum, u0, tspan)
soln = solve(prob,Tsit5(),saveat=0.1)
ode_data = Array(soln)
function batchTraj(traj;predtime=10,batchsize=20,shuffle=false)
maxidx = size(traj)[3] - predtime
s = 1:maxidx
xt = traj[:,1,s]
yt = Array{Float64,3}(undef,2,predtime+1,length(s))
for (i,idx) in enumerate(s)
yt[:,:,i] = traj[:,1,idx:idx+predtime]
end
loader = Flux.Data.DataLoader((xt,yt),batchsize=batchsize,shuffle=false)
return loader
end
function pend(u)
# network prediction
val = [u[2,:], -(g/L).*sin.(u[1,:])]
return vcat(val'...)
end
dudt2 = Chain(x-> pend(x),
Dense(2,20,tanh),
Dense(20,2))
p,re = Flux.destructure(dudt2)
dudt(u,p,t) = re(p)(u)
prob = ODEProblem(dudt,u0,tspan)
function predict(x=u0; in_t=(0.0,1.0))
_prob = remake(prob,u0=x,p=p,tspan=in_t)
soln = Array(solve(_prob, Tsit5(), saveat=0.1))
if length(size(x)) == 2
soln = reshape(soln,2,size(soln)[3],size(soln)[2])
end
return soln
end
function loss(x,target)
pred = predict(x)
loss = sum(abs2, target .- pred)
end
# test data
predtime = 10
maxidx = size(ode_data)[3] - predtime
st = 1:8:maxidx
xt = ode_data[:,1,st]
yt = Array{Float64,3}(undef,2,predtime+1,length(st));
for (i,idx) in enumerate(st)
yt[:,:,i] = ode_data[:,1,idx:idx+predtime]
end
train_loader = batchTraj(ode_data)
The code successfully outputs a u0 with “predict(u0)” and xt with “predict(xt)”, but running Flux.train! results in a weird error:
train_loader = batchTraj(ode_data);
Flux.train!(loss, p, train_loader, ADAM(0.05))
ERROR: MethodError: no method matching adjoint(::Tuple{Matrix{Float64}, Matrix{Float64}})
Closest candidates are:
adjoint(::LinearAlgebra.Hessenberg) at C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.6\LinearAlgebra\src\hessenberg.jl:344
adjoint(::ChainRulesCore.AbstractZero) at C:\Users\lindblot\.julia\packages\ChainRulesCore\BYuIz\src\differentials\abstract_zero.jl:23
adjoint(::ReverseDiff.TrackedReal) at C:\Users\lindblot\.julia\packages\ReverseDiff\E4Tzn\src\derivatives\scalars.jl:7
...
Stacktrace:
[1] (::Zygote.var"#back#742")(Δ::Tuple{Matrix{Float64}, Matrix{Float64}})
@ Zygote ~\.julia\packages\Zygote\TaBlo\src\lib\array.jl:408
[2] (::Zygote.var"#2991#back#743"{Zygote.var"#back#742"})(Δ::Tuple{Matrix{Float64}, Matrix{Float64}})
@ Zygote ~\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
[3] Pullback
@ .\REPL[17]:3 [inlined]
[4] (::typeof(∂(pend)))(Δ::Matrix{Float64})
@ Zygote ~\.julia\packages\Zygote\TaBlo\src\compiler\interface2.jl:0
[5] Pullback
@ .\REPL[18]:1 [inlined]
[6] (::typeof(∂(#2)))(Δ::Matrix{Float64})
@ Zygote ~\.julia\packages\Zygote\TaBlo\src\compiler\interface2.jl:0
[7] Pullback
@ ~\.julia\packages\Flux\Zz9RI\src\layers\basic.jl:37 [inlined]
[8] (::typeof(∂(applychain)))(Δ::Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}})
@ Zygote ~\.julia\packages\Zygote\TaBlo\src\compiler\interface2.jl:0
[9] Pullback
@ ~\.julia\packages\Flux\Zz9RI\src\layers\basic.jl:39 [inlined]
[10] (::typeof(∂(λ)))(Δ::Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}})
@ Zygote ~\.julia\packages\Zygote\TaBlo\src\compiler\interface2.jl:0
[11] Pullback
@ .\REPL[20]:1 [inlined]
[12] (::typeof(∂(dudt)))(Δ::Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}})
@ Zygote ~\.julia\packages\Zygote\TaBlo\src\compiler\interface2.jl:0
[13] (::DiffEqBase.var"#204#back#170"{typeof(∂(dudt))})(Δ::Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}})
@ DiffEqBase ~\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
[14] Pullback
@ ~\.julia\packages\DiffEqSensitivity\cLl5o\src\derivative_wrappers.jl:454 [inlined]
[15] (::typeof(∂(λ)))(Δ::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true})
@ Zygote ~\.julia\packages\Zygote\TaBlo\src\compiler\interface2.jl:0
[16] (::Zygote.var"#46#47"{typeof(∂(λ))})(Δ::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true})
@ Zygote ~\.julia\packages\Zygote\TaBlo\src\compiler\interface.jl:41
[17] _vecjacobian!(dλ::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, y::Matrix{Float64}, λ::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, p::Vector{Float32}, t::Float64, S::DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}, InterpolatingAdjoint{0, true, Val{:central}, ZygoteVJP, Bool}, Matrix{Float64}, ODESolution{Float64, 3, Vector{Matrix{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Matrix{Float64}}}, ODEProblem{Matrix{Float64}, Tuple{Float64, Float64}, false, Vector{Float32}, ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!)}, OrdinaryDiffEq.InterpolationData{ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Matrix{Float64}}, Vector{Float64}, Vector{Vector{Matrix{Float64}}}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}, DiffEqBase.DEStats}, Nothing, ODEProblem{Matrix{Float64}, Tuple{Float64, Float64}, false, Vector{Float32}, ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}}, isautojacvec::ZygoteVJP, dgrad::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, dy::Nothing, W::Nothing)
@ DiffEqSensitivity ~\.julia\packages\DiffEqSensitivity\cLl5o\src\derivative_wrappers.jl:461
[18] #vecjacobian!#36
@ ~\.julia\packages\DiffEqSensitivity\cLl5o\src\derivative_wrappers.jl:224 [inlined]
[19] (::DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}, InterpolatingAdjoint{0, true, Val{:central}, ZygoteVJP, Bool}, Matrix{Float64}, ODESolution{Float64, 3, Vector{Matrix{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Matrix{Float64}}}, ODEProblem{Matrix{Float64}, Tuple{Float64, Float64}, false, Vector{Float32}, ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!)}, OrdinaryDiffEq.InterpolationData{ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Matrix{Float64}}, Vector{Float64}, Vector{Vector{Matrix{Float64}}}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}, DiffEqBase.DEStats}, Nothing, ODEProblem{Matrix{Float64}, Tuple{Float64, Float64}, false, Vector{Float32}, ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}})(du::Vector{Float64}, u::Vector{Float64}, p::Vector{Float32}, t::Float64)
@ DiffEqSensitivity ~\.julia\packages\DiffEqSensitivity\cLl5o\src\interpolating_adjoint.jl:116
[20] ODEFunction
@ ~\.julia\packages\SciMLBase\UIp7W\src\scimlfunctions.jl:334 [inlined]
... (cut for post limits)
After some debugging investigation, it’s the “gradient(ps) do” line in Flux.train that causes the error, but I haven’t figured it out any farther.
Any ideas on what’s wrong? Is there a better way to go about this?