Fixed Frequency Control in DiffEq

I am trying to use DiffEq and SciML sensitivity to build NN fixed-frequency controller for physical systems.

I’ve tried PeriodicCallback and ti works (see original issue)but not with Neural Networks (Lux), because only ReverseDIffVJP and EnzymeVJP support callbacks (library error otherwise). However ReverseDiff is incompatible with Lux / ComponentArrays (or so it seems, and has been noted by others). EnzymeVJP works in a NNs but it gives many warnings on BLAS then fails outputting low-level codes and gibberish (for me), I suppose its too experimental for this combination of features.

As a workaround I tried Delay Differential Equations, also its ability to use past states would enable me to input windowed inputs to the NN, but it seems to allow only ForwardDiff which is too slow for the NN usecase. I can provide MWE for that too, it works but doesn’t scale well with the number of parameters: at 50 parameters it Zygote.gradient says it’s not compatible, but if you force ForwarDiff it works but it’s too slow to be a viable option.

Any other idea how a controller could be implemented with fixed-frequency and optimized? eventually with windowed inputs? Any workaround is welcome.

Thanks for your patience.

@avikpal can you look into the ReverseDiffVJP issue here? I don’t think that would be too difficult.

@Guido_Ballabio can you share your code that doesn’t work? Lux is currently compatible with ReverseDiff.jl and the linked issue also works.

Here you go, it’s a modified version of the original MWE with the addition of Lux and ComponentArrays, it works without them.

#%% imports

using DifferentialEquations
using Statistics
using Plots
using Random

using Zygote
using SciMLSensitivity
using NNlib
using Lux
using ComponentArrays: ComponentArray

#%% system

rng = Xoshiro(2)

net = Lux.Chain(
    Lux.Dense(2, 50, elu),
    Lux.Dense(50, 1, relu),
    x -> x * 2

pinit, state = Lux.setup(rng, net)
params = ComponentArray(pinit)

@inline net_controller(x, p; st=state) = net(x, p, st)[1]

l = 1.0                             # length [m]
m = 1.0                             # mass [kg]
g = 9.81                            # gravitational acceleration [m/s²]

 function pendulum!(du, u, p, t)
    du[1] = u[2]
    du[2] = -3g / (2l) * sin(u[1]) + 3 / (m * l^2) * u[3]
    du[3] = 0

θ₀ = 0.01                           # initial angular deflection [rad]
ω₀ = 0.0                            # initial angular velocity [rad/s]
u₀ = [θ₀, ω₀, 0]                       # initial state vector
tspan = (0.0, 10.0)                  # time interval

Ts = 0.5
function controller(integrator)
    u = integrator.u
    p = integrator.p
    t = integrator.t
    M = net_controller(u[1:2], p)[1]
    integrator.u[3] = M

#%% test system

prob = ODEProblem(pendulum!, u₀, tspan, params)
sol = solve(prob, callback=PeriodicCallback(controller, Ts), saveat=0.1)
plot(sol, xaxis = "t", label = ["θ [rad]" "ω [rad/s]" "M [rad/s^2]"], layout = (3, 1))

#%% setup callbacks

sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP())
# sensealg = BacksolveAdjoint(autojacvec=ReverseDiffVJP())

cb = PeriodicCallback(controller, Ts)
cb_tracked = SciMLSensitivity.track_callbacks(CallbackSet(cb), prob.tspan[1], prob.u0, prob.p, sensealg);

#%% solve

sol = solve(prob, callback=cb_tracked);
_, dp = adjoint_sensitivities(sol, Rosenbrock23(autodiff=false); callback=cb_tracked, sensealg=sensealg, g=(u, p, t)->sum(u));

#%% comparison

function loss(prob, p)
    prob = remake(prob, p=p)
    sol = solve(prob, callback=cb, saveat=0.1)
    mean(sum(sol, dims=1)) * tspan[end]
Zygote.withgradient((p) -> loss(prob, p), params)

I get this error from the adjoint_sensitivities function: ERROR: type Array has no field layer_1. I suppose the tracking callback wraps the array in a way not compatible with ComponentArrays.

I also get an error from the Zygote.withgradient call: ERROR: MethodError: no method matching setvjp(::InterpolatingAdjoint{0, true, Val{:central}, EnzymeVJP}, ::ReverseDiffVJP{false}).
It’s supposed to be a check without my intervention, but if I add sensealg i get the same error as in adjoint_sensitivities.

@ChrisRackauckas SciMLSensitivity.jl/callback_tracking.jl at master · SciML/SciMLSensitivity.jl · GitHub messes up a reverse-diff tracked array. why is that list comprehension needed?

For now pass the sensitivity algorithm in the loss function:

function loss(prob, p)
    prob = remake(prob; p)
    sol = solve(prob, Tsit5(); sensealg, callback=cb, saveat=0.1)
    return mean(sum(sol; dims=1)) * tspan[end]

Zygote.withgradient((p) -> loss(prob, p), params)

@ChrisRackauckas @frankschae Changing the get_Fake_Integrator definition to

function get_FakeIntegrator(autojacvec::ReverseDiffVJP, u, p, t, tprev)
    FakeIntegrator([x for x in u], p, t, tprev)

fixes the reversediff problem.

Can you open a PR? That looks like a reasonable fix.

1 Like