Suppose I have the following function which takes the solution to an ODE and returns a sum-of-squared errors, making use of the integrator interface (which I need to use for my actual application, not necessary for this MWE):
using DifferentialEquations, ForwardDiff
f(u, p, t) = p * u
u0 = 1 / 2
tspan = (0.0, 1.0)
prob = ODEProblem(f, u0, tspan, 1.0)
integrator = DifferentialEquations.init(prob, Tsit5(), saveat = 0:0.01:1.0)
sol = solve!(integrator)
y = exp.(0:0.01:1.0)
function g(λ, integrator, y)
integrator.p = λ
reinit!(integrator)
sol = solve!(integrator)
return sum((sol .- y).^2)
end
ℓ = g(1.0, integrator, y) # example use
I want to compute \partial g/\partial \lambda using ForwardDiff, using say
ForwardDiff.derivative(λ -> g(λ, integrator, y), 1.0)
This doesn’t work because integrator.p
starts off as a Float64
, whereas integrator.p = λ
tries to change it to ForwardDiff.Dual{T}
with T
representing the associated tags for g
. Is there a way to make this differentiation possible? I’m also interested in (again using this integrator interface) differentiating with respect to the initial condition, but using say reinit!(integrator, λ)
leads to a similar error. The full stacktrace for ForwardDiff.derivative(λ -> g(λ, integrator, y), 1.0)
above is:
ERROR: MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 1})
Closest candidates are:
(::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat at C:\Users\user\AppData\Local\Programs\Julia-1.7.0\share\julia\base\rounding.jl:200
(::Type{T})(::T) where T<:Number at C:\Users\user\AppData\Local\Programs\Julia-1.7.0\share\julia\base\boot.jl:770
(::Type{T})(::VectorizationBase.Double{T}) where T<:Union{Float16, Float32, Float64, VectorizationBase.Vec{<:Any, <:Union{Float16, Float32, Float64}}, VectorizationBase.VecUnroll{var"#s31", var"#s30", var"#s29", V} where {var"#s31", var"#s30", var"#s29"<:Union{Float16, Float32, Float64}, V<:Union{Bool, Float16, Float32, Float64, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8, SIMDTypes.Bit, VectorizationBase.AbstractSIMD{var"#s30", var"#s29"}}}} at C:\Users\user\.julia\packages\VectorizationBase\9edvL\src\special\double.jl:100
...
Stacktrace:
[1] convert(#unused#::Type{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 1})
@ Base .\number.jl:7
[2] setproperty!(x::OrdinaryDiffEq.ODEIntegrator{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, false, Float64, Nothing, Float64, Float64, Float64, Float64, Float64, Float64, Vector{Float64}, ODESolution{Float64, 1, Vector{Float64}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Float64}}, ODEProblem{Float64, Tuple{Float64, Float64}, false,
Float64, ODEFunction{false, typeof(f), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{false, typeof(f), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Float64}, Vector{Float64}, Vector{Vector{Float64}}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}, DiffEqBase.DEStats}, ODEFunction{false, typeof(f), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}, OrdinaryDiffEq.DEOptions{Float64, Float64, Float64, Float64, PIController{Rational{Int64}}, typeof(DiffEqBase.ODE_DEFAULT_NORM), typeof(LinearAlgebra.opnorm), Nothing, CallbackSet{Tuple{}, Tuple{}}, typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), DataStructures.BinaryHeap{Float64, DataStructures.FasterForward}, DataStructures.BinaryHeap{Float64, DataStructures.FasterForward}, Nothing, Nothing, Int64, Tuple{}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Tuple{}}, Float64, Float64, Nothing, OrdinaryDiffEq.DefaultInit}, f::Symbol, v::ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 1})
@ Base .\Base.jl:43
[3] g(λ::ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 1}, integrator::OrdinaryDiffEq.ODEIntegrator{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, false, Float64, Nothing, Float64, Float64, Float64, Float64, Float64, Float64, Vector{Float64}, ODESolution{Float64, 1, Vector{Float64}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Float64}}, ODEProblem{Float64, Tuple{Float64, Float64}, false, Float64, ODEFunction{false, typeof(f), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{false, typeof(f), LinearAlgebra.UniformScaling{Bool}, Nothing,
Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Float64}, Vector{Float64}, Vector{Vector{Float64}}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}, DiffEqBase.DEStats}, ODEFunction{false, typeof(f), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}, OrdinaryDiffEq.DEOptions{Float64, Float64, Float64, Float64, PIController{Rational{Int64}}, typeof(DiffEqBase.ODE_DEFAULT_NORM), typeof(LinearAlgebra.opnorm), Nothing, CallbackSet{Tuple{}, Tuple{}}, typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), DataStructures.BinaryHeap{Float64, DataStructures.FasterForward}, DataStructures.BinaryHeap{Float64, DataStructures.FasterForward}, Nothing, Nothing, Int64, Tuple{}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Tuple{}}, Float64, Float64, Nothing, OrdinaryDiffEq.DefaultInit}, y::Vector{Float64})
@ Main c:\Users\user\linear_exponential_ode.jl:75
[4] (::var"#9#10")(λ::ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 1})
@ Main c:\Users\user\linear_exponential_ode.jl:81
[5] derivative(f::var"#9#10", x::Float64)
@ ForwardDiff C:\Users\user\.julia\packages\ForwardDiff\wAaVJ\src\derivative.jl:14
[6] top-level scope
@ c:\Users\user\linear_exponential_ode.jl:81