# EnsembleProblem with SteadyStateProblem in Zygote

Is there a premade parallel simulation version of SteadyStateProblem (currently using a for loop)? I am assuming there isn’t because EnsembleProblem expects tspan which I don’t provide. My goal is to obtain the steady state solutions with adjoints for training a neural ODE. Thanks!

It should work. Share an MWE

``````using DifferentialEquations
using SciMLSensitivity
using Plots
using Zygote

# Some model with multiple steady states
function some_model(u, p, t)
dxdt = p[1]*u[1]*u[2]
dydt = -p[2]*u[2]
return [dxdt, dydt]
end

# Specify the initial conditions for each simulation
function prob_func(prob, i, repeat)
global x0
remake(prob, u0=x0[i, :])
end

x0 = hcat(rand(10), ones(10)) # Initialize some random initial conditions
foop = [1., 1.] # Parameters
tspan = LinRange(0, 50, 25)

# Do some random gradients just to make sure code works for parametrizing stuff
foo_prob = ODEProblem(some_model, u0, (tspan[1], tspan[end]), foop) # Works
foo_ensemble = EnsembleProblem(foo_prob, prob_func=prob_func)
foo_sol = solve(
foo_ensemble,
Tsit5(),
trajectories=size(x0)[1];
saveat = tspan,
)
Zygote.ignore_derivatives() do
plot(foo_sol) # Plot to see what is going on
end

# Some random function
foo_val = 0.
for some_sol in foo_sol.u
_some_sol_mat = mapreduce(permutedims, vcat, some_sol.u)
foo_val += sum(abs2, _some_sol_mat)
end
@show foo_val
return foo_val
end

# The following below doesn't work
foo_SS_ensemble = EnsembleProblem(foo_SS_prob, prob_func=prob_func)
foo_SS_sol = solve(
foo_SS_ensemble,
DynamicSS(Tsit5(), abstol=1e-4, reltol=1e-3, tspan=tspan[end]),
trajectories=size(x0)[1];
)
# @show Array(foo_SS_sol.u[1])

# Some random function
foo_val = 0.
for some_sol in foo_SS_sol.u
foo_val += sum(abs2, some_sol)
end
@show foo_val
return foo_val
end
``````

I’m not too familiar with the backend of AD packages, but using Zygote seems to break the ensemble problem. Using it without gradients work just fine. The error message is “ERROR: type SteadyStateProblem has no field tspan”.

Thanks! I see. In general Zygote works with ensemble problems (it’s tested with ODEs, SDEs, DDEs, DAEs), but this combination seems to require something extra. Can you open an issue on SciMLSensitivity.jl? I can solve this by the end of the week.

Thanks! I opened an issue on SciMLSensitivity.jl. I have a few more questions (maybe should be in a new thread).

1. The solver can speed up/slow down depending on the parametrization of the neural network. That’s fine, but it sometimes hangs for a long time (CPU, not GPU) without showing an error message (Hard to make a MWE for this). This issue using the standard Tsit5 solver can be alleviated with lower-order solvers, which I assume can be due to stiffness, divergence, or user error. I don’t assume it’s stiffness because AutoTsit5(Rosenbrock23()) hangs in the same iteration (I used rng seeds for my actual implementation). Increasing the tolerances (lower abstol and reltol values) also helps, but doing so slows down the training process. If I want to tradeoff accuracy for training speed, do I use maxiters to prevent hanging in this situation?

2. I want to simplify the dynamics to avoid stiffness and more function evaluations, so I tried implementing the methods in Opening the Blackbox: Accelerating Neural Differential Equations by
Regularizing Internal Solver Heuristics, but it throws this stacktrace in this MWE. Maybe I didn’t read the documentation carefully enough. How do I implement the method from the paper?

Thanks!
Note: I use Zygote for all of this since I have a custom training loop, and I need the gradients for a customized AdaBelief optimization with decoupled weight decay (AdaBelief in Optim.jl doesn’t have decoupled weight decay in the function signature). If there is a better option, I can explore that too.

StackTrace

