Automatic differentiation and optimization with respect to callback times

From the Training Neural Networks in Hybrid Differential Equations example, I know that SciMLSensitivity is compatible with PresetTimeCallback. I was wondering whether it is possible to compute derivatives (and then optimize) with respect to the preset callback times (dosetimes in the linked example)?
I tried the following but it did not seem to work. It returns nothing instead of the gradient:

function f(du,u,p,t)
    du[1] = -p[1]*u[1]
end

u0 = [10.0]
p = [1.0]
x = [1.0]
tspan = (0.0, 2.0)
prob = ODEProblem(f,u0,tspan,p)

affect!(integrator) = integrator.u[1] += 10
dose(x) = PresetTimeCallback(x, affect!)

loss1(p) = solve(prob,Tsit5(),u0=u0,p=p,callback=dose(1.0),saveat=0.1, sensealg=ReverseDiffAdjoint())[1,end]
loss2(x) = solve(prob,Tsit5(),u0=u0,p=p,callback=dose(x),saveat=0.1, sensealg=ReverseDiffAdjoint())[1,end]
@show Zygote.gradient(loss1, [1.0])    # works
@show Zygote.gradient(loss2, [1.0])    # returns `nothing`

The output is:

Zygote.gradient(loss1, [1.0]) = ([-6.385494404700164],)
Zygote.gradient(loss2, [1.0]) = (nothing,)

That derivative w.r.t. the time point is not defined in the adjoint system right now. Please open an issue and we can add its derivative calculation. For now, you can use forward mode via ForwardDiff which does calculate that derivative. If you’re calling it dosetimes for pharmacometrics models, that’s going to likely be faster anyways :sweat_smile:

Thank you! Using ForwardDiff and making the necessary changes to u0, tspan and p (as per this section of the docs) did the trick:

u0 = [10.0]
p = [1.0]
x = [1.0]
tspan = (0.0, 2.0)
prob = ODEProblem(f,u0,tspan,p)

affect!(integrator) = integrator.u[1] += 10
dose(x) = PresetTimeCallback(x, affect!)

loss(x) = solve(prob,Tsit5(),
                u0=eltype(x).(u0), tspan=eltype(x).(tspan), p=eltype(x).(p),
                reltol=1e-9, abstol=1e-9, callback=dose(x))[1,end]

@show (loss(1.0+1e-6)-loss(1.0))/1e-6
@show ForwardDiff.gradient(loss, [1.0])

My actual application is solving a time-optimal control problem where the optimal solution is bang-bang so I only need to optimize with respect to a few switching times. For anyone interested, here is a toy example using the approach above to compute gradients with respect to the switching times:

using OrdinaryDiffEq, DiffEqCallbacks, ForwardDiff, Optimisers
# - Spring dynamics x'' + x = f rescaled by the total time T
function eom!(du, u, p, t)

    x, dx, f = u
    T = p[1]

    du[1] = T * dx
    du[2] = T * (f - x)
    du[3] = T * 0.0

end

# Bang-bang control
affect!(integrator) = (integrator.u[3] = -integrator.u[3])

# ps[1] and ps[2] are the switching times, ps[3] is the total time T.
ps = [0.2, 0.7, 4.0*pi]

u0 = [-6.0, 0.0, -1.0]
tspan = (0.0, 1.0)
prob = ODEProblem(eom!, u0, tspan, ps[[3]])

# Minimize T=ps[3] subject to x(1) = dx(1) = 0.
function loss(ps)
    u_end = solve(prob, Tsit5(),
                  callback=PresetTimeCallback(ps[1:2], affect!, save_positions=(false,false)),
                  u0=eltype(ps).(u0), tspan=eltype(ps).(tspan), p=ps[[3]],
                  saveat = [1.0],
                  save_everystep=false,
                  reltol=1e-6, abstol=1e-6)[end]
    return u_end[1].^2 + u_end[2].^2 + 1.0e-6*(ps[3]^2)
end

# Optimization
opt = Optimisers.Adam()
opt_state = Optimisers.setup(opt, ps)

for epoch in 1:100_000
    gs = ForwardDiff.gradient(loss, ps)
    opt_state, ps = Optimisers.update(opt_state, ps, gs)
end

# Plot the result
using Plots
gr()
sol = solve(prob, Tsit5(),
            callback=PresetTimeCallback(ps[1:2], affect!), p=ps[[3]],
            reltol=1e-8, abstol=1e-8)
pl = plot(sol, idxs=(1,2), xlims=(-7.0,7.0), ylims=(-7.0,7.0), size=(350,350), label="", grid=true)
1 Like