Periodic Callback and Sensitivity

I am trying to use DiffEq and SciML sensitivity to build NN controllers for physical systems.

Usually everything works, however this time I need to design a fixed-frequency controller so I’ve reached to PeriodCallback which does the job but then the sensitivity analysis fails or worse it says there is no gradient.

Especially I am worried about the “no gradient” error which makes me wonder if PeriodicCallback is not compatible with the sensitivity package, in the sense that u cannot be changed in the callbacks.

Here I live a MWE with regards to the “no gradient” error, adapted from the docs.

using DifferentialEquations
using SciMLSensitivity
using Plots

l = 1.0                             # length [m]
m = 1.0                             # mass [kg]
g = 9.81                            # gravitational acceleration [m/s²]

 function pendulum!(du, u, p, t)
           du[1] = u[2]
           du[2] = -3g / (2l) * sin(u[1]) + 3 / (m * l^2) * u[3]
end

θ₀ = 0.01                           # initial angular deflection [rad]
ω₀ = 0.0                            # initial angular velocity [rad/s]
u₀ = [θ₀, ω₀, 0]                       # initial state vector
tspan = (0.0, 10.0)                  # time interval

Ts = 0.5
cb = PeriodicCallback(Ts) do integrator
           p = integrator.p
           t = integrator.t
           integrator.u[3] = p[1] * sin(t)
end

prob = ODEProblem(pendulum!, u₀, tspan, [0.1], callback=CallbackSet(cb))
sol = solve(prob)

plot(sol, xaxis = "t", label = ["θ [rad]" "ω [rad/s]" "M [rad/s^2]"], layout = (3, 1))

sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP(), checkpointing=true)
_, dp = adjoint_sensitivities(sol, Rosenbrock23(); sensealg=sensealg, g=(u, p, t)->sum(u));

This gives me the error:

ERROR: `nothing` returned from a Zygote vector-Jacobian product (vjp) calculation.
This indicates that your function `f` is not a function of `p` or `u`, i.e. that
the derivative is constant zero. In many cases this is due to an error in
the model definition, for example accidentally using a global parameter
instead of the one in the model (`f(u,p,t)= _p .* u`).

Weirdly enough I had to wrap the callback in a CallbackSet and add it to the problem instead of the solve call, but anyway.

On the original problem I was working on I also got:

  • “ERROR: More than two occurances of the same time point. Please report this.” I couldn’t replicate in the MWE.
  • “ERROR: type PeriodicCallbackAffect has no field event_times”, which happens depending on the sensealg (I think when checkpointing=true but on the second run it disappears).

Any idea on how to solve this issue of mine?

1 Like

This function is not a function of your parameters p so the gradient is zero which is why it throws the error. Did you intend to differentiate with respect ot l, m, and g as well? If so those need to be in p.

This error is a safety for protecting against cases where someone accidentally makes all parameters global, which seems to be the case here. However, if this is what you intended, i.e. that none of the ODE values are dependent on the parameter and only the callbaks are, then I think we’ll need to offer a way to turn this error off.

For the adjoint part it needs to be wrapped in a way that tracks the forward solution. We should throw a more explicit error on this. @frankschae is there an example of a manual adjoint with callbacks? I think all of the examples show it implicitly through the AD interface. The AD interface does the TrackedAffect! wrappers and such, but we need to document that this should be done explicitly for doing adjoints with callbacks and needs to be done for the forward solve as well.

This function is not a function of your parameters p so the gradient is zero which is why it throws the error. Did you intend to differentiate with respect ot l, m, and g as well? If so those need to be in p.

This error is a safety for protecting against cases where someone accidentally makes all parameters global, which seems to be the case here. However, if this is what you intended, i.e. that none of the ODE values are dependent on the parameter and only the callbaks are, then I think we’ll need to offer a way to turn this error off.

I think we do support this already. In the vjp computation code, we check

sensealg.autojacvec.allow_nothing

So a user can specify it when selecting a VJP backend (e.g., ZygoteVJP(allow_nothing = true)).

For the adjoint part it needs to be wrapped in a way that tracks the forward solution. We should throw a more explicit error on this. @frankschae is there an example of a manual adjoint with callbacks? I think all of the examples show it implicitly through the AD interface. The AD interface does the TrackedAffect! wrappers and such, but we need to document that this should be done explicitly for doing adjoints with callbacks and needs to be done for the forward solve as well.

We have an example for the manual construction in the tests: https://github.com/SciML/SciMLSensitivity.jl/blob/ead6d82f0c1520750e5008f193f40b529c17079c/test/callbacks/discrete_callbacks.jl#L133-L146

    cb2 = SciMLSensitivity.track_callbacks(CallbackSet(cb), prob.tspan[1], prob.u0, prob.p,
                                           BacksolveAdjoint(autojacvec = ReverseDiffVJP()))
    sol_track = solve(prob, Tsit5(), u0 = u0, p = p, callback = cb2, tstops = tstops,
                      abstol = abstol, reltol = reltol, saveat = savingtimes)
    #cb_adj = SciMLSensitivity.setup_reverse_callbacks(cb2,BacksolveAdjoint())

    adj_prob = ODEAdjointProblem(sol_track, BacksolveAdjoint(autojacvec = ReverseDiffVJP()),
                                 Tsit5(),
                                 sol_track.t, dg!,
                                 callback = cb2,
                                 abstol = abstol, reltol = reltol)
    adj_sol = solve(adj_prob, Tsit5(), abstol = abstol, reltol = reltol)
    @test du01 ≈ adj_sol[1:2, end]
    @test dp1 ≈ adj_sol[3:6, end]

