How to differentiate through the ODE solver (Discretize-then-Optimize)?

Does anyone know how to differentiate through the solver? I wanted to try to implement this technique over here Opening the Blackbox: Accelerating Neural Differential Equations by Regularizing Internal Solver Heuristics. I’m encountering an issue with the AD part. Can someone help me out? I pasted a MLE and the corresponding error message.

MLE

using DifferentialEquations, SciMLSensitivity, Tracker, Zygote, ReverseDiff

function randode(u, p, t)
    return [u[1]*p[1]*sin(t)]
end

function simulate_randode(u0, tspan)
    my_integrator = init(ODEProblem(randode, u0, (tspan[1], tspan[end])), Tsit5(); sensealg=TrackerAdjoint)
    errs = 0.0
    sol = zeros(size(u0)[1])
    for (u, t) in TimeChoiceIterator(my_integrator, tspan)
        errs = errs + my_integrator.EEst
        sol = vcat(sol, u)
    end
    return sol, errs
end

x0 = [0.1]
tspan = collect(range(0, 24, 50))
mysol, myerrs = simulate_randode(x0, tspan)
Tracker.gradient([0.1]) do u0
    _, _errs = simulate_randode(u0, tspan)
    return sum(abs2, _errs)
end

Error

ERROR: MethodError: no method matching Float64(::Tracker.TrackedReal{Float64})
Closest candidates are:
  (::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat at rounding.jl:200
  (::Type{T})(::T) where T<:Number at boot.jl:772
  (::Type{T})(::AbstractChar) where T<:Union{AbstractChar, Number} at char.jl:50
  ...

I just recently improved error messages, and the newer error message is much clearer as to what’s going on here:

ERROR: An indexing operation was performed on a NullParameters object. This means no parameters were passed
into the AbstractSciMLProblem (e.x.: ODEProblem) but the parameters object `p` was used in an indexing
expression (e.x. `p[i]`, or `x .+ p`). Two common reasons for this issue are:

1. Forgetting to pass parameters into the problem constructor. For example, `ODEProblem(f,u0,tspan)` should
be `ODEProblem(f,u0,tspan,p)` in order to use parameters.

2. Using the wrong function signature. For example, with `ODEProblem`s the function signature is always
`f(du,u,p,t)` for the in-place form or `f(u,p,t)` for the out-of-place form. Note that the `p` argument
will always be in the function signature reguardless of if the problem is defined with parameters!


Stacktrace:
  [1] getindex(#unused#::SciMLBase.NullParameters, i::Int64)
    @ SciMLBase c:\Users\accou\.julia\dev\SciMLBase\src\problems\problem_utils.jl:166
  [2] randode(u::Vector{Float64}, p::SciMLBase.NullParameters, t::Float64)
    @ Main c:\Users\accou\OneDrive\Computer\Desktop\test.jl:54
  [3] (::ODEFunction{false,SciMLBase.AutoSpecialize,…})(::Vector{Float64}, ::Vararg{Any})
    @ SciMLBase C:\Users\accou\.julia\dev\SciMLBase\src\scimlfunctions.jl:2404
  [4] initialize!(integrator::ODEIntegrator{false, Tsit5{Static.False,…}, Vector{Float64}, Float64,…}, cache::OrdinaryDiffEq.Tsit5ConstantCache)
    @ OrdinaryDiffEq C:\Users\accou\.julia\dev\OrdinaryDiffEq\src\perform_step\low_order_rk_perform_step.jl:700
  [5] __init(prob::ODEProblem{false,Vector{Float64},Tuple{Float64, Float64},…}, alg::Tsit5{Static.False,…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{Val{true}}; saveat::Tuple{}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::Nothing, dense::Bool, calck::Bool, dt::Float64, dtmin::Nothing, dtmax::Float64, force_dtmin::Bool, adaptive::Bool, gamma::Rational{Int64}, abstol::Nothing, reltol::Nothing, qmin::Rational{Int64}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Rational{Int64}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ OrdinaryDiffEq C:\Users\accou\.julia\dev\OrdinaryDiffEq\src\solve.jl:499
  [6] __init (repeats 5 times)
    @ C:\Users\accou\.julia\dev\OrdinaryDiffEq\src\solve.jl:10 [inlined]
  [7] #init_call#18
    @ C:\Users\accou\.julia\packages\DiffEqBase\jSoyI\src\solve.jl:412 [inlined]
  [8] init_call
    @ C:\Users\accou\.julia\packages\DiffEqBase\jSoyI\src\solve.jl:392 [inlined]
  [9] #init_up#21
    @ C:\Users\accou\.julia\packages\DiffEqBase\jSoyI\src\solve.jl:459 [inlined]
 [10] init_up
    @ C:\Users\accou\.julia\packages\DiffEqBase\jSoyI\src\solve.jl:432 [inlined]
 [11] #init#19
    @ C:\Users\accou\.julia\packages\DiffEqBase\jSoyI\src\solve.jl:425 [inlined]
 [12] init(prob::ODEProblem{false,Vector{Float64},Tuple{Float64, Float64},…}, args::Tsit5{Static.False,…})
    @ DiffEqBase C:\Users\accou\.julia\packages\DiffEqBase\jSoyI\src\solve.jl:416
 [13] simulate_randode(u0::Vector{Float64}, tspan::Vector{Float64})
    @ Main c:\Users\accou\OneDrive\Computer\Desktop\test.jl:58
 [14] top-level scope
    @ c:\Users\accou\OneDrive\Computer\Desktop\test.jl:70

So I assume you made a mistake in the randode that you posted. The other issues are quirks of having to deal with Tracker.jl. The working code is:

using DifferentialEquations, Tracker

function randode(u, p, t)
    return u .* sin(t)
end

function simulate_randode(u0, tspan)
    my_integrator = init(ODEProblem(randode, u0, eltype(u0).((tspan[1], tspan[end]))), Tsit5(), dt = 0.1)
    errs = zero(eltype(u0))
    sol = zeros(size(u0)[1])
    for (u, t) in TimeChoiceIterator(my_integrator, tspan)
        errs = errs + my_integrator.EEst
        sol = vcat(sol, u)
    end
    return sol, errs
end

x0 = [0.1]
tspan = collect(range(0, 24, 50))
mysol, myerrs = simulate_randode(x0, tspan)
Tracker.gradient([0.1]) do u0
    _, _errs = simulate_randode(u0, tspan)
    return sum(abs2, _errs)
end

which gives:

julia> Tracker.gradient([0.1]) do u0
           _, _errs = simulate_randode(u0, tspan)
           return sum(abs2, _errs)
       end
([1514.1961857601827] (tracked),)

The things involved here are all little Tracker details, but to note:

  1. eltype(u0).((tspan[1], tspan[end])) change the integration to also differentiate time by making tspan also be Tracked values. You need time to be differentiated since in this method you want to differentiate with respect to the dt choices and thus EEst. Otherwise it’s just differentiate with respect to state.
  2. errs = zero(eltype(u0)) make sure the accumulation values are Tracked values.
  3. u .* sin(t)

(3) is a deeper point with tracked array systems. Essentially, it boils down to:

julia> u = TrackedArray([0.1])
Tracked 1-element Vector{Float64}:
 0.1

julia> [u[1] * sin(0.0)]
1-element Vector{Tracker.TrackedReal{Float64}}:
 0.0

julia> u .* sin(0.0)
Tracked 1-element Vector{Float64}:
 0.0

If you do print statements inside of the simulate_randode call (or look at the stack trace), you see that this is precisely what was going on in your case. When Tracker starts taking gradients, u0 is a TrackedArray, i.e. it’s an array type of being differentiated. When you do u[1], you get a scalar type of something being differentiated. [u[1]] is an array of scalar types being differentiated (Array{TrackedReal}), which is different from the TrackedArray "an array being differentiated). It’s not only less efficient (because it differentiates per scalar op instead of by the array operations) but it’s also a type change when it goes into this form, so the ODE solver errors because it expects randode(u, p, t) to give out the same type it put in. That’s the weird (incomprehensible) MethodError: no method matching Float64(::Tracker.TrackedReal{Float64}).

So there are three things you can do. One, you can differentiate always in scalar mode by transforming to scalar mode before doing the ODE solve.

init(ODEProblem(randode, eltype(u0).(u0), eltype(u0).((tspan[1], tspan[end]))), Tsit5(), dt = 0.1)

Looks weird but it works. Or two, you can just use array operations inside of the ODE function, like u .* sin(t). This is the reason why neural networks just work: they only use these kinds of array functions, so it completely avoids this issue. Or the third thing you could do is convert the Array{TrackedReal} into a TrackedArray manually. This is done via Tracker.collect(x). So the following would also work:

using DifferentialEquations, Tracker

function randode(u, p, t)
    out = [u[1] * sin(t)]
    if eltype(out) <: Tracker.TrackedReal
        return @show Tracker.collect(out)
    else
        return out
    end
end

function simulate_randode(u0, tspan)
    my_integrator = init(ODEProblem(randode, u0, eltype(u0).((tspan[1], tspan[end]))), Tsit5(), dt = 0.1)
    errs = zero(eltype(u0))
    sol = zeros(size(u0)[1])
    for (u, t) in TimeChoiceIterator(my_integrator, tspan)
        errs = errs + my_integrator.EEst
        sol = vcat(sol, u)
    end
    return sol, errs
end

x0 = [0.1]
tspan = collect(range(0, 24, 50))
mysol, myerrs = simulate_randode(x0, tspan)
Tracker.gradient([0.1]) do u0
    _, _errs = simulate_randode(u0, tspan)
    return sum(abs2, _errs)
end

(And these are the reasons why not many people use Tracker anymore, but it’s still pretty cool for hacking odd things together :sweat_smile: )

3 Likes