# MethodError in Custom Flux model

Hey everyone,

I have been trying to get my custom model running, which should be able to optimize both for a set of paramters and some initial conditions of an ODE model. The ODE’s dynamics are defined in `ode()`, the model is `AutoODE`.

I am new to Julia and have a hard time interpreting the error, or rather locating the cause of it. Any help to find the culprit would be much appreciated

``````
module myModule

using Flux
using DifferentialEquations

abstract type AbstractAutoODEModel end

function create_initial_conditions(N_states::Int=56)
# parameters
β = rand(N_states, 1)
γ = rand(N_states, 1)
μ = rand(N_states, 1)
σ = rand(N_states, 1)
a = rand(N_states, 1)
b = rand(N_states, 1)
A = rand(N_states, N_states)

θ = [β, γ, μ, σ, a, b, A]

# initial conditions
S₀ = fill(0.5, N_states, 1)
E₀ = fill(0.5, N_states, 1)
U₀ = @. ((1 - μ) * σ) * E₀

u₀ = [S₀ E₀ U₀]

return u₀, θ
end

function ode(u, p, t)
β, γ, μ, σ, a, b, A = p
S, E, U, I, R, D = eachcol(u)

transm = A * (I + E)
dS = @. - β * transm * S
dE = @. β * transm * S - σ*E
dU = @. (1-γ)*σ*E

dI = @. μ*σ*E - γ*I
dR = @. γ*I
r = @. a*t + b
dD = @. r * dR

du = [dS dE dU dI dR dD]
return du
end

struct AutoODE <: AbstractAutoODEModel
y₀      # constant initial conditions
u₀      # learnable initial conditions

θ       # learnable parameters
q       # constant parameters (optional)

f::Function

# constructor
function AutoODE(y₀; q=nothing)
N_states = size(y₀, 1)
u₀, θ = create_initial_conditions(N_states)

new(y₀, u₀, θ, q, ode)
end

end

@Flux.layer AbstractAutoODEModel trainable=(u₀, θ,)

get_p(a::AbstractAutoODEModel) = isnothing(a.q) ? a.θ : [a.θ..., a.q...]
get_x₀(a::AbstractAutoODEModel) = isnothing(a.u₀) ? a.y₀ : [a.u₀ a.y₀]

function (a::AbstractAutoODEModel)(t; alg=RK4(), kwargs...)
p = get_p(a)
x₀ = get_x₀(a)

tspan = (t[1], t[end])
prob = ODEProblem(a.f, x₀, tspan, p)
return Array(solve(prob, alg; saveat=t, kwargs...))
end

end

using .myModule
using Flux

T = 7
y = rand(5, 3, T)
t = collect(1:T)
model = myModule.AutoODE(y[:, :, 1])

Flux.gradient(model -> Flux.mse(model(t)[:, 4:6, :], y), model)

``````
Full error

``````ERROR: MethodError: no method matching Float64(::Matrix{Float64})

Closest candidates are:
(::Type{T})(::VectorizationBase.Double{T}) where T<:Union{Float16, Float32, Float64, VectorizationBase.Vec{<:Any, <:Union{Float16, Float32, Float64}}, VectorizationBase.VecUnroll{var"#s45", var"#s44", var"#s43", V} where {var"#s45", var"#s44", var"#s43"<:Union{Float16, Float32, Float64}, V<:Union{Bool, Float16, Float32, Float64, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8, SIMDTypes.Bit, VectorizationBase.AbstractSIMD{var"#s44", var"#s43"}}}}
@ VectorizationBase C:\Users\colin\.julia\packages\VectorizationBase\xE5Tx\src\special\double.jl:111
(::Type{T})(::AbstractChar) where T<:Union{AbstractChar, Number}
@ Base char.jl:50
(::Type{T})(::Base.TwicePrecision) where T<:Number
@ Base twiceprecision.jl:266
...