``````ERROR: ArgumentError: new: too few arguments (expected 48)
Stacktrace:
[1] __new__(::Type, ::ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, Vector{Float64}, ODEFunction{false, typeof(some_model), LinearAlgebra.UniformScaling{Bool}, Nothing, var"#foo_basic_tgrad#46", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing,
typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{false, typeof(some_model), LinearAlgebra.UniformScaling{Bool}, Nothing, var"#foo_basic_tgrad#46", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}, DiffEqBase.DEStats}, ::Vararg{Any})
@ Zygote C:\Users\kevin\.julia\packages\Zygote\IoW2g\src\tools\builtins.jl:9
@ C:\Users\kevin\.julia\packages\Zygote\IoW2g\src\lib\lib.jl:294 [inlined]
[3] adjoint(::Zygote.Context, ::typeof(Zygote.__new__), ::Type, ::ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, Vector{Float64}, ODEFunction{false, typeof(some_model), LinearAlgebra.UniformScaling{Bool}, Nothing, var"#foo_basic_tgrad#46", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{false, typeof(some_model), LinearAlgebra.UniformScaling{Bool}, Nothing, var"#foo_basic_tgrad#46", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}, DiffEqBase.DEStats}, ::Vector{Float64}, ::Nothing, ::Vector{Vector{Float64}}, ::Vararg{Any})
@ Zygote .\none:0
[4] _pullback
[5] _pullback
@ C:\Users\kevin\.julia\packages\OrdinaryDiffEq\FhKjw\src\integrators\type.jl:144 [inlined]
[6] _pullback(::Zygote.Context, ::Type{OrdinaryDiffEq.ODEIntegrator{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, false, Vector{Float64}, Nothing, Float64, Vector{Float64}, Float64, Float64, Float64, Float64, Vector{Vector{Float64}}, ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, Vector{Float64}, ODEFunction{false, typeof(some_model), LinearAlgebra.UniformScaling{Bool}, Nothing, var"#foo_basic_tgrad#46", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{false, typeof(some_model), LinearAlgebra.UniformScaling{Bool}, Nothing, var"#foo_basic_tgrad#46", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}, DiffEqBase.DEStats}, ODEFunction{false, typeof(some_model), LinearAlgebra.UniformScaling{Bool}, Nothing, var"#foo_basic_tgrad#46", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}, OrdinaryDiffEq.DEOptions{Float64, Float64, Float64, Float64, PIController{Rational{Int64}}, typeof(DiffEqBase.ODE_DEFAULT_NORM), typeof(LinearAlgebra.opnorm), Nothing, CallbackSet{Tuple{}, Tuple{DiscreteCallback{DiffEqCallbacks.var"#30#31", DiffEqCallbacks.SavingAffect{var"#43#47", Float64, Float64, DataStructures.BinaryMinHeap{Float64}, Vector{Float64}}, typeof(DiffEqCallbacks.saving_initialize), typeof(SciMLBase.FINALIZE_DEFAULT)}}}, typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), DataStructures.BinaryHeap{Float64, DataStructures.FasterForward}, DataStructures.BinaryHeap{Float64, DataStructures.FasterForward}, Nothing, Nothing, Int64, Tuple{}, LinRange{Float64, Int64}, Tuple{}}, Vector{Float64}, Float64, Nothing, OrdinaryDiffEq.DefaultInit}}, ::ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, Vector{Float64}, ODEFunction{false, typeof(some_model), LinearAlgebra.UniformScaling{Bool}, Nothing, var"#foo_basic_tgrad#46", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{false, typeof(some_model), LinearAlgebra.UniformScaling{Bool}, Nothing, var"#foo_basic_tgrad#46", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}, DiffEqBase.DEStats}, ::Vector{Float64}, ::Nothing, ::Vector{Vector{Float64}}, ::Float64, ::Float64, ::ODEFunction{false, typeof(some_model), LinearAlgebra.UniformScaling{Bool}, Nothing, var"#foo_basic_tgrad#46",
Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, ::Vector{Float64}, ::Vector{Float64}, ::Vector{Float64}, ::Nothing, ::Float64, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ::Float64, ::Bool, ::Float64, ::Float64, ::Float64, ::Float64, ::Float64, ::Float64, ::Float64, ::Float64, ::Int64, ::Int64, ::Int64, ::Int64, ::OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}, ::Nothing, ::Int64, ::Bool, ::Bool, ::Bool, ::Bool, ::Int64, ::Int64, ::Float64, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::OrdinaryDiffEq.DEOptions{Float64, Float64, Float64, Float64, PIController{Rational{Int64}}, typeof(DiffEqBase.ODE_DEFAULT_NORM), typeof(LinearAlgebra.opnorm), Nothing, CallbackSet{Tuple{}, Tuple{DiscreteCallback{DiffEqCallbacks.var"#30#31", DiffEqCallbacks.SavingAffect{var"#43#47", Float64, Float64, DataStructures.BinaryMinHeap{Float64}, Vector{Float64}}, typeof(DiffEqCallbacks.saving_initialize), typeof(SciMLBase.FINALIZE_DEFAULT)}}}, typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), DataStructures.BinaryHeap{Float64, DataStructures.FasterForward}, DataStructures.BinaryHeap{Float64, DataStructures.FasterForward}, Nothing, Nothing, Int64, Tuple{}, LinRange{Float64, Int64}, Tuple{}}, ::DiffEqBase.DEStats, ::OrdinaryDiffEq.DefaultInit)
@ Zygote C:\Users\kevin\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
[7] _pullback
@ C:\Users\kevin\.julia\packages\OrdinaryDiffEq\FhKjw\src\solve.jl:413 [inlined]
[8] _pullback(::Zygote.Context, ::OrdinaryDiffEq.var"##__init#503", ::LinRange{Float64, Int64}, ::Tuple{}, ::Tuple{}, ::Nothing, ::Bool, ::Bool, ::Bool, ::Nothing, ::DiscreteCallback{DiffEqCallbacks.var"#30#31", DiffEqCallbacks.SavingAffect{var"#43#47", Float64, Float64, DataStructures.BinaryMinHeap{Float64}, Vector{Float64}}, typeof(DiffEqCallbacks.saving_initialize), typeof(SciMLBase.FINALIZE_DEFAULT)}, ::Bool, ::Bool, ::Float64, ::Nothing, ::Float64, ::Bool, ::Bool, ::Rational{Int64}, ::Float64, ::Float64, ::Rational{Int64}, ::Int64, ::Int64, ::Int64, ::Nothing, ::Nothing, ::Rational{Int64}, ::Nothing, ::Bool, ::Int64, ::Int64, ::typeof(DiffEqBase.ODE_DEFAULT_NORM), ::typeof(LinearAlgebra.opnorm), ::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), ::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Int64, ::String, ::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), ::Nothing, ::Bool, ::Bool, ::Bool, ::Bool, ::OrdinaryDiffEq.DefaultInit, ::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(SciMLBase.__init), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, Vector{Float64}, ODEFunction{false, typeof(some_model), LinearAlgebra.UniformScaling{Bool}, Nothing, var"#foo_basic_tgrad#46", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ::Tuple{}, ::Tuple{}, ::Tuple{}, ::Type{Val{true}})    @ Zygote C:\Users\kevin\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
[9] _pullback
@ C:\Users\kevin\.julia\packages\OrdinaryDiffEq\FhKjw\src\solve.jl:67 [inlined]
[10] _pullback(::Zygote.Context, ::SciMLBase.var"#__init##kw", ::NamedTuple{(:abstol, :reltol, :saveat, :callback), Tuple{Float64, Float64, LinRange{Float64, Int64}, DiscreteCallback{DiffEqCallbacks.var"#30#31", DiffEqCallbacks.SavingAffect{var"#43#47", Float64, Float64, DataStructures.BinaryMinHeap{Float64}, Vector{Float64}}, typeof(DiffEqCallbacks.saving_initialize), typeof(SciMLBase.FINALIZE_DEFAULT)}}}, ::typeof(SciMLBase.__init), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, Vector{Float64}, ODEFunction{false, typeof(some_model), LinearAlgebra.UniformScaling{Bool}, Nothing, var"#foo_basic_tgrad#46", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ::Tuple{}, ::Tuple{}, ::Tuple{}, ::Type{Val{true}})
@ Zygote C:\Users\kevin\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
--- the last 2 lines are repeated 1 more time ---
[13] _apply(::Function, ::Vararg{Any})
@ Core .\boot.jl:814
@ C:\Users\kevin\.julia\packages\Zygote\IoW2g\src\lib\lib.jl:204 [inlined]
[15] _pullback
[16] _pullback
@ C:\Users\kevin\.julia\packages\OrdinaryDiffEq\FhKjw\src\solve.jl:4 [inlined]
[17] _pullback(::Zygote.Context, ::OrdinaryDiffEq.var"##__solve#502", ::Base.Pairs{Symbol, Any, NTuple{4, Symbol}, NamedTuple{(:abstol, :reltol, :saveat, :callback), Tuple{Float64, Float64, LinRange{Float64,
Int64}, DiscreteCallback{DiffEqCallbacks.var"#30#31", DiffEqCallbacks.SavingAffect{var"#43#47", Float64, Float64, DataStructures.BinaryMinHeap{Float64}, Vector{Float64}}, typeof(DiffEqCallbacks.saving_initialize), typeof(SciMLBase.FINALIZE_DEFAULT)}}}}, ::typeof(SciMLBase.__solve), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, Vector{Float64}, ODEFunction{false, typeof(some_model), LinearAlgebra.UniformScaling{Bool}, Nothing, var"#foo_basic_tgrad#46", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol,
Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False})
@ Zygote C:\Users\kevin\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
--- the last 5 lines are repeated 1 more time ---
[23] _apply(::Function, ::Vararg{Any})
@ Core .\boot.jl:814
@ C:\Users\kevin\.julia\packages\Zygote\IoW2g\src\lib\lib.jl:204 [inlined]
[25] _pullback
[26] _pullback
@ C:\Users\kevin\.julia\packages\DiffEqBase\KouNZ\src\solve.jl:437 [inlined]
[27] _pullback(::Zygote.Context, ::DiffEqBase.var"##solve_call#28", ::Bool, ::KeywordArgError, ::Base.Pairs{Symbol, Any, NTuple{4, Symbol}, NamedTuple{(:abstol, :reltol, :saveat, :callback), Tuple{Float64, Float64, LinRange{Float64, Int64}, DiscreteCallback{DiffEqCallbacks.var"#30#31", DiffEqCallbacks.SavingAffect{var"#43#47", Float64, Float64, DataStructures.BinaryMinHeap{Float64}, Vector{Float64}}, typeof(DiffEqCallbacks.saving_initialize), typeof(SciMLBase.FINALIZE_DEFAULT)}}}}, ::typeof(DiffEqBase.solve_call), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, Vector{Float64}, ODEFunction{false, typeof(some_model), LinearAlgebra.UniformScaling{Bool}, Nothing, var"#foo_basic_tgrad#46", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED),
Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False})
@ Zygote C:\Users\kevin\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
[28] _apply(::Function, ::Vararg{Any})
@ Core .\boot.jl:814
@ C:\Users\kevin\.julia\packages\Zygote\IoW2g\src\lib\lib.jl:204 [inlined]
[30] _pullback
[31] _pullback
@ C:\Users\kevin\.julia\packages\DiffEqBase\KouNZ\src\solve.jl:409 [inlined]
[32] _pullback(::Zygote.Context, ::DiffEqBase.var"#solve_call##kw", ::NamedTuple{(:abstol, :reltol, :saveat, :callback), Tuple{Float64, Float64, LinRange{Float64, Int64}, DiscreteCallback{DiffEqCallbacks.var"#30#31", DiffEqCallbacks.SavingAffect{var"#43#47", Float64, Float64, DataStructures.BinaryMinHeap{Float64}, Vector{Float64}}, typeof(DiffEqCallbacks.saving_initialize), typeof(SciMLBase.FINALIZE_DEFAULT)}}}, ::typeof(DiffEqBase.solve_call), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, Vector{Float64}, ODEFunction{false, typeof(some_model), LinearAlgebra.UniformScaling{Bool}, Nothing, var"#foo_basic_tgrad#46", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False})
@ Zygote C:\Users\kevin\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
[33] _apply(::Function, ::Vararg{Any})
@ Core .\boot.jl:814
@ C:\Users\kevin\.julia\packages\Zygote\IoW2g\src\lib\lib.jl:204 [inlined]
[35] _pullback
[36] _pullback
@ C:\Users\kevin\.julia\packages\DiffEqBase\KouNZ\src\solve.jl:780 [inlined]
[37] _pullback(::Zygote.Context, ::DiffEqBase.var"##solve_up#34", ::Base.Pairs{Symbol, Any, NTuple{4, Symbol}, NamedTuple{(:abstol, :reltol, :saveat, :callback), Tuple{Float64, Float64, LinRange{Float64, Int64}, DiscreteCallback{DiffEqCallbacks.var"#30#31", DiffEqCallbacks.SavingAffect{var"#43#47", Float64, Float64, DataStructures.BinaryMinHeap{Float64}, Vector{Float64}}, typeof(DiffEqCallbacks.saving_initialize), typeof(SciMLBase.FINALIZE_DEFAULT)}}}}, ::typeof(DiffEqBase.solve_up), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, Vector{Float64}, ODEFunction{false, typeof(some_model), LinearAlgebra.UniformScaling{Bool}, Nothing, var"#foo_basic_tgrad#46", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::SensitivityADPassThrough, ::Vector{Float64}, ::Vector{Float64}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False})
@ Zygote C:\Users\kevin\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
[38] _apply(::Function, ::Vararg{Any})
@ Core .\boot.jl:814
@ C:\Users\kevin\.julia\packages\Zygote\IoW2g\src\lib\lib.jl:204 [inlined]
[40] _pullback
[41] _pullback
@ C:\Users\kevin\.julia\packages\DiffEqBase\KouNZ\src\solve.jl:765 [inlined]
[42] _pullback(::Zygote.Context, ::DiffEqBase.var"#solve_up##kw", ::NamedTuple{(:abstol, :reltol, :saveat, :callback), Tuple{Float64, Float64, LinRange{Float64, Int64}, DiscreteCallback{DiffEqCallbacks.var"#30#31", DiffEqCallbacks.SavingAffect{var"#43#47", Float64, Float64, DataStructures.BinaryMinHeap{Float64}, Vector{Float64}}, typeof(DiffEqCallbacks.saving_initialize), typeof(SciMLBase.FINALIZE_DEFAULT)}}}, ::typeof(DiffEqBase.solve_up), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, Vector{Float64}, ODEFunction{false, typeof(some_model), LinearAlgebra.UniformScaling{Bool}, Nothing, var"#foo_basic_tgrad#46", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::SensitivityADPassThrough, ::Vector{Float64}, ::Vector{Float64}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False})
@ Zygote C:\Users\kevin\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
[43] _apply(::Function, ::Vararg{Any})
@ Core .\boot.jl:814
@ C:\Users\kevin\.julia\packages\Zygote\IoW2g\src\lib\lib.jl:204 [inlined]
[45] _pullback
[46] _pullback
@ C:\Users\kevin\.julia\packages\DiffEqBase\KouNZ\src\solve.jl:760 [inlined]
[47] _pullback(::Zygote.Context, ::DiffEqBase.var"##solve#33", ::SensitivityADPassThrough, ::Nothing, ::Nothing, ::Base.Pairs{Symbol, Any, NTuple{4, Symbol}, NamedTuple{(:abstol, :reltol, :saveat, :callback), Tuple{Float64, Float64, LinRange{Float64, Int64}, DiscreteCallback{DiffEqCallbacks.var"#30#31", DiffEqCallbacks.SavingAffect{var"#43#47", Float64, Float64, DataStructures.BinaryMinHeap{Float64}, Vector{Float64}}, typeof(DiffEqCallbacks.saving_initialize), typeof(SciMLBase.FINALIZE_DEFAULT)}}}}, ::typeof(solve), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, Vector{Float64}, ODEFunction{false, typeof(some_model), LinearAlgebra.UniformScaling{Bool}, Nothing, var"#foo_basic_tgrad#46", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False})
@ Zygote C:\Users\kevin\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
[48] _apply(::Function, ::Vararg{Any})
@ Core .\boot.jl:814
@ C:\Users\kevin\.julia\packages\Zygote\IoW2g\src\lib\lib.jl:204 [inlined]
[50] _pullback
[51] _pullback
@ C:\Users\kevin\.julia\packages\DiffEqBase\KouNZ\src\solve.jl:755 [inlined]
[52] _pullback(::Zygote.Context, ::CommonSolve.var"#solve##kw", ::NamedTuple{(:abstol, :reltol, :saveat, :sensealg, :callback), Tuple{Float64, Float64, LinRange{Float64, Int64}, SensitivityADPassThrough, DiscreteCallback{DiffEqCallbacks.var"#30#31", DiffEqCallbacks.SavingAffect{var"#43#47", Float64, Float64, DataStructures.BinaryMinHeap{Float64}, Vector{Float64}}, typeof(DiffEqCallbacks.saving_initialize), typeof(SciMLBase.FINALIZE_DEFAULT)}}}, ::typeof(solve), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, Vector{Float64}, ODEFunction{false, typeof(some_model), LinearAlgebra.UniformScaling{Bool}, Nothing, var"#foo_basic_tgrad#46", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False})
@ Zygote C:\Users\kevin\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
[53] _pullback
@ c:\Users\kevin\.julia\dev\MiNN\examples\foo.jl:29 [inlined]
[54] _pullback(ctx::Zygote.Context, f::var"#42#45", args::Vector{Float64})
@ Zygote C:\Users\kevin\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
[55] _pullback(f::Function, args::Vector{Float64})
@ Zygote C:\Users\kevin\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:34
[56] pullback(f::Function, args::Vector{Float64})
@ Zygote C:\Users\kevin\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:40
@ Zygote C:\Users\kevin\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:75
``````

