MethodError in Custom Flux model

Hey everyone,

I have been trying to get my custom model running, which should be able to optimize both for a set of paramters and some initial conditions of an ODE model. The ODE’s dynamics are defined in ode(), the model is AutoODE.

I am new to Julia and have a hard time interpreting the error, or rather locating the cause of it. Any help to find the culprit would be much appreciated :slight_smile:


module myModule

using Flux
using DifferentialEquations


abstract type AbstractAutoODEModel end


function create_initial_conditions(N_states::Int=56)
    # parameters
    β = rand(N_states, 1)
    γ = rand(N_states, 1)
    μ = rand(N_states, 1)
    σ = rand(N_states, 1)
    a = rand(N_states, 1)
    b = rand(N_states, 1)
    A = rand(N_states, N_states)

    θ = [β, γ, μ, σ, a, b, A]

    # initial conditions
    S₀ = fill(0.5, N_states, 1)
    E₀ = fill(0.5, N_states, 1)
    U₀ = @. ((1 - μ) * σ) * E₀

    u₀ = [S₀ E₀ U₀]

    return u₀, θ
end


function ode(u, p, t)
    β, γ, μ, σ, a, b, A = p
    S, E, U, I, R, D = eachcol(u)

    transm = A * (I + E)
    dS = @. - β * transm * S
    dE = @. β * transm * S - σ*E
    dU = @. (1-γ)*σ*E

    dI = @. μ*σ*E - γ*I
    dR = @. γ*I
    r = @. a*t + b
    dD = @. r * dR

    du = [dS dE dU dI dR dD]
    return du
end


struct AutoODE <: AbstractAutoODEModel
    y₀      # constant initial conditions
    u₀      # learnable initial conditions

    θ       # learnable parameters
    q       # constant parameters (optional)

    f::Function

    # constructor
    function AutoODE(y₀; q=nothing)
        N_states = size(y₀, 1)
        u₀, θ = create_initial_conditions(N_states)

        new(y₀, u₀, θ, q, ode)
    end

end


@Flux.layer AbstractAutoODEModel trainable=(u₀, θ,)


get_p(a::AbstractAutoODEModel) = isnothing(a.q) ? a.θ : [a.θ..., a.q...]
get_x₀(a::AbstractAutoODEModel) = isnothing(a.u₀) ? a.y₀ : [a.u₀ a.y₀]


function (a::AbstractAutoODEModel)(t; alg=RK4(), kwargs...)
    p = get_p(a)
    x₀ = get_x₀(a)

    tspan = (t[1], t[end])
    prob = ODEProblem(a.f, x₀, tspan, p)
    return Array(solve(prob, alg; saveat=t, kwargs...))
end

end


using .myModule
using Flux

T = 7
y = rand(5, 3, T)
t = collect(1:T)
model = myModule.AutoODE(y[:, :, 1])

Flux.gradient(model -> Flux.mse(model(t)[:, 4:6, :], y), model)

Full error

ERROR: MethodError: no method matching Float64(::Matrix{Float64})

Closest candidates are:
  (::Type{T})(::VectorizationBase.Double{T}) where T<:Union{Float16, Float32, Float64, VectorizationBase.Vec{<:Any, <:Union{Float16, Float32, Float64}}, VectorizationBase.VecUnroll{var"#s45", var"#s44", var"#s43", V} where {var"#s45", var"#s44", var"#s43"<:Union{Float16, Float32, Float64}, V<:Union{Bool, Float16, Float32, Float64, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8, SIMDTypes.Bit, VectorizationBase.AbstractSIMD{var"#s44", var"#s43"}}}}
   @ VectorizationBase C:\Users\colin\.julia\packages\VectorizationBase\xE5Tx\src\special\double.jl:111
  (::Type{T})(::AbstractChar) where T<:Union{AbstractChar, Number}
   @ Base char.jl:50
  (::Type{T})(::Base.TwicePrecision) where T<:Number
   @ Base twiceprecision.jl:266
  ...

