Hi!
Lots of differential equations models incorporate an external input function (e.g. a control system doing reference tracking). When I build models of this form in ModelingToolkit.jl, they don’t seem to be differentiable. What am I doing wrong?
MWE:
A linear-time-invariant system being forced by a sine wave input. the A matrix constitutes the parameters of the ODE.
using ModelingToolkit, Zygote, OrdinaryDiffEq, DiffEqSensitivity, LinearAlgebra
function simple_sys() #u is input
@variables t x[1:2](t) u(t)
@parameters p[1:4]
D = Differential(t)
A = reshape(p,2,2)
B = reshape([2.,3], 2, 1)
eqs = D.(x) .~ A*x .+ B*u
defaults = Dict(vcat(p .=> [1., -2., 3., -0.5], x .=> ones(2), u => 0.))
return ODESystem(eqs; name=:lp, defaults)
end
od = simple_sys()
function connect_to_input(od, input; name)
@variables t f(t)
eqs = [
f ~ input(t),
f ~ od.u
]
ODESystem(eqs; systems = [od], name = name, defaults = Dict(f .=> input(0.)))
end
@named con = connect_to_input(od, sin)
final = structural_simplify(con)
prob = ODEProblem(final, [], (0., 10.),[])
tsteps = 0:0.5:10
nom_sol = solve(prob, Tsit5(), saveat = tsteps, sensealg = InterpolatingAdjoint())
function loss(p)
nprob = remake(prob; p=p)
sol = solve(nprob, Tsit5(), saveat = tsteps, sensealg = InterpolatingAdjoint())
return sum(abs2, sol[Symbol("lp₊x₁(t)")] .- sin.(tsteps))
end
Zygote.gradient(loss, prob.p)
Stacktrace is:
ERROR: Can't differentiate foreigncall expression
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] Pullback
@ ./boot.jl:480 [inlined]
[3] (::typeof(∂(Symbol)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/lwmfx/src/compiler/interface2.jl:0
[4] Pullback
@ ~/.julia/dev/FeedbackLearning/scripts/undifferentiable.jl:35 [inlined]
[5] (::typeof(∂(loss)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/lwmfx/src/compiler/interface2.jl:0
[6] (::Zygote.var"#41#42"{typeof(∂(loss))})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/lwmfx/src/compiler/interface.jl:41
[7] gradient(f::Function, args::Vector{Float32})
@ Zygote ~/.julia/packages/Zygote/lwmfx/src/compiler/interface.jl:59
[8] top-level scope
@ REPL[97]:1
How should I set up a forcing function so that autodiff works?
Thanks a lot!