ModelingToolkit.jl ODESystems with a forcing function: how to autodiff?

Hi!

Lots of differential equations models incorporate an external input function (e.g. a control system doing reference tracking). When I build models of this form in ModelingToolkit.jl, they don’t seem to be differentiable. What am I doing wrong?

MWE:

A linear-time-invariant system being forced by a sine wave input. the A matrix constitutes the parameters of the ODE.

using ModelingToolkit, Zygote, OrdinaryDiffEq, DiffEqSensitivity, LinearAlgebra


function simple_sys() #u is input
    @variables t x[1:2](t) u(t)
    @parameters p[1:4]
    D = Differential(t)
    A = reshape(p,2,2)
    B = reshape([2.,3], 2, 1)

    eqs = D.(x) .~ A*x .+ B*u
    defaults = Dict(vcat(p .=> [1., -2., 3., -0.5], x .=> ones(2), u => 0.))
    return ODESystem(eqs; name=:lp, defaults)
end
od = simple_sys()

function connect_to_input(od, input; name)
    @variables t f(t)
    eqs = [
        f ~ input(t),
        f ~ od.u
    ]
    ODESystem(eqs; systems = [od], name = name, defaults = Dict(f .=> input(0.)))
end

@named con = connect_to_input(od, sin)
final = structural_simplify(con)
prob = ODEProblem(final, [], (0., 10.),[])
tsteps = 0:0.5:10
nom_sol = solve(prob, Tsit5(), saveat = tsteps, sensealg = InterpolatingAdjoint())

function loss(p)
    nprob = remake(prob; p=p)
    sol = solve(nprob, Tsit5(), saveat = tsteps, sensealg = InterpolatingAdjoint())
    return sum(abs2, sol[Symbol("lp₊x₁(t)")] .- sin.(tsteps))
end

Zygote.gradient(loss, prob.p)

Stacktrace is:

ERROR: Can't differentiate foreigncall expression
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:33
 [2] Pullback
   @ ./boot.jl:480 [inlined]
 [3] (::typeof(∂(Symbol)))(Δ::Nothing)
   @ Zygote ~/.julia/packages/Zygote/lwmfx/src/compiler/interface2.jl:0
 [4] Pullback
   @ ~/.julia/dev/FeedbackLearning/scripts/undifferentiable.jl:35 [inlined]
 [5] (::typeof(∂(loss)))(Δ::Float64)
   @ Zygote ~/.julia/packages/Zygote/lwmfx/src/compiler/interface2.jl:0
 [6] (::Zygote.var"#41#42"{typeof(∂(loss))})(Δ::Float64)
   @ Zygote ~/.julia/packages/Zygote/lwmfx/src/compiler/interface.jl:41
 [7] gradient(f::Function, args::Vector{Float32})
   @ Zygote ~/.julia/packages/Zygote/lwmfx/src/compiler/interface.jl:59
 [8] top-level scope
   @ REPL[97]:1

How should I set up a forcing function so that autodiff works?

Thanks a lot!