Stacktrace:
  [1] convert
    @ C:\Users\colin\.julia\packages\ForwardDiff\PcZ48\src\dual.jl:435 [inlined]
  [2] _broadcast_getindex_evalf
    @ .\broadcast.jl:683 [inlined]
  [3] _broadcast_getindex
    @ .\broadcast.jl:666 [inlined]
  [4] getindex
    @ .\broadcast.jl:610 [inlined]
  [5] copy
    @ .\broadcast.jl:912 [inlined]
  [6] materialize(bc::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(convert), Tuple{Base.RefValue{Type{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.ODEFunction{false, SciMLBase.AutoSpecialize, typeof(Main.myModule.ode), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}}, Float64}, Float64, 12}}}, Vector{Matrix{Float64}}}})
    @ Base.Broadcast .\broadcast.jl:873
  [7] (::SciMLSensitivity.var"#333#342"{0, Array{Float64, 3}, Vector{Int64}, Base.Pairs{Symbol, Bool, Tuple{Symbol}, NamedTuple{(:verbose,), Tuple{Bool}}}, SciMLBase.ODEProblem{Matrix{Float64}, Tuple{Float64, Float64}, false, Vector{Matrix{Float64}}, SciMLBase.ODEFunction{false, SciMLBase.AutoSpecialize, typeof(Main.myModule.ode), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, OrdinaryDiffEq.RK4{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, SciMLSensitivity.ForwardDiffSensitivity{0, nothing}, Matrix{Float64}, Vector{Matrix{Float64}}, Tuple{}, Vector{Float64}})()
    @ SciMLSensitivity C:\Users\colin\.julia\packages\SciMLSensitivity\4Ah3r\src\concrete_solve.jl:888
  [8] unthunk
    @ C:\Users\colin\.julia\packages\ChainRulesCore\zgT0R\src\tangent_types\thunks.jl:204 [inlined]
  [9] wrap_chainrules_output
    @ C:\Users\colin\.julia\packages\Zygote\jxHJc\src\compiler\chainrules.jl:110 [inlined]
 [10] map
    @ .\tuple.jl:275 [inlined]
 [11] map (repeats 3 times)
    @ .\tuple.jl:276 [inlined]
 [12] wrap_chainrules_output
    @ C:\Users\colin\.julia\packages\Zygote\jxHJc\src\compiler\chainrules.jl:111 [inlined]
 [13] ZBack
    @ C:\Users\colin\.julia\packages\Zygote\jxHJc\src\compiler\chainrules.jl:211 [inlined]
 [14] kw_zpullback
    @ C:\Users\colin\.julia\packages\Zygote\jxHJc\src\compiler\chainrules.jl:237 [inlined]
 [15] #291
    @ C:\Users\colin\.julia\packages\Zygote\jxHJc\src\lib\lib.jl:206 [inlined]
 [16] (::Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{NTuple{6, Nothing}, Tuple{Nothing}}, Zygote.var"#kw_zpullback#53"{SciMLSensitivity.var"#forward_sensitivity_backpass#338"{0, Vector{Int64}, Base.Pairs{Symbol, Bool, Tuple{Symbol}, NamedTuple{(:verbose,), Tuple{Bool}}}, SciMLBase.ODEProblem{Matrix{Float64}, Tuple{Float64, Float64}, false, Vector{Matrix{Float64}}, SciMLBase.ODEFunction{false, SciMLBase.AutoSpecialize, typeof(Main.myModule.ode), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, OrdinaryDiffEq.RK4{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, SciMLSensitivity.ForwardDiffSensitivity{0, nothing}, Matrix{Float64}, Vector{Matrix{Float64}}, SciMLBase.ChainRulesOriginator, Tuple{}, Vector{Float64}}}}})(Δ::Array{Float64, 3})
    @ Zygote C:\Users\colin\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72
 [17] Pullback
    @ C:\Users\colin\.julia\packages\DiffEqBase\eTCPy\src\solve.jl:980 [inlined]
 [18] (::Zygote.Pullback{Tuple{DiffEqBase.var"##solve#40", Nothing, Nothing, Nothing, Val{true}, Base.Pairs{Symbol, Vector{Int64}, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Vector{Int64}}}}, typeof(CommonSolve.solve), SciMLBase.ODEProblem{Matrix{Float64}, Tuple{Int64, Int64}, false, Vector{Matrix{Float64}}, SciMLBase.ODEFunction{false, SciMLBase.AutoSpecialize, typeof(Main.myModule.ode), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, OrdinaryDiffEq.RK4{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Any})(Δ::Array{Float64, 3})
    @ Zygote C:\Users\colin\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [19] #291
    @ C:\Users\colin\.julia\packages\Zygote\jxHJc\src\lib\lib.jl:206 [inlined]
 [20] (::Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{NTuple{7, Nothing}, Tuple{Nothing}}, Zygote.Pullback{Tuple{DiffEqBase.var"##solve#40", Nothing, Nothing, Nothing, Val{true}, Base.Pairs{Symbol, Vector{Int64}, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Vector{Int64}}}}, typeof(CommonSolve.solve), SciMLBase.ODEProblem{Matrix{Float64}, Tuple{Int64, Int64}, false, Vector{Matrix{Float64}}, SciMLBase.ODEFunction{false, SciMLBase.AutoSpecialize, typeof(Main.myModule.ode), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, OrdinaryDiffEq.RK4{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Any}}})(Δ::Array{Float64, 3})
    @ Zygote C:\Users\colin\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72
 [21] Pullback
    @ C:\Users\colin\.julia\packages\DiffEqBase\eTCPy\src\solve.jl:970 [inlined]
 [22] (::Zygote.Pullback{Tuple{typeof(Core.kwcall), NamedTuple{(:saveat,), Tuple{Vector{Int64}}}, typeof(CommonSolve.solve), SciMLBase.ODEProblem{Matrix{Float64}, Tuple{Int64, Int64}, false, Vector{Matrix{Float64}}, SciMLBase.ODEFunction{false, SciMLBase.AutoSpecialize, typeof(Main.myModule.ode), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, OrdinaryDiffEq.RK4{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Any})(Δ::Array{Float64, 3})
    @ Zygote C:\Users\colin\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [23] Pullback
    @ c:\Users\colin\OneDrive-TUM\Code\Julia\MMDS.jl\scripts\mwe2.jl:86 [inlined]
 [24] (::Zygote.Pullback{Tuple{Main.myModule.var"##_#3", OrdinaryDiffEq.RK4{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Main.myModule.AutoODE, Vector{Int64}}, Tuple{Zygote.Pullback{Tuple{typeof(lastindex), Vector{Int64}}, Tuple{Zygote.ZBack{ChainRules.var"#eachindex_pullback#378"{Tuple{IndexLinear, Vector{Int64}}}}, Zygote.Pullback{Tuple{typeof(last), Base.OneTo{Int64}}, Tuple{Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:stop, Zygote.Context{false}, Base.OneTo{Int64}, Int64}}, Zygote.ZBack{Zygote.var"#convert_pullback#330"}}}, Zygote.Pullback{Tuple{Type{IndexLinear}}, Tuple{}}}}, Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:f, Zygote.Context{false}, Main.myModule.AutoODE, typeof(Main.myModule.ode)}}, Zygote.Pullback{Tuple{typeof(Main.myModule.get_x₀), Main.myModule.AutoODE}, Any}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Zygote.literal_getindex), Vector{Int64}, Val{1}}, Tuple{Zygote.ZBack{ChainRules.var"#getindex_pullback#1580"{Vector{Int64}, Tuple{Int64}, Tuple{ChainRulesCore.NoTangent}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(merge), NamedTuple{(:saveat,), Tuple{Vector{Int64}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Tuple{Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:data, Zygote.Context{false}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, NamedTuple{(), Tuple{}}}}, Zygote.Pullback{Tuple{typeof(merge), NamedTuple{(:saveat,), Tuple{Vector{Int64}}}, NamedTuple{(), Tuple{}}}, Tuple{}}}}, RecursiveArrayToolsZygoteExt.var"#182#back#103"{RecursiveArrayToolsZygoteExt.var"#99#102"}, Zygote.Pullback{Tuple{Type{NamedTuple{(:saveat,)}}, Tuple{Vector{Int64}}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:saveat,), Tuple{Vector{Int64}}}}, Tuple{Vector{Int64}}}, Tuple{Zygote.var"#2224#back#315"{Zygote.Jnew{NamedTuple{(:saveat,), Tuple{Vector{Int64}}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(Core.kwcall), NamedTuple{(:saveat,), Tuple{Vector{Int64}}}, typeof(CommonSolve.solve), SciMLBase.ODEProblem{Matrix{Float64}, Tuple{Int64, Int64}, false, Vector{Matrix{Float64}}, SciMLBase.ODEFunction{false, SciMLBase.AutoSpecialize, typeof(Main.myModule.ode), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, OrdinaryDiffEq.RK4{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Any}, Zygote.Pullback{Tuple{typeof(Main.myModule.get_p), Main.myModule.AutoODE}, Any}, Zygote.ZBack{ChainRules.var"#getindex_pullback#1580"{Vector{Int64}, Tuple{Int64}, Tuple{ChainRulesCore.NoTangent}}}, Zygote.ZBack{SciMLBaseChainRulesCoreExt.var"#ODEProblemAdjoint#14"}}})(Δ::Array{Float64, 3})
    @ Zygote C:\Users\colin\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [25] Pullback
    @ c:\Users\colin\OneDrive-TUM\Code\Julia\MMDS.jl\scripts\mwe2.jl:80 [inlined]
 [26] (::Zygote.Pullback{Tuple{Main.myModule.AutoODE, Vector{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple}}, Tuple{}}, Zygote.var"#2370#back#419"{Zygote.var"#pairs_namedtuple_pullback#418"{(), NamedTuple{(), Tuple{}}}}, Zygote.Pullback{Tuple{Type{OrdinaryDiffEq.RK4}}, Tuple{}}, Zygote.Pullback{Tuple{Main.myModule.var"##_#3", OrdinaryDiffEq.RK4{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Main.myModule.AutoODE, Vector{Int64}}, Tuple{Zygote.Pullback{Tuple{typeof(lastindex), Vector{Int64}}, Tuple{Zygote.ZBack{ChainRules.var"#eachindex_pullback#378"{Tuple{IndexLinear, Vector{Int64}}}}, Zygote.Pullback{Tuple{typeof(last), Base.OneTo{Int64}}, Tuple{Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:stop, Zygote.Context{false}, Base.OneTo{Int64}, Int64}}, Zygote.ZBack{Zygote.var"#convert_pullback#330"}}}, Zygote.Pullback{Tuple{Type{IndexLinear}}, Tuple{}}}}, Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:f, Zygote.Context{false}, Main.myModule.AutoODE, typeof(Main.myModule.ode)}}, Zygote.Pullback{Tuple{typeof(Main.myModule.get_x₀), Main.myModule.AutoODE}, Any}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Zygote.literal_getindex), Vector{Int64}, Val{1}}, Tuple{Zygote.ZBack{ChainRules.var"#getindex_pullback#1580"{Vector{Int64}, Tuple{Int64}, Tuple{ChainRulesCore.NoTangent}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(merge), NamedTuple{(:saveat,), Tuple{Vector{Int64}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Tuple{Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:data, Zygote.Context{false}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, NamedTuple{(), Tuple{}}}}, Zygote.Pullback{Tuple{typeof(merge), NamedTuple{(:saveat,), Tuple{Vector{Int64}}}, NamedTuple{(), Tuple{}}}, Tuple{}}}}, RecursiveArrayToolsZygoteExt.var"#182#back#103"{RecursiveArrayToolsZygoteExt.var"#99#102"}, Zygote.Pullback{Tuple{Type{NamedTuple{(:saveat,)}}, Tuple{Vector{Int64}}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:saveat,), Tuple{Vector{Int64}}}}, Tuple{Vector{Int64}}}, Tuple{Zygote.var"#2224#back#315"{Zygote.Jnew{NamedTuple{(:saveat,), Tuple{Vector{Int64}}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(Core.kwcall), NamedTuple{(:saveat,), Tuple{Vector{Int64}}}, typeof(CommonSolve.solve), SciMLBase.ODEProblem{Matrix{Float64}, Tuple{Int64, Int64}, false, Vector{Matrix{Float64}}, SciMLBase.ODEFunction{false, SciMLBase.AutoSpecialize, typeof(Main.myModule.ode), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, OrdinaryDiffEq.RK4{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Any}, Zygote.Pullback{Tuple{typeof(Main.myModule.get_p), Main.myModule.AutoODE}, Any}, Zygote.ZBack{ChainRules.var"#getindex_pullback#1580"{Vector{Int64}, Tuple{Int64}, Tuple{ChainRulesCore.NoTangent}}}, Zygote.ZBack{SciMLBaseChainRulesCoreExt.var"#ODEProblemAdjoint#14"}}}}})(Δ::Array{Float64, 3})
    @ Zygote C:\Users\colin\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [27] Pullback
    @ c:\Users\colin\OneDrive-TUM\Code\Julia\MMDS.jl\scripts\mwe2.jl:100 [inlined]
 [28] (::Zygote.Pullback{Tuple{var"#21#22", Main.myModule.AutoODE}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.Losses.mse), Array{Float64, 3}, Array{Float64, 3}}, Tuple{Zygote.Pullback{Tuple{Flux.Losses.var"##mse#14", typeof(Statistics.mean), typeof(Flux.Losses.mse), Array{Float64, 3}, Array{Float64, 3}}, Tuple{Zygote.ZBack{Flux.Losses.var"#_check_sizes_pullback#12"}, Zygote.var"#3976#back#1289"{Zygote.var"#1285#1288"{Array{Float64, 3}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Array{Float64, 3}}, Tuple{}}, Zygote.var"#3768#back#1191"{Zygote.var"#1187#1190"{Array{Float64, 3}, Array{Float64, 3}}}, Zygote.ZBack{ChainRules.var"#mean_pullback#1820"{Int64, ChainRules.var"#sum_pullback#1632"{Colon, Array{Float64, 3}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}}}}}}, Zygote.var"#1990#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Array{Float64, 3}}}, Zygote.var"#1990#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Vector{Int64}}}, Zygote.ZBack{ChainRules.var"#getindex_pullback#1580"{Array{Float64, 3}, Tuple{Colon, UnitRange{Int64}, Colon}, Tuple{ChainRulesCore.NoTangent, ChainRulesCore.NoTangent, ChainRulesCore.NoTangent}}}, Zygote.ZBack{ChainRules.var"#:_pullback#278"{Tuple{Int64, Int64}}}, Zygote.Pullback{Tuple{Main.myModule.AutoODE, Vector{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple}}, Tuple{}}, Zygote.var"#2370#back#419"{Zygote.var"#pairs_namedtuple_pullback#418"{(), NamedTuple{(), Tuple{}}}}, Zygote.Pullback{Tuple{Type{OrdinaryDiffEq.RK4}}, Tuple{}}, Zygote.Pullback{Tuple{Main.myModule.var"##_#3", OrdinaryDiffEq.RK4{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Main.myModule.AutoODE, Vector{Int64}}, Tuple{Zygote.Pullback{Tuple{typeof(lastindex), Vector{Int64}}, Tuple{Zygote.ZBack{ChainRules.var"#eachindex_pullback#378"{Tuple{IndexLinear, Vector{Int64}}}}, Zygote.Pullback{Tuple{typeof(last), Base.OneTo{Int64}}, Tuple{Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:stop, Zygote.Context{false}, Base.OneTo{Int64}, Int64}}, Zygote.ZBack{Zygote.var"#convert_pullback#330"}}}, Zygote.Pullback{Tuple{Type{IndexLinear}}, Tuple{}}}}, Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:f, Zygote.Context{false}, Main.myModule.AutoODE, typeof(Main.myModule.ode)}}, Zygote.Pullback{Tuple{typeof(Main.myModule.get_x₀), Main.myModule.AutoODE}, Any}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Zygote.literal_getindex), Vector{Int64}, Val{1}}, Tuple{Zygote.ZBack{ChainRules.var"#getindex_pullback#1580"{Vector{Int64}, Tuple{Int64}, Tuple{ChainRulesCore.NoTangent}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(merge), NamedTuple{(:saveat,), Tuple{Vector{Int64}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Tuple{Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:data, Zygote.Context{false}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, NamedTuple{(), Tuple{}}}}, Zygote.Pullback{Tuple{typeof(merge), NamedTuple{(:saveat,), Tuple{Vector{Int64}}}, NamedTuple{(), Tuple{}}}, Tuple{}}}}, RecursiveArrayToolsZygoteExt.var"#182#back#103"{RecursiveArrayToolsZygoteExt.var"#99#102"}, Zygote.Pullback{Tuple{Type{NamedTuple{(:saveat,)}}, Tuple{Vector{Int64}}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:saveat,), Tuple{Vector{Int64}}}}, Tuple{Vector{Int64}}}, Tuple{Zygote.var"#2224#back#315"{Zygote.Jnew{NamedTuple{(:saveat,), Tuple{Vector{Int64}}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(Core.kwcall), NamedTuple{(:saveat,), Tuple{Vector{Int64}}}, typeof(CommonSolve.solve), SciMLBase.ODEProblem{Matrix{Float64}, Tuple{Int64, Int64}, false, Vector{Matrix{Float64}}, SciMLBase.ODEFunction{false, SciMLBase.AutoSpecialize, typeof(Main.myModule.ode), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, OrdinaryDiffEq.RK4{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Any}, Zygote.Pullback{Tuple{typeof(Main.myModule.get_p), Main.myModule.AutoODE}, Any}, Zygote.ZBack{ChainRules.var"#getindex_pullback#1580"{Vector{Int64}, Tuple{Int64}, Tuple{ChainRulesCore.NoTangent}}}, Zygote.ZBack{SciMLBaseChainRulesCoreExt.var"#ODEProblemAdjoint#14"}}}}}}})(Δ::Float64)
    @ Zygote C:\Users\colin\.julia\packages\Zygote\jxHJc\src\compiler\interface2.jl:0
 [29] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#21#22", Main.myModule.AutoODE}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.Losses.mse), Array{Float64, 3}, Array{Float64, 3}}, Tuple{Zygote.Pullback{Tuple{Flux.Losses.var"##mse#14", typeof(Statistics.mean), typeof(Flux.Losses.mse), Array{Float64, 3}, Array{Float64, 3}}, Tuple{Zygote.ZBack{Flux.Losses.var"#_check_sizes_pullback#12"}, Zygote.var"#3976#back#1289"{Zygote.var"#1285#1288"{Array{Float64, 3}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Array{Float64, 3}}, Tuple{}}, Zygote.var"#3768#back#1191"{Zygote.var"#1187#1190"{Array{Float64, 3}, Array{Float64, 3}}}, Zygote.ZBack{ChainRules.var"#mean_pullback#1820"{Int64, ChainRules.var"#sum_pullback#1632"{Colon, Array{Float64, 3}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}}}}}}, Zygote.var"#1990#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Array{Float64, 3}}}, Zygote.var"#1990#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Vector{Int64}}}, Zygote.ZBack{ChainRules.var"#getindex_pullback#1580"{Array{Float64, 3}, Tuple{Colon, UnitRange{Int64}, Colon}, Tuple{ChainRulesCore.NoTangent, ChainRulesCore.NoTangent, ChainRulesCore.NoTangent}}}, Zygote.ZBack{ChainRules.var"#:_pullback#278"{Tuple{Int64, Int64}}}, Zygote.Pullback{Tuple{Main.myModule.AutoODE, Vector{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple}}, Tuple{}}, Zygote.var"#2370#back#419"{Zygote.var"#pairs_namedtuple_pullback#418"{(), NamedTuple{(), Tuple{}}}}, Zygote.Pullback{Tuple{Type{OrdinaryDiffEq.RK4}}, Tuple{}}, Zygote.Pullback{Tuple{Main.myModule.var"##_#3", OrdinaryDiffEq.RK4{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Main.myModule.AutoODE, Vector{Int64}}, Tuple{Zygote.Pullback{Tuple{typeof(lastindex), Vector{Int64}}, Tuple{Zygote.ZBack{ChainRules.var"#eachindex_pullback#378"{Tuple{IndexLinear, Vector{Int64}}}}, Zygote.Pullback{Tuple{typeof(last), Base.OneTo{Int64}}, Tuple{Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:stop, Zygote.Context{false}, Base.OneTo{Int64}, Int64}}, Zygote.ZBack{Zygote.var"#convert_pullback#330"}}}, Zygote.Pullback{Tuple{Type{IndexLinear}}, Tuple{}}}}, Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:f, Zygote.Context{false}, Main.myModule.AutoODE, typeof(Main.myModule.ode)}}, Zygote.Pullback{Tuple{typeof(Main.myModule.get_x₀), Main.myModule.AutoODE}, Any}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Zygote.literal_getindex), Vector{Int64}, Val{1}}, Tuple{Zygote.ZBack{ChainRules.var"#getindex_pullback#1580"{Vector{Int64}, Tuple{Int64}, Tuple{ChainRulesCore.NoTangent}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(merge), NamedTuple{(:saveat,), Tuple{Vector{Int64}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Tuple{Zygote.var"#2184#back#303"{Zygote.var"#back#302"{:data, Zygote.Context{false}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, NamedTuple{(), Tuple{}}}}, Zygote.Pullback{Tuple{typeof(merge), NamedTuple{(:saveat,), Tuple{Vector{Int64}}}, NamedTuple{(), Tuple{}}}, Tuple{}}}}, RecursiveArrayToolsZygoteExt.var"#182#back#103"{RecursiveArrayToolsZygoteExt.var"#99#102"}, Zygote.Pullback{Tuple{Type{NamedTuple{(:saveat,)}}, Tuple{Vector{Int64}}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:saveat,), Tuple{Vector{Int64}}}}, Tuple{Vector{Int64}}}, Tuple{Zygote.var"#2224#back#315"{Zygote.Jnew{NamedTuple{(:saveat,), Tuple{Vector{Int64}}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(Core.kwcall), NamedTuple{(:saveat,), Tuple{Vector{Int64}}}, typeof(CommonSolve.solve), SciMLBase.ODEProblem{Matrix{Float64}, Tuple{Int64, Int64}, false, Vector{Matrix{Float64}}, SciMLBase.ODEFunction{false, SciMLBase.AutoSpecialize, typeof(Main.myModule.ode), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, OrdinaryDiffEq.RK4{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Any}, Zygote.Pullback{Tuple{typeof(Main.myModule.get_p), Main.myModule.AutoODE}, Any}, Zygote.ZBack{ChainRules.var"#getindex_pullback#1580"{Vector{Int64}, Tuple{Int64}, Tuple{ChainRulesCore.NoTangent}}}, Zygote.ZBack{SciMLBaseChainRulesCoreExt.var"#ODEProblemAdjoint#14"}}}}}}}})(Δ::Float64)
    @ Zygote C:\Users\colin\.julia\packages\Zygote\jxHJc\src\compiler\interface.jl:91
 [30] gradient(f::Function, args::Main.myModule.AutoODE)
    @ Zygote C:\Users\colin\.julia\packages\Zygote\jxHJc\src\compiler\interface.jl:148
 [31] top-level scope
    @ c:\Users\colin\OneDrive-TUM\Code\Julia\MMDS.jl\scripts\mwe2.jl:100

Not a solution, but with the code above, the error can be reproduced by this code, without Flux:

julia> Flux.gradient(model -> sum(model(t)), model)  # same error
ERROR: MethodError: no method matching Float64(::Matrix{Float64})

julia> using DifferentialEquations, SciMLSensitivity, Zygote

julia> p = get_p(model); summary(p)
"7-element Vector{Matrix{Float64}}"

julia> x0 = get_x₀(model); summary(x0)
"5×6 Matrix{Float64}"

julia> function inner(t, p, x₀; alg=RK4(), kwargs...)
           # p = get_p(a)
           # x₀ = get_x₀(a)
           tspan = (t[1], t[end])
           prob = ODEProblem(ode, x₀, tspan, p)
           solve(prob, alg; saveat=t, kwargs...) |> Array |> sum
           # similar error with solve(prob, alg; saveat=t, kwargs...)[2][3]
       end
inner (generic function with 1 method)

julia> inner(t, p, x0)  # forward pass runs
116.92583755043069

julia> Zygote.gradient(inner, t, p, x0)
ERROR: MethodError: no method matching Float64(::Matrix{Float64})
Stacktrace:
  [1] convert
    @ ~/.julia/packages/ForwardDiff/PcZ48/src/dual.jl:435 [inlined]
...
  [6] materialize(bc::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{…}, Nothing, typeof(convert), Tuple{…}})
    @ Base.Broadcast ./broadcast.jl:903
  [7] (::SciMLSensitivity.var"#333#342"{…})()
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/rXkM4/src/concrete_solve.jl:911
...
 [17] #solve#41
    @ ~/.julia/packages/DiffEqBase/8vI1R/src/solve.jl:1003 [inlined]
 [18] (::Zygote.Pullback{…})(Δ::ODESolution{…})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0

julia> Enzyme.autodiff(Reverse, inner, Active, Duplicated(t, zero(t)), Duplicated(p, zero.(p)), Duplicated(x0, zero(x0)))
ERROR: Enzyme execution failed.
Enzyme: Not yet implemented augmented forward for jl_f__apply_iterate (true, true, iterate, Core.apply_type, 7, 6)
Stacktrace:
 [1] signature_type
   @ ./reflection.jl:962
 [2] _methods
   @ ./reflection.jl:1020
Stacktrace:
  [1] throwerr(cstr::Cstring)
    @ Enzyme.Compiler ~/.julia/compiled/v1.10/Enzyme/G1p5n_ppatl.dylib:-1