Here you go, it’s a modified version of the original MWE with the addition of Lux and ComponentArrays, it works without them.
#%% imports
using DifferentialEquations
using Statistics
using Plots
using Random
using Zygote
using SciMLSensitivity
using NNlib
using Lux
using ComponentArrays: ComponentArray
#%% system
rng = Xoshiro(2)
net = Lux.Chain(
Lux.Dense(2, 50, elu),
Lux.Dense(50, 1, relu),
x -> x * 2
)
pinit, state = Lux.setup(rng, net)
params = ComponentArray(pinit)
@inline net_controller(x, p; st=state) = net(x, p, st)[1]
l = 1.0 # length [m]
m = 1.0 # mass [kg]
g = 9.81 # gravitational acceleration [m/s²]
function pendulum!(du, u, p, t)
du[1] = u[2]
du[2] = -3g / (2l) * sin(u[1]) + 3 / (m * l^2) * u[3]
du[3] = 0
end
θ₀ = 0.01 # initial angular deflection [rad]
ω₀ = 0.0 # initial angular velocity [rad/s]
u₀ = [θ₀, ω₀, 0] # initial state vector
tspan = (0.0, 10.0) # time interval
Ts = 0.5
function controller(integrator)
u = integrator.u
p = integrator.p
t = integrator.t
M = net_controller(u[1:2], p)[1]
integrator.u[3] = M
end
#%% test system
prob = ODEProblem(pendulum!, u₀, tspan, params)
sol = solve(prob, callback=PeriodicCallback(controller, Ts), saveat=0.1)
plot(sol, xaxis = "t", label = ["θ [rad]" "ω [rad/s]" "M [rad/s^2]"], layout = (3, 1))
#%% setup callbacks
sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP())
# sensealg = BacksolveAdjoint(autojacvec=ReverseDiffVJP())
cb = PeriodicCallback(controller, Ts)
cb_tracked = SciMLSensitivity.track_callbacks(CallbackSet(cb), prob.tspan[1], prob.u0, prob.p, sensealg);
#%% solve
sol = solve(prob, callback=cb_tracked);
_, dp = adjoint_sensitivities(sol, Rosenbrock23(autodiff=false); callback=cb_tracked, sensealg=sensealg, g=(u, p, t)->sum(u));
dp
#%% comparison
function loss(prob, p)
prob = remake(prob, p=p)
sol = solve(prob, callback=cb, saveat=0.1)
mean(sum(sol, dims=1)) * tspan[end]
end
Zygote.withgradient((p) -> loss(prob, p), params)
I get this error from the adjoint_sensitivities
function: ERROR: type Array has no field layer_1
. I suppose the tracking callback wraps the array in a way not compatible with ComponentArrays.
I also get an error from the Zygote.withgradient
call: ERROR: MethodError: no method matching setvjp(::InterpolatingAdjoint{0, true, Val{:central}, EnzymeVJP}, ::ReverseDiffVJP{false})
.
It’s supposed to be a check without my intervention, but if I add sensealg
i get the same error as in adjoint_sensitivities
.