I am trying to use AD to differentiate the solution of an ODE. Various docs seem to indicate this is very well supported, so I was surprised to get an error.
For some additional color, I got a different error before adding
“ODEProblem{true, SciMLBase.FullSpecialize}”. I don’t quite understand what this does, but it was recommended in another discourse topic. I also get a slightly different error when switching between Tsit5() and Vern9().
Copy-pastable MWE:
using ModelingToolkit, DifferentialEquations
using ForwardDiff
function ADTest()
@parameters a b
@variables t x1(t) x2(t) y1(t) y2(t)
D = Differential(t)
states = [x1, x2]
parameters = [a, b]
@named model = ODESystem([
D(x1) ~ a * x1,
D(x2) ~ b * x2,
], t, states, parameters)
model = structural_simplify(model)
measured_quantities = [
y1 ~ x1,
y2 ~ x2]
ic = Dict(x1 => 1.0, x2 => 2.0)
p_true = Dict(a => 2.0, b => 3.0)
problem = ODEProblem(model, ic, [0.0, 1e-5], p_true)
soln = ModelingToolkit.solve(problem, Tsit5(), abstol = 1e-14, reltol = 1e-14)
display(soln(1e-5, idxs = [x1, x2]))
function different_time(new_ic, new_params, new_t)
newprob = ODEProblem{true, SciMLBase.FullSpecialize}(model, new_ic, [0.0, new_t], new_params)
new_soln = ModelingToolkit.solve(newprob, Tsit5(), abstol = 1e-14, reltol = 1e-14)
return (soln(new_t, idxs = [x1, x2]))
end
display(different_time(ic,p_true, 2e-5))
temp = ForwardDiff.derivative(s -> different_time(ic,p_true, s),4e-5)
display(temp)
end
ADTest()
Error:
ERROR: LoadError: MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{var"#209#211"{var"#different_time#210"{ODESolution{…}, Num, Num}, Dict{Num, Float64}, Dict{Num, Float64}}, Float64}, Float64, 1})
Closest candidates are:
(::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat
@ Base rounding.jl:207
(::Type{T})(::T) where T<:Number
@ Core boot.jl:792
Float64(::IrrationalConstants.Fourinvπ)
@ IrrationalConstants ~/.julia/packages/IrrationalConstants/vp5v4/src/macro.jl:112
...
Stacktrace:
[1] convert(::Type{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{var"#209#211"{var"#different_time#210"{ODESolution{…}, Num, Num}, Dict{Num, Float64}, Dict{Num, Float64}}, Float64}, Float64, 1})
@ Base ./number.jl:7
[2] setindex!(A::Vector{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{var"#209#211"{var"#different_time#210"{ODESolution{…}, Num, Num}, Dict{Num, Float64}, Dict{Num, Float64}}, Float64}, Float64, 1}, i1::Int64)
@ Base ./array.jl:1021
[3] macro expansion
@ ~/.julia/packages/OrdinaryDiffEq/GAgjL/src/initdt.jl:119 [inlined]
[4] macro expansion
@ ./simdloop.jl:77 [inlined]
[5] ode_determine_initdt(u0::Vector{…}, t::ForwardDiff.Dual{…}, tdir::ForwardDiff.Dual{…}, dtmax::ForwardDiff.Dual{…}, abstol::Float64, reltol::Float64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), prob::ODEProblem{…}, integrator::OrdinaryDiffEq.ODEIntegrator{…})
@ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/GAgjL/src/initdt.jl:118
[6] auto_dt_reset!
@ ~/.julia/packages/OrdinaryDiffEq/GAgjL/src/integrators/integrator_interface.jl:453 [inlined]
[7] handle_dt!(integrator::OrdinaryDiffEq.ODEIntegrator{…})
@ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/GAgjL/src/solve.jl:571
[8] __init(prob::ODEProblem{…}, alg::Tsit5{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{…}; saveat::Tuple{}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::Nothing, dense::Bool, calck::Bool, dt::ForwardDiff.Dual{…}, dtmin::ForwardDiff.Dual{…}, dtmax::ForwardDiff.Dual{…}, force_dtmin::Bool, adaptive::Bool, gamma::Rational{…}, abstol::Float64, reltol::Float64, qmin::Rational{…}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Rational{…}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::@Kwargs{})
@ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/GAgjL/src/solve.jl:533
[9] __init (repeats 5 times)
@ ~/.julia/packages/OrdinaryDiffEq/GAgjL/src/solve.jl:11 [inlined]
[10] #__solve#787
@ ~/.julia/packages/OrdinaryDiffEq/GAgjL/src/solve.jl:6 [inlined]
[11] __solve
@ ~/.julia/packages/OrdinaryDiffEq/GAgjL/src/solve.jl:1 [inlined]
[12] solve_call(_prob::ODEProblem{…}, args::Tsit5{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
@ DiffEqBase ~/.julia/packages/DiffEqBase/WyGjp/src/solve.jl:612
[13] solve_call
@ ~/.julia/packages/DiffEqBase/WyGjp/src/solve.jl:569 [inlined]
[14] #solve_up#53
@ ~/.julia/packages/DiffEqBase/WyGjp/src/solve.jl:1080 [inlined]
[15] solve_up
@ ~/.julia/packages/DiffEqBase/WyGjp/src/solve.jl:1066 [inlined]
[16] #solve#51
@ ~/.julia/packages/DiffEqBase/WyGjp/src/solve.jl:1003 [inlined]
[17] (::var"#different_time#210"{ODESolution{…}, Num, Num})(new_ic::Dict{Num, Float64}, new_params::Dict{Num, Float64}, new_t::ForwardDiff.Dual{ForwardDiff.Tag{…}, Float64, 1})
@ Main ~/learning/ODETests/PLI/MWE2.jl:32
[18] #209
@ ~/learning/ODETests/PLI/MWE2.jl:37 [inlined]
[19] derivative(f::var"#209#211"{var"#different_time#210"{ODESolution{…}, Num, Num}, Dict{Num, Float64}, Dict{Num, Float64}}, x::Float64)
@ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/derivative.jl:14