Stacktrace:
[1] convert
@ C:\Users\colin\.julia\packages\ForwardDiff\PcZ48\src\dual.jl:435 [inlined]
[4] getindex
[5] copy
[6] materialize(bc::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(convert), Tuple{Base.RefValue{Type{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.ODEFunction{false, SciMLBase.AutoSpecialize, typeof(Main.myModule.ode), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}}, Float64}, Float64, 12}}}, Vector{Matrix{Float64}}}})
[7] (::SciMLSensitivity.var"#333#342"{0, Array{Float64, 3}, Vector{Int64}, Base.Pairs{Symbol, Bool, Tuple{Symbol}, NamedTuple{(:verbose,), Tuple{Bool}}}, SciMLBase.ODEProblem{Matrix{Float64}, Tuple{Float64, Float64}, false, Vector{Matrix{Float64}}, SciMLBase.ODEFunction{false, SciMLBase.AutoSpecialize, typeof(Main.myModule.ode), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, OrdinaryDiffEq.RK4{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, SciMLSensitivity.ForwardDiffSensitivity{0, nothing}, Matrix{Float64}, Vector{Matrix{Float64}}, Tuple{}, Vector{Float64}})()
@ SciMLSensitivity C:\Users\colin\.julia\packages\SciMLSensitivity\4Ah3r\src\concrete_solve.jl:888
[8] unthunk
@ C:\Users\colin\.julia\packages\ChainRulesCore\zgT0R\src\tangent_types\thunks.jl:204 [inlined]
[9] wrap_chainrules_output
@ C:\Users\colin\.julia\packages\Zygote\jxHJc\src\compiler\chainrules.jl:110 [inlined]
[10] map
@ .\tuple.jl:275 [inlined]
[11] map (repeats 3 times)
@ .\tuple.jl:276 [inlined]
[12] wrap_chainrules_output
@ C:\Users\colin\.julia\packages\Zygote\jxHJc\src\compiler\chainrules.jl:111 [inlined]
[13] ZBack
@ C:\Users\colin\.julia\packages\Zygote\jxHJc\src\compiler\chainrules.jl:211 [inlined]
[14] kw_zpullback
@ C:\Users\colin\.julia\packages\Zygote\jxHJc\src\compiler\chainrules.jl:237 [inlined]
[15] #291
@ C:\Users\colin\.julia\packages\Zygote\jxHJc\src\lib\lib.jl:206 [inlined]
[16] (::Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{NTuple{6, Nothing}, Tuple{Nothing}}, Zygote.var"#kw_zpullback#53"{SciMLSensitivity.var"#forward_sensitivity_backpass#338"{0, Vector{Int64}, Base.Pairs{Symbol, Bool, Tuple{Symbol}, NamedTuple{(:verbose,), Tuple{Bool}}}, SciMLBase.ODEProblem{Matrix{Float64}, Tuple{Float64, Float64}, false, Vector{Matrix{Float64}}, SciMLBase.ODEFunction{false, SciMLBase.AutoSpecialize, typeof(Main.myModule.ode), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, OrdinaryDiffEq.RK4{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, SciMLSensitivity.ForwardDiffSensitivity{0, nothing}, Matrix{Float64}, Vector{Matrix{Float64}}, SciMLBase.ChainRulesOriginator, Tuple{}, Vector{Float64}}}}})(Δ::Array{Float64, 3})
[17] Pullback
@ C:\Users\colin\.julia\packages\DiffEqBase\eTCPy\src\solve.jl:980 [inlined]
[18] (::Zygote.Pullback{Tuple{DiffEqBase.var"##solve#40", Nothing, Nothing, Nothing, Val{true}, Base.Pairs{Symbol, Vector{Int64}, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Vector{Int64}}}}, typeof(CommonSolve.solve), SciMLBase.ODEProblem{Matrix{Float64}, Tuple{Int64, Int64}, false, Vector{Matrix{Float64}}, SciMLBase.ODEFunction{false, SciMLBase.AutoSpecialize, typeof(Main.myModule.ode), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, OrdinaryDiffEq.RK4{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Any})(Δ::Array{Float64, 3})
@ Zygote C:\Users\colin\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[19] #291
@ C:\Users\colin\.julia\packages\Zygote\jxHJc\src\lib\lib.jl:206 [inlined]
[20] (::Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{NTuple{7, Nothing}, Tuple{Nothing}}, Zygote.Pullback{Tuple{DiffEqBase.var"##solve#40", Nothing, Nothing, Nothing, Val{true}, Base.Pairs{Symbol, Vector{Int64}, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Vector{Int64}}}}, typeof(CommonSolve.solve), SciMLBase.ODEProblem{Matrix{Float64}, Tuple{Int64, Int64}, false, Vector{Matrix{Float64}}, SciMLBase.ODEFunction{false, SciMLBase.AutoSpecialize, typeof(Main.myModule.ode), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, OrdinaryDiffEq.RK4{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Any}}})(Δ::Array{Float64, 3})
[21] Pullback
@ C:\Users\colin\.julia\packages\DiffEqBase\eTCPy\src\solve.jl:970 [inlined]
[22] (::Zygote.Pullback{Tuple{typeof(Core.kwcall), NamedTuple{(:saveat,), Tuple{Vector{Int64}}}, typeof(CommonSolve.solve), SciMLBase.ODEProblem{Matrix{Float64}, Tuple{Int64, Int64}, false, Vector{Matrix{Float64}}, SciMLBase.ODEFunction{false, SciMLBase.AutoSpecialize, typeof(Main.myModule.ode), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, OrdinaryDiffEq.RK4{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Any})(Δ::Array{Float64, 3})
@ Zygote C:\Users\colin\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[23] Pullback
@ c:\Users\colin\OneDrive-TUM\Code\Julia\MMDS.jl\scripts\mwe2.jl:86 [inlined]
[24] (::Zygote.Pullback{Tuple{Main.myModule.var"##_#3", OrdinaryDiffEq.RK4{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Main.myModule.AutoODE, Vector{Int64}}, Tuple{Zygote.Pullback{Tuple{typeof(lastindex), Vector{Int64}}, Tuple{Zygote.ZBack{ChainRules.var"#eachindex_pullback#378"{Tuple{IndexLinear, Vector{Int64}}}}, Zygote.Pullback{Tuple{typeof(last), Base.OneTo{Int64}}, Tuple{Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:stop, Zygote.Context{false}, Base.OneTo{Int64}, Int64}}, Zygote.ZBack{Zygote.var"#convert_pullback#330"}}}, Zygote.Pullback{Tuple{Type{IndexLinear}}, Tuple{}}}}, Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:f, Zygote.Context{false}, Main.myModule.AutoODE, typeof(Main.myModule.ode)}}, Zygote.Pullback{Tuple{typeof(Main.myModule.get_x₀), Main.myModule.AutoODE}, Any}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Zygote.literal_getindex), Vector{Int64}, Val{1}}, Tuple{Zygote.ZBack{ChainRules.var"#getindex_pullback#1580"{Vector{Int64}, Tuple{Int64}, Tuple{ChainRulesCore.NoTangent}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(merge), NamedTuple{(:saveat,), Tuple{Vector{Int64}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Tuple{Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:data, Zygote.Context{false}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, NamedTuple{(), Tuple{}}}}, Zygote.Pullback{Tuple{typeof(merge), NamedTuple{(:saveat,), Tuple{Vector{Int64}}}, NamedTuple{(), Tuple{}}}, Tuple{}}}}, RecursiveArrayToolsZygoteExt.var"#182#back#103"{RecursiveArrayToolsZygoteExt.var"#99#102"}, Zygote.Pullback{Tuple{Type{NamedTuple{(:saveat,)}}, Tuple{Vector{Int64}}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:saveat,), Tuple{Vector{Int64}}}}, Tuple{Vector{Int64}}}, Tuple{Zygote.var"#2224#back#315"{Zygote.Jnew{NamedTuple{(:saveat,), Tuple{Vector{Int64}}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(Core.kwcall), NamedTuple{(:saveat,), Tuple{Vector{Int64}}}, typeof(CommonSolve.solve), SciMLBase.ODEProblem{Matrix{Float64}, Tuple{Int64, Int64}, false, Vector{Matrix{Float64}}, SciMLBase.ODEFunction{false, SciMLBase.AutoSpecialize, typeof(Main.myModule.ode), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, OrdinaryDiffEq.RK4{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Any}, Zygote.Pullback{Tuple{typeof(Main.myModule.get_p), Main.myModule.AutoODE}, Any}, Zygote.ZBack{ChainRules.var"#getindex_pullback#1580"{Vector{Int64}, Tuple{Int64}, Tuple{ChainRulesCore.NoTangent}}}, Zygote.ZBack{SciMLBaseChainRulesCoreExt.var"#ODEProblemAdjoint#14"}}})(Δ::Array{Float64, 3})
@ Zygote C:\Users\colin\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[25] Pullback
@ c:\Users\colin\OneDrive-TUM\Code\Julia\MMDS.jl\scripts\mwe2.jl:80 [inlined]
[26] (::Zygote.Pullback{Tuple{Main.myModule.AutoODE, Vector{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple}}, Tuple{}}, Zygote.var"#2370#back#419"{Zygote.var"#pairs_namedtuple_pullback#418"{(), NamedTuple{(), Tuple{}}}}, Zygote.Pullback{Tuple{Type{OrdinaryDiffEq.RK4}}, Tuple{}}, Zygote.Pullback{Tuple{Main.myModule.var"##_#3", OrdinaryDiffEq.RK4{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Main.myModule.AutoODE, Vector{Int64}}, Tuple{Zygote.Pullback{Tuple{typeof(lastindex), Vector{Int64}}, Tuple{Zygote.ZBack{ChainRules.var"#eachindex_pullback#378"{Tuple{IndexLinear, Vector{Int64}}}}, Zygote.Pullback{Tuple{typeof(last), Base.OneTo{Int64}}, Tuple{Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:stop, Zygote.Context{false}, Base.OneTo{Int64}, Int64}}, Zygote.ZBack{Zygote.var"#convert_pullback#330"}}}, Zygote.Pullback{Tuple{Type{IndexLinear}}, Tuple{}}}}, Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:f, Zygote.Context{false}, Main.myModule.AutoODE, typeof(Main.myModule.ode)}}, Zygote.Pullback{Tuple{typeof(Main.myModule.get_x₀), Main.myModule.AutoODE}, Any}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Zygote.literal_getindex), Vector{Int64}, Val{1}}, Tuple{Zygote.ZBack{ChainRules.var"#getindex_pullback#1580"{Vector{Int64}, Tuple{Int64}, Tuple{ChainRulesCore.NoTangent}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(merge), NamedTuple{(:saveat,), Tuple{Vector{Int64}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Tuple{Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:data, Zygote.Context{false}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, NamedTuple{(), Tuple{}}}}, Zygote.Pullback{Tuple{typeof(merge), NamedTuple{(:saveat,), Tuple{Vector{Int64}}}, NamedTuple{(), Tuple{}}}, Tuple{}}}}, RecursiveArrayToolsZygoteExt.var"#182#back#103"{RecursiveArrayToolsZygoteExt.var"#99#102"}, Zygote.Pullback{Tuple{Type{NamedTuple{(:saveat,)}}, Tuple{Vector{Int64}}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:saveat,), Tuple{Vector{Int64}}}}, Tuple{Vector{Int64}}}, Tuple{Zygote.var"#2224#back#315"{Zygote.Jnew{NamedTuple{(:saveat,), Tuple{Vector{Int64}}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(Core.kwcall), NamedTuple{(:saveat,), Tuple{Vector{Int64}}}, typeof(CommonSolve.solve), SciMLBase.ODEProblem{Matrix{Float64}, Tuple{Int64, Int64}, false, Vector{Matrix{Float64}}, SciMLBase.ODEFunction{false, SciMLBase.AutoSpecialize, typeof(Main.myModule.ode), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, OrdinaryDiffEq.RK4{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Any}, Zygote.Pullback{Tuple{typeof(Main.myModule.get_p), Main.myModule.AutoODE}, Any}, Zygote.ZBack{ChainRules.var"#getindex_pullback#1580"{Vector{Int64}, Tuple{Int64}, Tuple{ChainRulesCore.NoTangent}}}, Zygote.ZBack{SciMLBaseChainRulesCoreExt.var"#ODEProblemAdjoint#14"}}}}})(Δ::Array{Float64, 3})
@ Zygote C:\Users\colin\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[27] Pullback
@ c:\Users\colin\OneDrive-TUM\Code\Julia\MMDS.jl\scripts\mwe2.jl:100 [inlined]
[28] (::Zygote.Pullback{Tuple{var"#21#22", Main.myModule.AutoODE}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.Losses.mse), Array{Float64, 3}, Array{Float64, 3}}, Tuple{Zygote.Pullback{Tuple{Flux.Losses.var"##mse#14", typeof(Statistics.mean), typeof(Flux.Losses.mse), Array{Float64, 3}, Array{Float64, 3}}, Tuple{Zygote.ZBack{Flux.Losses.var"#_check_sizes_pullback#12"}, Zygote.var"#3976#back#1289"{Zygote.var"#1285#1288"{Array{Float64, 3}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Array{Float64, 3}}, Tuple{}}, Zygote.var"#3768#back#1191"{Zygote.var"#1187#1190"{Array{Float64, 3}, Array{Float64, 3}}}, Zygote.ZBack{ChainRules.var"#mean_pullback#1820"{Int64, ChainRules.var"#sum_pullback#1632"{Colon, Array{Float64, 3}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}}}}}}, Zygote.var"#1990#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Array{Float64, 3}}}, Zygote.var"#1990#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Vector{Int64}}}, Zygote.ZBack{ChainRules.var"#getindex_pullback#1580"{Array{Float64, 3}, Tuple{Colon, UnitRange{Int64}, Colon}, Tuple{ChainRulesCore.NoTangent, ChainRulesCore.NoTangent, ChainRulesCore.NoTangent}}}, Zygote.ZBack{ChainRules.var"#:_pullback#278"{Tuple{Int64, Int64}}}, Zygote.Pullback{Tuple{Main.myModule.AutoODE, Vector{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple}}, Tuple{}}, Zygote.var"#2370#back#419"{Zygote.var"#pairs_namedtuple_pullback#418"{(), NamedTuple{(), Tuple{}}}}, Zygote.Pullback{Tuple{Type{OrdinaryDiffEq.RK4}}, Tuple{}}, Zygote.Pullback{Tuple{Main.myModule.var"##_#3", OrdinaryDiffEq.RK4{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Main.myModule.AutoODE, Vector{Int64}}, Tuple{Zygote.Pullback{Tuple{typeof(lastindex), Vector{Int64}}, Tuple{Zygote.ZBack{ChainRules.var"#eachindex_pullback#378"{Tuple{IndexLinear, Vector{Int64}}}}, Zygote.Pullback{Tuple{typeof(last), Base.OneTo{Int64}}, Tuple{Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:stop, Zygote.Context{false}, Base.OneTo{Int64}, Int64}}, Zygote.ZBack{Zygote.var"#convert_pullback#330"}}}, Zygote.Pullback{Tuple{Type{IndexLinear}}, Tuple{}}}}, Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:f, Zygote.Context{false}, Main.myModule.AutoODE, typeof(Main.myModule.ode)}}, Zygote.Pullback{Tuple{typeof(Main.myModule.get_x₀), Main.myModule.AutoODE}, Any}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Zygote.literal_getindex), Vector{Int64}, Val{1}}, Tuple{Zygote.ZBack{ChainRules.var"#getindex_pullback#1580"{Vector{Int64}, Tuple{Int64}, Tuple{ChainRulesCore.NoTangent}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(merge), NamedTuple{(:saveat,), Tuple{Vector{Int64}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Tuple{Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:data, Zygote.Context{false}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, NamedTuple{(), Tuple{}}}}, Zygote.Pullback{Tuple{typeof(merge), NamedTuple{(:saveat,), Tuple{Vector{Int64}}}, NamedTuple{(), Tuple{}}}, Tuple{}}}}, RecursiveArrayToolsZygoteExt.var"#182#back#103"{RecursiveArrayToolsZygoteExt.var"#99#102"}, Zygote.Pullback{Tuple{Type{NamedTuple{(:saveat,)}}, Tuple{Vector{Int64}}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:saveat,), Tuple{Vector{Int64}}}}, Tuple{Vector{Int64}}}, Tuple{Zygote.var"#2224#back#315"{Zygote.Jnew{NamedTuple{(:saveat,), Tuple{Vector{Int64}}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(Core.kwcall), NamedTuple{(:saveat,), Tuple{Vector{Int64}}}, typeof(CommonSolve.solve), SciMLBase.ODEProblem{Matrix{Float64}, Tuple{Int64, Int64}, false, Vector{Matrix{Float64}}, SciMLBase.ODEFunction{false, SciMLBase.AutoSpecialize, typeof(Main.myModule.ode), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, OrdinaryDiffEq.RK4{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Any}, Zygote.Pullback{Tuple{typeof(Main.myModule.get_p), Main.myModule.AutoODE}, Any}, Zygote.ZBack{ChainRules.var"#getindex_pullback#1580"{Vector{Int64}, Tuple{Int64}, Tuple{ChainRulesCore.NoTangent}}}, Zygote.ZBack{SciMLBaseChainRulesCoreExt.var"#ODEProblemAdjoint#14"}}}}}}})(Δ::Float64)
@ Zygote C:\Users\colin\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
[29] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#21#22", Main.myModule.AutoODE}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.Losses.mse), Array{Float64, 3}, Array{Float64, 3}}, Tuple{Zygote.Pullback{Tuple{Flux.Losses.var"##mse#14", typeof(Statistics.mean), typeof(Flux.Losses.mse), Array{Float64, 3}, Array{Float64, 3}}, Tuple{Zygote.ZBack{Flux.Losses.var"#_check_sizes_pullback#12"}, Zygote.var"#3976#back#1289"{Zygote.var"#1285#1288"{Array{Float64, 3}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Array{Float64, 3}}, Tuple{}}, Zygote.var"#3768#back#1191"{Zygote.var"#1187#1190"{Array{Float64, 3}, Array{Float64, 3}}}, Zygote.ZBack{ChainRules.var"#mean_pullback#1820"{Int64, ChainRules.var"#sum_pullback#1632"{Colon, Array{Float64, 3}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}}}}}}, Zygote.var"#1990#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Array{Float64, 3}}}, Zygote.var"#1990#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Vector{Int64}}}, Zygote.ZBack{ChainRules.var"#getindex_pullback#1580"{Array{Float64, 3}, Tuple{Colon, UnitRange{Int64}, Colon}, Tuple{ChainRulesCore.NoTangent, ChainRulesCore.NoTangent, ChainRulesCore.NoTangent}}}, Zygote.ZBack{ChainRules.var"#:_pullback#278"{Tuple{Int64, Int64}}}, Zygote.Pullback{Tuple{Main.myModule.AutoODE, Vector{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple}}, Tuple{}}, Zygote.var"#2370#back#419"{Zygote.var"#pairs_namedtuple_pullback#418"{(), NamedTuple{(), Tuple{}}}}, Zygote.Pullback{Tuple{Type{OrdinaryDiffEq.RK4}}, Tuple{}}, Zygote.Pullback{Tuple{Main.myModule.var"##_#3", OrdinaryDiffEq.RK4{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Main.myModule.AutoODE, Vector{Int64}}, Tuple{Zygote.Pullback{Tuple{typeof(lastindex), Vector{Int64}}, Tuple{Zygote.ZBack{ChainRules.var"#eachindex_pullback#378"{Tuple{IndexLinear, Vector{Int64}}}}, Zygote.Pullback{Tuple{typeof(last), Base.OneTo{Int64}}, Tuple{Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:stop, Zygote.Context{false}, Base.OneTo{Int64}, Int64}}, Zygote.ZBack{Zygote.var"#convert_pullback#330"}}}, Zygote.Pullback{Tuple{Type{IndexLinear}}, Tuple{}}}}, Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:f, Zygote.Context{false}, Main.myModule.AutoODE, typeof(Main.myModule.ode)}}, Zygote.Pullback{Tuple{typeof(Main.myModule.get_x₀), Main.myModule.AutoODE}, Any}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Zygote.literal_getindex), Vector{Int64}, Val{1}}, Tuple{Zygote.ZBack{ChainRules.var"#getindex_pullback#1580"{Vector{Int64}, Tuple{Int64}, Tuple{ChainRulesCore.NoTangent}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(merge), NamedTuple{(:saveat,), Tuple{Vector{Int64}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Tuple{Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:data, Zygote.Context{false}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, NamedTuple{(), Tuple{}}}}, Zygote.Pullback{Tuple{typeof(merge), NamedTuple{(:saveat,), Tuple{Vector{Int64}}}, NamedTuple{(), Tuple{}}}, Tuple{}}}}, RecursiveArrayToolsZygoteExt.var"#182#back#103"{RecursiveArrayToolsZygoteExt.var"#99#102"}, Zygote.Pullback{Tuple{Type{NamedTuple{(:saveat,)}}, Tuple{Vector{Int64}}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:saveat,), Tuple{Vector{Int64}}}}, Tuple{Vector{Int64}}}, Tuple{Zygote.var"#2224#back#315"{Zygote.Jnew{NamedTuple{(:saveat,), Tuple{Vector{Int64}}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(Core.kwcall), NamedTuple{(:saveat,), Tuple{Vector{Int64}}}, typeof(CommonSolve.solve), SciMLBase.ODEProblem{Matrix{Float64}, Tuple{Int64, Int64}, false, Vector{Matrix{Float64}}, SciMLBase.ODEFunction{false, SciMLBase.AutoSpecialize, typeof(Main.myModule.ode), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, OrdinaryDiffEq.RK4{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Any}, Zygote.Pullback{Tuple{typeof(Main.myModule.get_p), Main.myModule.AutoODE}, Any}, Zygote.ZBack{ChainRules.var"#getindex_pullback#1580"{Vector{Int64}, Tuple{Int64}, Tuple{ChainRulesCore.NoTangent}}}, Zygote.ZBack{SciMLBaseChainRulesCoreExt.var"#ODEProblemAdjoint#14"}}}}}}}})(Δ::Float64)
@ Zygote C:\Users\colin\.julia\packages\Zygote\jxHJc\src\compiler\interface.jl:91
@ Zygote C:\Users\colin\.julia\packages\Zygote\jxHJc\src\compiler\interface.jl:148
[31] top-level scope
@ c:\Users\colin\OneDrive-TUM\Code\Julia\MMDS.jl\scripts\mwe2.jl:100
``````

Not a solution, but with the code above, the error can be reproduced by this code, without Flux:

``````julia> Flux.gradient(model -> sum(model(t)), model)  # same error
ERROR: MethodError: no method matching Float64(::Matrix{Float64})

julia> using DifferentialEquations, SciMLSensitivity, Zygote

julia> p = get_p(model); summary(p)
"7-element Vector{Matrix{Float64}}"

julia> x0 = get_x₀(model); summary(x0)
"5×6 Matrix{Float64}"

julia> function inner(t, p, x₀; alg=RK4(), kwargs...)
# p = get_p(a)
# x₀ = get_x₀(a)
tspan = (t[1], t[end])
prob = ODEProblem(ode, x₀, tspan, p)
solve(prob, alg; saveat=t, kwargs...) |> Array |> sum
# similar error with solve(prob, alg; saveat=t, kwargs...)[2][3]
end
inner (generic function with 1 method)

julia> inner(t, p, x0)  # forward pass runs
116.92583755043069

ERROR: MethodError: no method matching Float64(::Matrix{Float64})
Stacktrace:
[1] convert
@ ~/.julia/packages/ForwardDiff/PcZ48/src/dual.jl:435 [inlined]
...
[7] (::SciMLSensitivity.var"#333#342"{…})()
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/rXkM4/src/concrete_solve.jl:911
...
[17] #solve#41
@ ~/.julia/packages/DiffEqBase/8vI1R/src/solve.jl:1003 [inlined]
[18] (::Zygote.Pullback{…})(Δ::ODESolution{…})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0

julia> Enzyme.autodiff(Reverse, inner, Active, Duplicated(t, zero(t)), Duplicated(p, zero.(p)), Duplicated(x0, zero(x0)))
ERROR: Enzyme execution failed.
Enzyme: Not yet implemented augmented forward for jl_f__apply_iterate (true, true, iterate, Core.apply_type, 7, 6)
Stacktrace:
[1] signature_type
@ ./reflection.jl:962
[2] _methods
@ ./reflection.jl:1020
Stacktrace:
[1] throwerr(cstr::Cstring)
@ Enzyme.Compiler ~/.julia/compiled/v1.10/Enzyme/G1p5n_ppatl.dylib:-1
``````