MWE

``````using DifferentialEquations
using SciMLSensitivity
using Plots
using Zygote

# Some model with multiple steady states
function some_model(u, p, t)
dxdt = p[1]*u[1]*u[2]
dydt = -p[2]*u[2]
return [dxdt, dydt]
end

x0 = hcat(rand(10), ones(10)) # Initialize some random initial conditions
foop = [1., 1.] # Parameters
tspan = LinRange(0, 50, 25)

# Do some random gradients just to make sure code works for parametrizing stuff
global tspan
global x0

# From paper Opening the Blackbox: Accelerating Neural Differential Equations by Regularizing Internal Solver Heuristics

foo_basic_tgrad(u,p,t) = zero(u) # From NODE code in diffeqflux.
foo_sv = SavedValues(eltype(tspan), eltype(foop))
foo_func = (u, t, integrator) -> integrator.EEst * integrator.dt # Error estimate and time step
foo_svcb = SavingCallback(foo_func, foo_sv)
foo_prob = ODEProblem{false}(foo_ff, x0[1, :], (tspan[1], tspan[end]), foop)
foo_sol = solve(
foo_prob,
Tsit5(),
abstol=1e-6,
reltol=1e-6,
saveat=tspan;
callback = foo_svcb
)
Zygote.ignore_derivatives() do
plot(foo_sol) # Plot to see what is going on
end

# Some random function
_some_sol_mat = mapreduce(permutedims, vcat, foo_sol.u)
foo_val = sum(abs2, _som_sol_mat) + sum(foo_sv)
@show foo_val
return foo_val
end
``````

Solved, in a way. The issue is that for steady state problems you should use SteadyStateAdjoint. Throw a better error for time-based adjoint on no-time problem by ChrisRackauckas · Pull Request #705 · SciML/SciMLSensitivity.jl · GitHub solves this by throwing a very explicit error message saying that.