Error trying to ForwardDiff through an ODE solver

I am trying to use AD to differentiate the solution of an ODE. Various docs seem to indicate this is very well supported, so I was surprised to get an error.

For some additional color, I got a different error before adding
“ODEProblem{true, SciMLBase.FullSpecialize}”. I don’t quite understand what this does, but it was recommended in another discourse topic. I also get a slightly different error when switching between Tsit5() and Vern9().

Copy-pastable MWE:

using ModelingToolkit, DifferentialEquations
using ForwardDiff

function ADTest()
	@parameters a b
	@variables t x1(t) x2(t) y1(t) y2(t)
	D = Differential(t)
	states = [x1, x2]
	parameters = [a, b]

	@named model = ODESystem([
			D(x1) ~ a * x1,
			D(x2) ~ b * x2,
		], t, states, parameters)
	model = structural_simplify(model)
	measured_quantities = [
		y1 ~ x1,
		y2 ~ x2]

	ic = Dict(x1 => 1.0, x2 => 2.0)
	p_true = Dict(a => 2.0, b => 3.0)

	problem = ODEProblem(model, ic, [0.0, 1e-5], p_true)
	soln = ModelingToolkit.solve(problem, Tsit5(), abstol = 1e-14, reltol = 1e-14)
	display(soln(1e-5, idxs = [x1, x2]))

	function different_time(new_ic, new_params, new_t)
		newprob = ODEProblem{true, SciMLBase.FullSpecialize}(model, new_ic, [0.0, new_t], new_params)
		new_soln = ModelingToolkit.solve(newprob, Tsit5(),  abstol = 1e-14, reltol = 1e-14)
		return (soln(new_t, idxs = [x1, x2]))
    display(different_time(ic,p_true, 2e-5))

    temp = ForwardDiff.derivative(s -> different_time(ic,p_true, s),4e-5)



ERROR: LoadError: MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{var"#209#211"{var"#different_time#210"{ODESolution{…}, Num, Num}, Dict{Num, Float64}, Dict{Num, Float64}}, Float64}, Float64, 1})

