ForwardDiff with ModelingToolkit: Type conversion error in MTKParameters with DAE systems

Hi everyone,

I’m encountering a MethodError when trying to follow this example from the ModelingToolkit documentation. The problem seems to be related with dual number propagation in ODEProblems that have the property: Non-trivial mass matrix: true. Problems in which algebraic equations are eliminated after mtkcompile() work for me.

MWE which reproduces the error:

using ForwardDiff
using ModelingToolkit
using OrdinaryDiffEq
using PreallocationTools
using SymbolicIndexingInterface

using ModelingToolkit: t_nounits as t, D_nounits as D
using SciMLStructures: Tunable, canonicalize, replace
using SymbolicIndexingInterface: parameter_values

function mixture_boiling(; name)
    sts = @variables begin
        x(t)
        y(t)
        V(t)
        N(t)
    end
    ps = @parameters Φ
    eqs = [
        y ~ x * 1.5,
        D(N) ~ -V,
        D(x) ~ 1 / N * (V * (x - y)),
        0 ~ 1 / N * (Φ - 5000.0 * V)
    ]
    return System(eqs, t; name=name)
end

@named mix_sys = mixture_boiling()
mix_sys = mtkcompile(mix_sys)

N0 = 20.0
x0 = 0.2
Heat_flow = 10.0

inits_v = [mix_sys.N => N0, mix_sys.x => x0, mix_sys.Φ => Heat_flow]
dae_problem = ODEProblem(mix_sys, inits_v, (0.0, 100.0), guesses=[mix_sys.V => 0.01])

ps = parameter_values(dae_problem)
tunable_buffer = copy(canonicalize(Tunable(), ps)[1])
diff_cache = DiffCache(tunable_buffer)
setter! = setp(dae_problem, [mix_sys.Φ])
obj_params = (18.5,)

p = (dae_problem, mix_sys, setter!, diff_cache, obj_params)
function loss(u, p)
    (dae_problem, mix_sys, setter!, diff_cache, obj_params) = p
    ps = parameter_values(dae_problem)
    buffer = get_tmp(diff_cache, u)
    copyto!(buffer, canonicalize(Tunable(), ps)[1])
    ps = replace(Tunable(), ps, buffer)
    setter!(ps, u)
    newprob = remake(dae_problem; p=ps)
    sol = solve(newprob, Rodas5P())
    N_end = sol[mix_sys.N, end]
    obj = (N_end - obj_params[1])^2
    return obj
end

function loss_grad(u)
    rez = loss(u, p)
end

controls_0 = [12.0]
loss_grad(controls_0)
grad_ = ForwardDiff.gradient(loss_grad, controls_0)

Error message and stacktrace:

ERROR: MethodError: Cannot `convert` an object of type
  MTKParameters{Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(loss_grad), Float64}, Float64, 1}},Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(loss_grad),Float64},Float64,1},1},Tuple,Tuple,Tuple,Tuple} to an object of type
  MTKParameters{Base.ReinterpretArray{ForwardDiff.Dual{ForwardDiff.Tag{typeof(loss_grad), Float64}, Float64, 1}, 1, Float64, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, false},Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(loss_grad),Float64},Float64,1},1},Tuple,Tuple,Tuple,Tuple}
The function `convert` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  MTKParameters{T, I, D, C, N, H}(::T, ::I, ::D, ::C, ::N, ::H) where {T, I, D, C, N, H}
   @ ModelingToolkit C:\Users\faksH\.julia\packages\ModelingToolkit\KTR1R\src\systems\parameter_buffer.jl:14
  convert(::Type{T}, ::T) where T
   @ Base Base_compiler.jl:133

