Adjoint sensitivities for non-numeric types in ModelingToolkit.jl

I’d like to use non-numeric types in ModelingToolkit.jl, but with adjoint sensitivities. When using DifferentialEquations.jl there is support for this through SciMLStructures.jl as described here. So I assumed this would work with ModelingToolkit.jl as well, but thus far have been unsuccessful. The reason it is important to me to define a parameter as a non-numeric type is because there are a ton of fields in my type that I would rather not all expose as parameters or structural parameters for code clarity reasons.

Here’s a MWE

using ModelingToolkit, OrdinaryDiffEq, SciMLSensitivity, Zygote
using ModelingToolkit: t_nounits as t, D_nounits as D
using SciMLStructures
using SymbolicIndexingInterface: parameter_values

# parametric so AD duals work
struct MyType{T <: Real}
    a::T
    b::T
end

# accessor functions, registered for symbolic use
geta(θ::MyType) = θ.a
getb(θ::MyType) = θ.b
@register_symbolic geta(mytype::MyType)
@register_symbolic getb(mytype::MyType)

# tell SciMLStructures how to "flatten" and "repack" MyType
SciMLStructures.isscimlstructure(::MyType) = true
function SciMLStructures.canonicalize(::SciMLStructures.Tunable, θ::MyType{T}) where {T}
    vals = T[θ.a, θ.b]                    # flatten to a Vector
    repack = x -> (@assert length(x) == 2; MyType(x[1], x[2]))
    return vals, repack, false            # false => no aliasing
end

@mtkmodel Toy begin
    @parameters begin
        custom::MyType = MyType(2.0, 3.0), [tunable = true]
        # a = 2.0
        # b = 3.0
        c = -1.0
    end
    @variables begin
        x(t) = 1.0
    end
    @equations begin
        # Use the accessors instead of `custom.a` / `custom.b` directly
        D(x) ~ c * geta(custom) * x + getb(custom) # this doesn't work
        # D(x) ~ c * a * x + b # this works
    end
end

@mtkcompile sys = Toy()
tspan = (0.0, 1.0)
prob = ODEProblem(sys, [], tspan)

x_target = 0.5

# Canonicalize the tunables of the MTK parameter object to get an initial vector [a,b]
ps0 = parameter_values(prob)
x0, repack_ps, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), ps0)  # I want this to be x0 == [2.0, 3.0], but x0 is empty when using MyType

function loss(ab_vec)
    # repack the entire MTKParameters object, *including* custom::MyType, from the vector
    new_ps = repack_ps(ab_vec)
    newprob = remake(prob; p = new_ps)
    sol = solve(newprob, Tsit5(); saveat = [last(tspan)], sensealg = GaussAdjoint())
    xT = sol[sys.x][end]
    return (xT - x_target)^2
end

# Gradient w.r.t. [a,b]
∇loss = Zygote.gradient(loss, x0)[1]
println("∂loss/∂a = ", ∇loss[1], ",  ∂loss/∂b = ", ∇loss[2])

When using custom::MyType in the Toy system, x0 is empty which (I believe) then leads to this error:

┌ Warning: Potential performance improvement omitted. ReverseDiffVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN[] = true. To turn off this printing, add `verbose = false` to the `solve` call.
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/7zKgz/src/concrete_solve.jl:68

ERROR: LoadError: BoundsError: attempt to access 1-element Vector{Float64} at index [2]
Stacktrace:
 [1] throw_boundserror(A::Vector{Float64}, I::Tuple{Int64})
   @ Base ./essentials.jl:14
 [2] getindex(A::Vector{Float64}, i::Int64)
   @ Base ./essentials.jl:916
 [3] top-level scope
   @ ~/test_adjoint_mtk.jl:64
 [4] include(fname::String)
   @ Main ./sysimg.jl:38
 [5] top-level scope
   @ REPL[15]:1
in expression starting at /test_adjoint_mtk.jl:64

It does work when I set a=2.0 and b=3.0 as parameters directly (without CustomType).