I am trying to use DiffEq and SciML sensitivity to build NN controllers for physical systems.
Usually everything works, however this time I need to design a fixed-frequency controller so I’ve reached to PeriodCallback which does the job but then the sensitivity analysis fails or worse it says there is no gradient.
Especially I am worried about the “no gradient” error which makes me wonder if PeriodicCallback is not compatible with the sensitivity package, in the sense that u cannot be changed in the callbacks.
Here I live a MWE with regards to the “no gradient” error, adapted from the docs.
using DifferentialEquations
using SciMLSensitivity
using Plots
l = 1.0 # length [m]
m = 1.0 # mass [kg]
g = 9.81 # gravitational acceleration [m/s²]
function pendulum!(du, u, p, t)
du[1] = u[2]
du[2] = -3g / (2l) * sin(u[1]) + 3 / (m * l^2) * u[3]
end
θ₀ = 0.01 # initial angular deflection [rad]
ω₀ = 0.0 # initial angular velocity [rad/s]
u₀ = [θ₀, ω₀, 0] # initial state vector
tspan = (0.0, 10.0) # time interval
Ts = 0.5
cb = PeriodicCallback(Ts) do integrator
p = integrator.p
t = integrator.t
integrator.u[3] = p[1] * sin(t)
end
prob = ODEProblem(pendulum!, u₀, tspan, [0.1], callback=CallbackSet(cb))
sol = solve(prob)
plot(sol, xaxis = "t", label = ["θ [rad]" "ω [rad/s]" "M [rad/s^2]"], layout = (3, 1))
sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP(), checkpointing=true)
_, dp = adjoint_sensitivities(sol, Rosenbrock23(); sensealg=sensealg, g=(u, p, t)->sum(u));
This gives me the error:
ERROR: `nothing` returned from a Zygote vector-Jacobian product (vjp) calculation.
This indicates that your function `f` is not a function of `p` or `u`, i.e. that
the derivative is constant zero. In many cases this is due to an error in
the model definition, for example accidentally using a global parameter
instead of the one in the model (`f(u,p,t)= _p .* u`).
Weirdly enough I had to wrap the callback in a CallbackSet and add it to the problem instead of the solve call, but anyway.
On the original problem I was working on I also got:
- “ERROR: More than two occurances of the same time point. Please report this.” I couldn’t replicate in the MWE.
- “ERROR: type PeriodicCallbackAffect has no field event_times”, which happens depending on the sensealg (I think when checkpointing=true but on the second run it disappears).
Any idea on how to solve this issue of mine?