Stacktrace:
  [1] setproperty!(x::OrdinaryDiffEqCore.ODEIntegrator{…}, f::Symbol, v::MTKParameters{…})
    @ Base .\Base_compiler.jl:57
  [2] _initialize_dae!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…}, prob::ODEProblem{…}, alg::SciMLBase.OverrideInit{…}, isinplace::Val{…})
    @ OrdinaryDiffEqCore C:\Users\faksH\.julia\packages\OrdinaryDiffEqCore\GMkz9\src\initialize_dae.jl:142
  [3] _initialize_dae!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…}, prob::ODEProblem{…}, alg::DiffEqBase.DefaultInit, x::Val{…})
    @ OrdinaryDiffEqCore C:\Users\faksH\.julia\packages\OrdinaryDiffEqCore\GMkz9\src\initialize_dae.jl:28
  [4] initialize_dae!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…}, initializealg::DiffEqBase.DefaultInit)
    @ OrdinaryDiffEqCore C:\Users\faksH\.julia\packages\OrdinaryDiffEqCore\GMkz9\src\initialize_dae.jl:18
  [5] __init(prob::ODEProblem{…}, alg::Rodas5P{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}; saveat::Tuple{}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_discretes::Bool, save_start::Bool, save_end::Nothing, callback::Nothing, dense::Bool, calck::Bool, dt::Float64, dtmin::Float64, dtmax::Float64, force_dtmin::Bool, adaptive::Bool, gamma::Rational{…}, abstol::Nothing, reltol::Nothing, qmin::Rational{…}, qmax::Int64, qsteady_min::Int64, qsteady_max::Rational{…}, beta1::Nothing, beta2::Nothing, qoldinit::Rational{…}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias::ODEAliasSpecifier, initializealg::DiffEqBase.DefaultInit, kwargs::@Kwargs{})
    @ OrdinaryDiffEqCore C:\Users\faksH\.julia\packages\OrdinaryDiffEqCore\GMkz9\src\solve.jl:556
  [6] __init (repeats 2 times)
    @ C:\Users\faksH\.julia\packages\OrdinaryDiffEqCore\GMkz9\src\solve.jl:11 [inlined]
  [7] __solve(::ODEProblem{…}, ::Rodas5P{…}; kwargs::@Kwargs{})
    @ OrdinaryDiffEqCore C:\Users\faksH\.julia\packages\OrdinaryDiffEqCore\GMkz9\src\solve.jl:6
  [8] __solve
    @ C:\Users\faksH\.julia\packages\OrdinaryDiffEqCore\GMkz9\src\solve.jl:1 [inlined]
  [9] #solve_call#24
    @ C:\Users\faksH\.julia\packages\DiffEqBase\6sydR\src\solve.jl:142 [inlined]
 [10] solve_call
    @ C:\Users\faksH\.julia\packages\DiffEqBase\6sydR\src\solve.jl:109 [inlined]
 [11] #solve_up#31
    @ C:\Users\faksH\.julia\packages\DiffEqBase\6sydR\src\solve.jl:578 [inlined]
 [12] solve_up
    @ C:\Users\faksH\.julia\packages\DiffEqBase\6sydR\src\solve.jl:555 [inlined]
 [13] #solve#30
    @ C:\Users\faksH\.julia\packages\DiffEqBase\6sydR\src\solve.jl:545 [inlined]
 [14] solve(prob::ODEProblem{…}, args::Rodas5P{…})
    @ DiffEqBase C:\Users\faksH\.julia\packages\DiffEqBase\6sydR\src\solve.jl:535
 [15] loss(u::Vector{…}, p::Tuple{…})
    @ Main c:\Users\faksH\.julia\dev\Forwarddiff_bug\dae_bug.jl:53
 [16] loss_grad(u::Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(loss_grad), Float64}, Float64, 1}})
    @ Main c:\Users\faksH\.julia\dev\Forwarddiff_bug\dae_bug.jl:60
 [17] vector_mode_dual_eval!
    @ C:\Users\faksH\.julia\packages\ForwardDiff\kQBw9\src\apiutils.jl:24 [inlined]
 [18] vector_mode_gradient(f::typeof(loss_grad), x::Vector{…}, cfg::ForwardDiff.GradientConfig{…})
    @ ForwardDiff C:\Users\faksH\.julia\packages\ForwardDiff\kQBw9\src\gradient.jl:98
 [19] gradient
    @ C:\Users\faksH\.julia\packages\ForwardDiff\kQBw9\src\gradient.jl:20 [inlined]
 [20] gradient(f::typeof(loss_grad), x::Vector{…}, cfg::ForwardDiff.GradientConfig{…})
    @ ForwardDiff C:\Users\faksH\.julia\packages\ForwardDiff\kQBw9\src\gradient.jl:17
 [21] gradient(f::typeof(loss_grad), x::Vector{Float64})
    @ ForwardDiff C:\Users\faksH\.julia\packages\ForwardDiff\kQBw9\src\gradient.jl:17
 [22] top-level scope
    @ c:\Users\faksH\.julia\dev\Forwarddiff_bug\dae_bug.jl:65

Any guidance would be greatly appreciated! Thanks in advance.