I’d like to use non-numeric types in ModelingToolkit.jl, but with adjoint sensitivities. When using DifferentialEquations.jl there is support for this through SciMLStructures.jl as described here. So I assumed this would work with ModelingToolkit.jl as well, but thus far have been unsuccessful. The reason it is important to me to define a parameter as a non-numeric type is because there are a ton of fields in my type that I would rather not all expose as parameters or structural parameters for code clarity reasons.
Here’s a MWE
using ModelingToolkit, OrdinaryDiffEq, SciMLSensitivity, Zygote
using ModelingToolkit: t_nounits as t, D_nounits as D
using SciMLStructures
using SymbolicIndexingInterface: parameter_values
# parametric so AD duals work
struct MyType{T <: Real}
a::T
b::T
end
# accessor functions, registered for symbolic use
geta(θ::MyType) = θ.a
getb(θ::MyType) = θ.b
@register_symbolic geta(mytype::MyType)
@register_symbolic getb(mytype::MyType)
# tell SciMLStructures how to "flatten" and "repack" MyType
SciMLStructures.isscimlstructure(::MyType) = true
function SciMLStructures.canonicalize(::SciMLStructures.Tunable, θ::MyType{T}) where {T}
vals = T[θ.a, θ.b] # flatten to a Vector
repack = x -> (@assert length(x) == 2; MyType(x[1], x[2]))
return vals, repack, false # false => no aliasing
end
@mtkmodel Toy begin
@parameters begin
custom::MyType = MyType(2.0, 3.0), [tunable = true]
# a = 2.0
# b = 3.0
c = -1.0
end
@variables begin
x(t) = 1.0
end
@equations begin
# Use the accessors instead of `custom.a` / `custom.b` directly
D(x) ~ c * geta(custom) * x + getb(custom) # this doesn't work
# D(x) ~ c * a * x + b # this works
end
end
@mtkcompile sys = Toy()
tspan = (0.0, 1.0)
prob = ODEProblem(sys, [], tspan)
x_target = 0.5
# Canonicalize the tunables of the MTK parameter object to get an initial vector [a,b]
ps0 = parameter_values(prob)
x0, repack_ps, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), ps0) # I want this to be x0 == [2.0, 3.0], but x0 is empty when using MyType
function loss(ab_vec)
# repack the entire MTKParameters object, *including* custom::MyType, from the vector
new_ps = repack_ps(ab_vec)
newprob = remake(prob; p = new_ps)
sol = solve(newprob, Tsit5(); saveat = [last(tspan)], sensealg = GaussAdjoint())
xT = sol[sys.x][end]
return (xT - x_target)^2
end
# Gradient w.r.t. [a,b]
∇loss = Zygote.gradient(loss, x0)[1]
println("∂loss/∂a = ", ∇loss[1], ", ∂loss/∂b = ", ∇loss[2])
When using custom::MyType
in the Toy
system, x0
is empty which (I believe) then leads to this error:
┌ Warning: Potential performance improvement omitted. ReverseDiffVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN[] = true. To turn off this printing, add `verbose = false` to the `solve` call.
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/7zKgz/src/concrete_solve.jl:68
ERROR: LoadError: BoundsError: attempt to access 1-element Vector{Float64} at index [2]
Stacktrace:
[1] throw_boundserror(A::Vector{Float64}, I::Tuple{Int64})
@ Base ./essentials.jl:14
[2] getindex(A::Vector{Float64}, i::Int64)
@ Base ./essentials.jl:916
[3] top-level scope
@ ~/test_adjoint_mtk.jl:64
[4] include(fname::String)
@ Main ./sysimg.jl:38
[5] top-level scope
@ REPL[15]:1
in expression starting at /test_adjoint_mtk.jl:64
It does work when I set a=2.0
and b=3.0
as parameters directly (without CustomType).