Automatic differentiation and optimization with respect to callback times

Thank you! Using ForwardDiff and making the necessary changes to u0, tspan and p (as per this section of the docs) did the trick:

u0 = [10.0]
p = [1.0]
x = [1.0]
tspan = (0.0, 2.0)
prob = ODEProblem(f,u0,tspan,p)

affect!(integrator) = integrator.u[1] += 10
dose(x) = PresetTimeCallback(x, affect!)

loss(x) = solve(prob,Tsit5(),
                u0=eltype(x).(u0), tspan=eltype(x).(tspan), p=eltype(x).(p),
                reltol=1e-9, abstol=1e-9, callback=dose(x))[1,end]

@show (loss(1.0+1e-6)-loss(1.0))/1e-6
@show ForwardDiff.gradient(loss, [1.0])

My actual application is solving a time-optimal control problem where the optimal solution is bang-bang so I only need to optimize with respect to a few switching times. For anyone interested, here is a toy example using the approach above to compute gradients with respect to the switching times:

using OrdinaryDiffEq, DiffEqCallbacks, ForwardDiff, Optimisers
# - Spring dynamics x'' + x = f rescaled by the total time T
function eom!(du, u, p, t)

    x, dx, f = u
    T = p[1]

    du[1] = T * dx
    du[2] = T * (f - x)
    du[3] = T * 0.0

end

# Bang-bang control
affect!(integrator) = (integrator.u[3] = -integrator.u[3])

# ps[1] and ps[2] are the switching times, ps[3] is the total time T.
ps = [0.2, 0.7, 4.0*pi]

u0 = [-6.0, 0.0, -1.0]
tspan = (0.0, 1.0)
prob = ODEProblem(eom!, u0, tspan, ps[[3]])

# Minimize T=ps[3] subject to x(1) = dx(1) = 0.
function loss(ps)
    u_end = solve(prob, Tsit5(),
                  callback=PresetTimeCallback(ps[1:2], affect!, save_positions=(false,false)),
                  u0=eltype(ps).(u0), tspan=eltype(ps).(tspan), p=ps[[3]],
                  saveat = [1.0],
                  save_everystep=false,
                  reltol=1e-6, abstol=1e-6)[end]
    return u_end[1].^2 + u_end[2].^2 + 1.0e-6*(ps[3]^2)
end

# Optimization
opt = Optimisers.Adam()
opt_state = Optimisers.setup(opt, ps)

for epoch in 1:100_000
    gs = ForwardDiff.gradient(loss, ps)
    opt_state, ps = Optimisers.update(opt_state, ps, gs)
end

# Plot the result
using Plots
gr()
sol = solve(prob, Tsit5(),
            callback=PresetTimeCallback(ps[1:2], affect!), p=ps[[3]],
            reltol=1e-8, abstol=1e-8)
pl = plot(sol, idxs=(1,2), xlims=(-7.0,7.0), ylims=(-7.0,7.0), size=(350,350), label="", grid=true)
1 Like