Updating integrator parameter/initial condition with ForwardDiff.Dual types

Suppose I have the following function which takes the solution to an ODE and returns a sum-of-squared errors, making use of the integrator interface (which I need to use for my actual application, not necessary for this MWE):

using DifferentialEquations, ForwardDiff
f(u, p, t) = p * u
u0 = 1 / 2
tspan = (0.0, 1.0)
prob = ODEProblem(f, u0, tspan, 1.0)
integrator = DifferentialEquations.init(prob, Tsit5(), saveat = 0:0.01:1.0)
sol = solve!(integrator)
y = exp.(0:0.01:1.0)
function g(λ, integrator, y)
    integrator.p = λ 
    reinit!(integrator)
    sol = solve!(integrator)
    return sum((sol .- y).^2)
end
ℓ = g(1.0, integrator, y) # example use

I want to compute \partial g/\partial \lambda using ForwardDiff, using say

ForwardDiff.derivative(λ -> g(λ, integrator, y), 1.0)

This doesn’t work because integrator.p starts off as a Float64, whereas integrator.p = λ tries to change it to ForwardDiff.Dual{T} with T representing the associated tags for g. Is there a way to make this differentiation possible? I’m also interested in (again using this integrator interface) differentiating with respect to the initial condition, but using say reinit!(integrator, λ) leads to a similar error. The full stacktrace for ForwardDiff.derivative(λ -> g(λ, integrator, y), 1.0) above is:

ERROR: MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 1})
Closest candidates are:
  (::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat at C:\Users\user\AppData\Local\Programs\Julia-1.7.0\share\julia\base\rounding.jl:200
  (::Type{T})(::T) where T<:Number at C:\Users\user\AppData\Local\Programs\Julia-1.7.0\share\julia\base\boot.jl:770
  (::Type{T})(::VectorizationBase.Double{T}) where T<:Union{Float16, Float32, Float64, VectorizationBase.Vec{<:Any, <:Union{Float16, Float32, Float64}}, VectorizationBase.VecUnroll{var"#s31", var"#s30", var"#s29", V} where {var"#s31", var"#s30", var"#s29"<:Union{Float16, Float32, Float64}, V<:Union{Bool, Float16, Float32, Float64, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8, SIMDTypes.Bit, VectorizationBase.AbstractSIMD{var"#s30", var"#s29"}}}} at C:\Users\user\.julia\packages\VectorizationBase\9edvL\src\special\double.jl:100
  ...
Stacktrace:
 [1] convert(#unused#::Type{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 1})
   @ Base .\number.jl:7
 [2] setproperty!(x::OrdinaryDiffEq.ODEIntegrator{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, false, Float64, Nothing, Float64, Float64, Float64, Float64, Float64, Float64, Vector{Float64}, ODESolution{Float64, 1, Vector{Float64}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Float64}}, ODEProblem{Float64, Tuple{Float64, Float64}, false, 
Float64, ODEFunction{false, typeof(f), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{false, typeof(f), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Float64}, Vector{Float64}, Vector{Vector{Float64}}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}, DiffEqBase.DEStats}, ODEFunction{false, typeof(f), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}, OrdinaryDiffEq.DEOptions{Float64, Float64, Float64, Float64, PIController{Rational{Int64}}, typeof(DiffEqBase.ODE_DEFAULT_NORM), typeof(LinearAlgebra.opnorm), Nothing, CallbackSet{Tuple{}, Tuple{}}, typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), DataStructures.BinaryHeap{Float64, DataStructures.FasterForward}, DataStructures.BinaryHeap{Float64, DataStructures.FasterForward}, Nothing, Nothing, Int64, Tuple{}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Tuple{}}, Float64, Float64, Nothing, OrdinaryDiffEq.DefaultInit}, f::Symbol, v::ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 1})
   @ Base .\Base.jl:43
 [3] g(λ::ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 1}, integrator::OrdinaryDiffEq.ODEIntegrator{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, false, Float64, Nothing, Float64, Float64, Float64, Float64, Float64, Float64, Vector{Float64}, ODESolution{Float64, 1, Vector{Float64}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Float64}}, ODEProblem{Float64, Tuple{Float64, Float64}, false, Float64, ODEFunction{false, typeof(f), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{false, typeof(f), LinearAlgebra.UniformScaling{Bool}, Nothing, 
Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Float64}, Vector{Float64}, Vector{Vector{Float64}}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}, DiffEqBase.DEStats}, ODEFunction{false, typeof(f), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}, OrdinaryDiffEq.DEOptions{Float64, Float64, Float64, Float64, PIController{Rational{Int64}}, typeof(DiffEqBase.ODE_DEFAULT_NORM), typeof(LinearAlgebra.opnorm), Nothing, CallbackSet{Tuple{}, Tuple{}}, typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), DataStructures.BinaryHeap{Float64, DataStructures.FasterForward}, DataStructures.BinaryHeap{Float64, DataStructures.FasterForward}, Nothing, Nothing, Int64, Tuple{}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Tuple{}}, Float64, Float64, Nothing, OrdinaryDiffEq.DefaultInit}, y::Vector{Float64})
   @ Main c:\Users\user\linear_exponential_ode.jl:75
 [4] (::var"#9#10")(λ::ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 1})
   @ Main c:\Users\user\linear_exponential_ode.jl:81
 [5] derivative(f::var"#9#10", x::Float64)
   @ ForwardDiff C:\Users\user\.julia\packages\ForwardDiff\wAaVJ\src\derivative.jl:14
 [6] top-level scope
   @ c:\Users\user\linear_exponential_ode.jl:81

reinit! is designed to be non-allocating, allowing one to fill the values of the integrator and go again. You cannot fill an array of Float64’s with an array of Dual numbers, that’s just not possible because there isn’t even enough memory space there (dual numbers are larger bitsizes becaues they are multiple floating point numbers). So it’s not that the library doesn’t support it, this is something that in principle isn’t doable. You’d have to create a dual valued integrator which you then reinit! in order for the differentiation to be non-allocating, and to do that the easiest way would be to directly define the dual tag and the chunk size so that you know the type in advance, init on that type, and then reinit with that same dual type.

1 Like

That makes sense. I suspected that to be the case.

You’d have to create a dual valued integrator which you then reinit! in order for the differentiation to be non-allocating, and to do that the easiest way would be to directly define the dual tag and the chunk size so that you know the type in advance, init on that type, and then reinit with that same dual type.

Do you have an example of how this might be setup? I know that the Dual numbers do require the tags you mention, though I’ve never done it manually (and my actual use may have this function nested as a closure inside an optimization problem from Optimization.jl, which I think complicates the tags further). Could PreallocationTools.jl be used here somehow to make it easier to switch between finite difference / automatic differentiation methods?