Closest candidates are:
  (::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat
   @ Base rounding.jl:207
  (::Type{T})(::T) where T<:Number
   @ Core boot.jl:792
   @ IrrationalConstants ~/.julia/packages/IrrationalConstants/vp5v4/src/macro.jl:112

  [1] convert(::Type{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{var"#209#211"{var"#different_time#210"{ODESolution{…}, Num, Num}, Dict{Num, Float64}, Dict{Num, Float64}}, Float64}, Float64, 1})
    @ Base ./number.jl:7
  [2] setindex!(A::Vector{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{var"#209#211"{var"#different_time#210"{ODESolution{…}, Num, Num}, Dict{Num, Float64}, Dict{Num, Float64}}, Float64}, Float64, 1}, i1::Int64)
    @ Base ./array.jl:1021
  [3] macro expansion
    @ ~/.julia/packages/OrdinaryDiffEq/GAgjL/src/initdt.jl:119 [inlined]
  [4] macro expansion
    @ ./simdloop.jl:77 [inlined]
  [5] ode_determine_initdt(u0::Vector{…}, t::ForwardDiff.Dual{…}, tdir::ForwardDiff.Dual{…}, dtmax::ForwardDiff.Dual{…}, abstol::Float64, reltol::Float64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), prob::ODEProblem{…}, integrator::OrdinaryDiffEq.ODEIntegrator{…})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/GAgjL/src/initdt.jl:118
  [6] auto_dt_reset!
    @ ~/.julia/packages/OrdinaryDiffEq/GAgjL/src/integrators/integrator_interface.jl:453 [inlined]
  [7] handle_dt!(integrator::OrdinaryDiffEq.ODEIntegrator{…})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/GAgjL/src/solve.jl:571
  [8] __init(prob::ODEProblem{…}, alg::Tsit5{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{…}; saveat::Tuple{}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::Nothing, dense::Bool, calck::Bool, dt::ForwardDiff.Dual{…}, dtmin::ForwardDiff.Dual{…}, dtmax::ForwardDiff.Dual{…}, force_dtmin::Bool, adaptive::Bool, gamma::Rational{…}, abstol::Float64, reltol::Float64, qmin::Rational{…}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Rational{…}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::@Kwargs{})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/GAgjL/src/solve.jl:533
  [9] __init (repeats 5 times)
    @ ~/.julia/packages/OrdinaryDiffEq/GAgjL/src/solve.jl:11 [inlined]
 [10] #__solve#787
    @ ~/.julia/packages/OrdinaryDiffEq/GAgjL/src/solve.jl:6 [inlined]
 [11] __solve
    @ ~/.julia/packages/OrdinaryDiffEq/GAgjL/src/solve.jl:1 [inlined]
 [12] solve_call(_prob::ODEProblem{…}, args::Tsit5{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/WyGjp/src/solve.jl:612
 [13] solve_call
    @ ~/.julia/packages/DiffEqBase/WyGjp/src/solve.jl:569 [inlined]
 [14] #solve_up#53
    @ ~/.julia/packages/DiffEqBase/WyGjp/src/solve.jl:1080 [inlined]
 [15] solve_up
    @ ~/.julia/packages/DiffEqBase/WyGjp/src/solve.jl:1066 [inlined]
 [16] #solve#51
    @ ~/.julia/packages/DiffEqBase/WyGjp/src/solve.jl:1003 [inlined]
 [17] (::var"#different_time#210"{ODESolution{…}, Num, Num})(new_ic::Dict{Num, Float64}, new_params::Dict{Num, Float64}, new_t::ForwardDiff.Dual{ForwardDiff.Tag{…}, Float64, 1})
    @ Main ~/learning/ODETests/PLI/MWE2.jl:32
 [18] #209
    @ ~/learning/ODETests/PLI/MWE2.jl:37 [inlined]
 [19] derivative(f::var"#209#211"{var"#different_time#210"{ODESolution{…}, Num, Num}, Dict{Num, Float64}, Dict{Num, Float64}}, x::Float64)
    @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/derivative.jl:14

We should probably update the title. This has nothing to do with diffing through the solver. The issue is that you are doing a symbolic code generation process in your loss function and trying to differentiate the symbolic codegen. This is a known issue with ModelingToolkit and unrelated to the differential equation solver:

Though of course, it’s kind of not sensical in a way because, just as the issue describes at the top, you shouldn’t be doing a symbolic codegen in the loss function in the first place. So if you did something standard like:

	function different_time(new_ic, new_params, new_t)
		newprob = remake(problem, new_ic, tspan = (0.0, new_t), p=new_params)
		new_soln = ModelingToolkit.solve(newprob, Tsit5(),  abstol = 1e-14, reltol = 1e-14)
		return (soln(new_t, idxs = [x1, x2]))

Then it should work fine, and it wouldn’t need to codegen new models every step. Note that the documentation specifically has a guide on the symbolic tooling which is helpful for further optimizing this code via setp and setu:

Building new models every step is useful for things like genetic algorithms which are trying to learn what the equations are via some evolution, but any change to the equations is generally a discrete change to the gradient (or adding new parameters) in which case it’s hard to think of a legitimate use of doing the codegen within the solving process itself. So it hasn’t gotten the highest priority to fix this, but since it is a workflow thing some newcomers may run into we will get around to fixing it. I expect to throw a warning though, i.e. if we notice you’re diffing this, we warn you by default (with an option to turn off) mentioning that you likely want to reuse generated models via remake, setp, etc. see that page, as a way to help users find the right solution.

Hi. I had actually been use remake before, and switched to ODEProblem to try and debug. Below is a different copy-pastable code. It uses remake() as you recommend, and I tried hooking up to 7 different AD methods. They all fail, but they give slightly different errors.

All of the errors are around typing. My guess is that there is some problem with specifying the timespan with anything not a Float64 (i.e. the Dual numbers are an issue), but I’m surprised all of these AD systems make the same assumption. For instance, here is the error from ReverseDiff:

ERROR: LoadError: ArgumentError: Converting an instance of ReverseDiff.TrackedReal{Float64, Float64, Nothing} to Float64 is not defined. Please use `ReverseDiff.value` instead.

New MWE:

using ModelingToolkit, DifferentialEquations
using TaylorDiff, ForwardDiff
using DifferentiationInterface, Enzyme, Zygote, ReverseDiff

function ADTest()
	@parameters a
	@variables t x1(t) 
	D = Differential(t)
	states = [x1]
	parameters = [a]

	@named pre_model = ODESystem([D(x1) ~ a * x1], t, states, parameters)
	model = structural_simplify(pre_model)

	ic = Dict(x1 => 1.0)
	p_true = Dict(a => 2.0)

	problem = ODEProblem{true, SciMLBase.FullSpecialize}(model, ic, [0.0, 1.0], p_true)
	soln = ModelingToolkit.solve(problem, Tsit5(), abstol = 1e-12, reltol = 1e-12)
	display(soln(0.5, idxs = [x1]))

	function different_time(new_ic, new_params, new_t)
		#newprob = ODEProblem{true, SciMLBase.FullSpecialize}(model, new_ic, [0.0, new_t*2], new_params)
		#newprob = remake(problem, u0=new_ic, tspan = [0.0, new_t], p = new_params)
		newprob = remake(problem, u0 = new_ic, tspan = [0.0, new_t], p=new_params)
        new_soln = ModelingToolkit.solve(newprob, Tsit5(), abstol = 1e-12, reltol = 1e-12)
		return (soln(new_t, idxs = [x1]))

	function just_t(new_t)
		return different_time(ic, p_true, new_t)[1]
	display(different_time(ic, p_true, 2e-5))

    g = ForwardDiff.derivative(just_t,4e-5)
	g = TaylorDiff.derivative(just_t,4e-5,1)
    value_and_gradient(just_t, AutoForwardDiff(), 1.0) 
	value_and_gradient(just_t, AutoReverseDiff(), 1.0) 	
    value_and_gradient(just_t, AutoEnzyme(Enzyme.Reverse), 1.0) 
	value_and_gradient(just_t, AutoEnzyme(Enzyme.Forward), 1.0) 
    value_and_gradient(just_t, AutoZygote(), 1.0) 


Remake should promote the type of of u0 to match, i.e. equivalent to:

newprob = remake(problem, u0 = new_ic, tspan = [0.0, new_t], p=new_params)
newprob = remake(newprob, u0 = typeof(new_t).(newprob.u0))

I’m traveling right now, but double check that works.

Can you open an issue in ModelingToolkit.jl? take note this might be a promotion case that was missed.

The use of DifferentiationInterface requires creating closures which can both prevent Enzyme from differentiating code, as well as hinder performance.

What happens if you just use Enzyme.autodiff of different_time

I can confirm that your fix makes ForwardDiff work (and the value is sane).

What’s your opinion on the whether I need to put FullSpecialize vs NoSpecialize vs nothing? Which is most likely to help AD packages succeed?

ReverseDiff, TaylorDiff, Enzyme, and Zygote all fail for different reasons. I’ll try and notify various places about… 5 different issues or so. (I haven’t tried Enzyme except through DifferentiationInterface.)

You shouldn’t need to do any of that.

ReverseDiff and TaylorDiff should be the same fix. This just needs a ModelingToolkit issue.

For bookkeeping, here’s the issue at MTK, feel free to add comments

I also filed an issue with TaylorDiff.jl

It is possible, as you say, that the solution for TaylorDiff lies in MTK, it’s hard for me to tell. But it looks like at a minimum it needs to handle isnan() on their type.

The error with ReverseDiff is quite complicated and I’m not 100% sure if it’s in ReverseDiff or the DifferentiationInterface layer. I’ll post it in the next comment.

Here’s the error from ReverseDiff, after your workaround to cast u0.

ERROR: LoadError: MethodError: no method matching length(::ModelingToolkit.MTKParameters{Tuple{Vector{Float64}}, Tuple{}, Tuple{}, Tuple{}, Tuple{}, Nothing, Nothing})

Closest candidates are:
   @ SymbolicUtils ~/.julia/packages/SymbolicUtils/JhFWV/src/utils.jl:225
   @ RecursiveArrayTools ~/.julia/packages/RecursiveArrayTools/plzpk/src/utils.jl:326
   @ MutableArithmetics ~/.julia/packages/MutableArithmetics/iovKe/src/rewrite.jl:104

  [1] automatic_sensealg_choice(prob::ODEProblem{…}, u0::Vector{…}, p::ModelingToolkit.MTKParameters{…}, verbose::Bool)
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/rXkM4/src/concrete_solve.jl:84
  [2] _concrete_solve_adjoint(::ODEProblem{…}, ::Tsit5{…}, ::Nothing, ::Vector{…}, ::ModelingToolkit.MTKParameters{…}, ::SciMLBase.ReverseDiffOriginator; verbose::Bool, kwargs::@Kwargs{…})
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/rXkM4/src/concrete_solve.jl:218
  [3] _solve_adjoint(prob::ODEProblem{…}, sensealg::Nothing, u0::Vector{…}, p::ModelingToolkit.MTKParameters{…}, originator::SciMLBase.ReverseDiffOriginator, args::Tsit5{…}; merge_callbacks::Bool, kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/X5SZr/src/solve.jl:1537
  [4] (::DiffEqBaseReverseDiffExt.var"##solve_up#225#23"{…})(prob::ODEProblem{…}, sensealg::Nothing, u0::ReverseDiff.TrackedArray{…}, p::ModelingToolkit.MTKParameters{…}, args::Tsit5{…}; kwargs::@Kwargs{…})
    @ DiffEqBaseReverseDiffExt ~/.julia/packages/DiffEqBase/X5SZr/ext/DiffEqBaseReverseDiffExt.jl:159
  [5] track(::typeof(DiffEqBase.solve_up), prob::ODEProblem{…}, sensealg::Nothing, u0::ReverseDiff.TrackedArray{…}, p::ModelingToolkit.MTKParameters{…}, args::Tsit5{…}; kwargs::@Kwargs{…})
    @ DiffEqBaseReverseDiffExt ~/.julia/packages/ReverseDiff/p1MzG/src/macros.jl:195
  [6] solve_up(prob::ODEProblem{…}, sensealg::Nothing, u0::ReverseDiff.TrackedArray{…}, p::ModelingToolkit.MTKParameters{…}, args::Tsit5{…}; kwargs::@Kwargs{…})
    @ DiffEqBaseReverseDiffExt ~/.julia/packages/DiffEqBase/X5SZr/ext/DiffEqBaseReverseDiffExt.jl:100
  [7] solve_up(prob::ODEProblem{…}, sensealg::Nothing, u0::Vector{…}, p::ModelingToolkit.MTKParameters{…}, args::Tsit5{…}; kwargs::@Kwargs{…})
    @ DiffEqBaseReverseDiffExt ~/.julia/packages/DiffEqBase/X5SZr/ext/DiffEqBaseReverseDiffExt.jl:142
  [8] solve(prob::ODEProblem{…}, args::Tsit5{…}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/X5SZr/src/solve.jl:1003
  [9] (::var"#different_time#9"{ODESolution{…}, ODEProblem{…}, Num})(new_ic::Dict{Num, Float64}, new_params::Dict{Num, Float64}, new_t::ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{…}})
    @ Main ~/learning/ODETests/PLI/MWE3.jl:35
 [10] (::var"#just_t#10"{var"#different_time#9"{…}, Dict{…}, Dict{…}})(new_t::ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{…}})
    @ Main ~/learning/ODETests/PLI/MWE3.jl:40
 [11] call_composed
    @ ./operators.jl:1044 [inlined]
 [12] (::ComposedFunction{var"#just_t#10"{var"#different_time#9"{…}, Dict{…}, Dict{…}}, typeof(only)})(x::ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}; kw::@Kwargs{})
    @ Base ./operators.jl:1041
 [13] ComposedFunction
    @ ./operators.jl:1041 [inlined]
 [14] ReverseDiff.GradientTape(f::ComposedFunction{var"#just_t#10"{…}, typeof(only)}, input::Vector{Float64}, cfg::ReverseDiff.GradientConfig{ReverseDiff.TrackedArray{…}})
    @ ReverseDiff ~/.julia/packages/ReverseDiff/p1MzG/src/api/tape.jl:199
 [15] gradient(f::Function, input::Vector{Float64}, cfg::ReverseDiff.GradientConfig{ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}})
    @ ReverseDiff ~/.julia/packages/ReverseDiff/p1MzG/src/api/gradients.jl:22
 [16] gradient
    @ ~/.julia/packages/ReverseDiff/p1MzG/src/api/gradients.jl:22 [inlined]
 [17] value_and_pullback(f::ComposedFunction{var"#just_t#10"{…}, typeof(only)}, ::AutoReverseDiff, x::Vector{Float64}, dy::Float64, ::DifferentiationInterface.NoPullbackExtras)
    @ DifferentiationInterfaceReverseDiffExt ~/.julia/packages/DifferentiationInterface/9POaB/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl:10
 [18] value_and_pullback(f::Function, backend::AutoReverseDiff, x::Vector{Float64}, dy::Float64)
    @ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/9POaB/src/pullback.jl:96
 [19] value_and_pullback
    @ ~/.julia/packages/DifferentiationInterface/9POaB/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl:35 [inlined]
 [20] value_and_gradient
    @ ~/.julia/packages/DifferentiationInterface/9POaB/src/gradient.jl:57 [inlined]
 [21] value_and_gradient(f::Function, backend::AutoReverseDiff, x::Float64)
    @ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/9POaB/src/gradient.jl:57
 [22] ADTest()
    @ Main ~/learning/ODETests/PLI/MWE3.jl:49
 [23] top-level scope
    @ ~/learning/ODETests/PLI/MWE3.jl:56
 [24] include(fname::String)
    @ Base.MainInclude ./client.jl:489
 [25] top-level scope
    @ REPL[1]:1
in expression starting at /home/orebas/learning/ODETests/PLI/MWE3.jl:56
Some type information was truncated. Use `show(err)` to see complete types.

Thanks, the two forward ones are easy.

The ReverseDiff one is just an MTK v9 thing that’s known. It’ll be solved by:

Which hopefully should be in a few weeks. Zygote and Enzyme reverse will also hit similar issues to this one.