(Yes, all examples in the docs only show the high-level interface…)

The fact that the parameters p are used only in the callback is intended.
I can turn off the error with
sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP(allow_nothing=true)).
but then I get
MethodError: no method matching vec(::Nothing).

Indeed I tried trough AD and it works, so I just need a way to wrap the forward solution and the callbacks I suppose.

If there is the need to wrap the callbacks and forward solution somehow, my bad. If you can instruct me at least on where to look for it in the code, I’ll try to correct the above MWE and leave it as example.

Thanks for the suggestion, the MWE seems to work in this way:

#%% system

using DifferentialEquations
using SciMLSensitivity
using Statistics
using Zygote
using Plots

l = 1.0                             # length [m]
m = 1.0                             # mass [kg]
g = 9.81                            # gravitational acceleration [m/s²]

 function pendulum!(du, u, p, t)
           du[1] = u[2]
           du[2] = -3g / (2l) * sin(u[1]) + 3 / (m * l^2) * u[3]
end

θ₀ = 0.01                           # initial angular deflection [rad]
ω₀ = 0.0                            # initial angular velocity [rad/s]
u₀ = [θ₀, ω₀, 0]                       # initial state vector
tspan = (0.0, 10.0)                  # time interval

Ts = 0.5
function controller(integrator)
    p = integrator.p
    t = integrator.t
    integrator.u[3] = p[1] * sin(t)
end

#%% test system

prob = ODEProblem(pendulum!, u₀, tspan, [0.1])
sol = solve(prob, callback=PeriodicCallback(controller, Ts))
plot(sol, xaxis = "t", label = ["θ [rad]" "ω [rad/s]" "M [rad/s^2]"], layout = (3, 1))

#%% setup callbacks

sensealg = InterpolatingAdjoint(autojacvec=EnzymeVJP())

cb = PeriodicCallback(controller, Ts)
cb_tracked = SciMLSensitivity.track_callbacks(CallbackSet(cb), prob.tspan[1], prob.u0, prob.p, sensealg);

#%% solve
sol = solve(prob, callback=cb_tracked);
_, dp = adjoint_sensitivities(sol, Rosenbrock23(autodiff=false); callback=cb_tracked, sensealg=sensealg, g=(u, p, t)->sum(u));
dp

#%% comparison

function loss(prob, p)
    prob = remake(prob, p=p)
    sol = solve(prob, callback=cb)
    mean(sum.(sol.u)) * tspan[end]
    # sum(diff(sol.t) .* sum.(sol.u)[2:end])
end
g = Zygote.gradient((p) -> loss(prob, p), [0.1])[1]

It works even if i get the warning: warning: didn't implement memmove, using memcpy as fallback which can result in errors.
The results are dp = 2.3 and g = 1.5 which seems ok to me considering the second loss is not an actual integral.

I would note though that there are some issues I couldn’t explain myself when changing sensealg:

  • In order to enable checkpointing I had to redefine the problem prob = ODEProblem(pendulum!, u₀, tspan, [0.1], callback=cb_tracked)

  • I needed to set Rosenbrock23(autodiff=false)

  • sensealg = BacksolveAdjoint(autojacvec=EnzymeVJP()) gets me ERROR: MethodError: no method matching EnzymeCore.DuplicatedNoNeed(::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, ::Vector{Float64}) unless I redefine the problem for checkpointing.

  • sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP()) gets me ERROR: UndefRefError: access to undefined reference

  • sensealg = BacksolveAdjoint(autojacvec=ReverseDiffVJP()) gets me ERROR: UndefRefError: access to undefined reference

  • I couldn’t use ZygoteVJP as it says it’s not compatible with hybrid ODEs and that’s ok.

Finally an issue not related that I found in other occasions as well is that I cannot use sol.t inside the loss as if it weren’t compatible with AD, example:

function loss(prob, p)
    prob = remake(prob, p=p)
    sol = solve(prob, callback=cb)
    sum(diff(sol.t) .* sum.(sol.u)[2:end])
end

ERROR: MethodError: no method matching +(::ODESolution{Float64, 2, Vector....g

I am just leaving it here for everyone else, but if you have suggestions I’ll be thankful.

2 Likes

The undefined reference error seems to be a ReverseDiff initialization thing… it works when I add du[3] = 0.0 to your ODE function

function pendulum!(du, u, p, t)
    du[1] = p[2]*u[2]
    du[2] = -3g / (2l) * sin(u[1]) + 3 / (m * l^2) * u[3]
    du[3] = 0.0
end
1 Like