Thank you! Using ForwardDiff and making the necessary changes to u0, tspan and p (as per this section of the docs) did the trick:
u0 = [10.0]
p = [1.0]
x = [1.0]
tspan = (0.0, 2.0)
prob = ODEProblem(f,u0,tspan,p)
affect!(integrator) = integrator.u[1] += 10
dose(x) = PresetTimeCallback(x, affect!)
loss(x) = solve(prob,Tsit5(),
u0=eltype(x).(u0), tspan=eltype(x).(tspan), p=eltype(x).(p),
reltol=1e-9, abstol=1e-9, callback=dose(x))[1,end]
@show (loss(1.0+1e-6)-loss(1.0))/1e-6
@show ForwardDiff.gradient(loss, [1.0])
My actual application is solving a time-optimal control problem where the optimal solution is bang-bang so I only need to optimize with respect to a few switching times. For anyone interested, here is a toy example using the approach above to compute gradients with respect to the switching times:
using OrdinaryDiffEq, DiffEqCallbacks, ForwardDiff, Optimisers
# - Spring dynamics x'' + x = f rescaled by the total time T
function eom!(du, u, p, t)
x, dx, f = u
T = p[1]
du[1] = T * dx
du[2] = T * (f - x)
du[3] = T * 0.0
end
# Bang-bang control
affect!(integrator) = (integrator.u[3] = -integrator.u[3])
# ps[1] and ps[2] are the switching times, ps[3] is the total time T.
ps = [0.2, 0.7, 4.0*pi]
u0 = [-6.0, 0.0, -1.0]
tspan = (0.0, 1.0)
prob = ODEProblem(eom!, u0, tspan, ps[[3]])
# Minimize T=ps[3] subject to x(1) = dx(1) = 0.
function loss(ps)
u_end = solve(prob, Tsit5(),
callback=PresetTimeCallback(ps[1:2], affect!, save_positions=(false,false)),
u0=eltype(ps).(u0), tspan=eltype(ps).(tspan), p=ps[[3]],
saveat = [1.0],
save_everystep=false,
reltol=1e-6, abstol=1e-6)[end]
return u_end[1].^2 + u_end[2].^2 + 1.0e-6*(ps[3]^2)
end
# Optimization
opt = Optimisers.Adam()
opt_state = Optimisers.setup(opt, ps)
for epoch in 1:100_000
gs = ForwardDiff.gradient(loss, ps)
opt_state, ps = Optimisers.update(opt_state, ps, gs)
end
# Plot the result
using Plots
gr()
sol = solve(prob, Tsit5(),
callback=PresetTimeCallback(ps[1:2], affect!), p=ps[[3]],
reltol=1e-8, abstol=1e-8)
pl = plot(sol, idxs=(1,2), xlims=(-7.0,7.0), ylims=(-7.0,7.0), size=(350,350), label="", grid=true)