Fixed Frequency Control in DiffEq

Here you go, it’s a modified version of the original MWE with the addition of Lux and ComponentArrays, it works without them.

#%% imports

using DifferentialEquations
using Statistics
using Plots
using Random

using Zygote
using SciMLSensitivity
using NNlib
using Lux
using ComponentArrays: ComponentArray

#%% system

rng = Xoshiro(2)

net = Lux.Chain(
    Lux.Dense(2, 50, elu),
    Lux.Dense(50, 1, relu),
    x -> x * 2
)

pinit, state = Lux.setup(rng, net)
params = ComponentArray(pinit)

@inline net_controller(x, p; st=state) = net(x, p, st)[1]

l = 1.0                             # length [m]
m = 1.0                             # mass [kg]
g = 9.81                            # gravitational acceleration [m/s²]

 function pendulum!(du, u, p, t)
    du[1] = u[2]
    du[2] = -3g / (2l) * sin(u[1]) + 3 / (m * l^2) * u[3]
    du[3] = 0
end

θ₀ = 0.01                           # initial angular deflection [rad]
ω₀ = 0.0                            # initial angular velocity [rad/s]
u₀ = [θ₀, ω₀, 0]                       # initial state vector
tspan = (0.0, 10.0)                  # time interval

Ts = 0.5
function controller(integrator)
    u = integrator.u
    p = integrator.p
    t = integrator.t
    M = net_controller(u[1:2], p)[1]
    integrator.u[3] = M
end

#%% test system

prob = ODEProblem(pendulum!, u₀, tspan, params)
sol = solve(prob, callback=PeriodicCallback(controller, Ts), saveat=0.1)
plot(sol, xaxis = "t", label = ["θ [rad]" "ω [rad/s]" "M [rad/s^2]"], layout = (3, 1))

#%% setup callbacks

sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP())
# sensealg = BacksolveAdjoint(autojacvec=ReverseDiffVJP())

cb = PeriodicCallback(controller, Ts)
cb_tracked = SciMLSensitivity.track_callbacks(CallbackSet(cb), prob.tspan[1], prob.u0, prob.p, sensealg);

#%% solve

sol = solve(prob, callback=cb_tracked);
_, dp = adjoint_sensitivities(sol, Rosenbrock23(autodiff=false); callback=cb_tracked, sensealg=sensealg, g=(u, p, t)->sum(u));
dp

#%% comparison

function loss(prob, p)
    prob = remake(prob, p=p)
    sol = solve(prob, callback=cb, saveat=0.1)
    mean(sum(sol, dims=1)) * tspan[end]
end
Zygote.withgradient((p) -> loss(prob, p), params)

I get this error from the adjoint_sensitivities function: ERROR: type Array has no field layer_1. I suppose the tracking callback wraps the array in a way not compatible with ComponentArrays.

I also get an error from the Zygote.withgradient call: ERROR: MethodError: no method matching setvjp(::InterpolatingAdjoint{0, true, Val{:central}, EnzymeVJP}, ::ReverseDiffVJP{false}).
It’s supposed to be a check without my intervention, but if I add sensealg i get the same error as in adjoint_sensitivities.