Hey,
I’m trying to get parameter sensitivities of an ODE problem with a parameter-dependent time-event.
Just throwing AD at the problem doesn’t seem to work:
using OrdinaryDiffEq
using SciMLSensitivity
using ForwardDiff
# At t = p[2], we assign p[1] <- p[3]
function rhs!(du, u, p, t)
du[1] = -u[1] + p[1]
end
u0 = [1.0]
p_start = [1.2, 2.0, 0.1]
prob = ODEProblem(rhs!, u0, (0.0, 10.0), p_start)
function loss(p)
_prob = remake(prob, p=p)
function condition_disc(u, t, integrator)
return t == integrator.p[2]
end
function condition_cont(u, t, integrator)
return t - integrator.p[2]
end
function affect!(integrator)
# Triggered at t = p[2], use p[3] instead of p[1] for the remaining time
integrator.p[1] = p[3]
end
# sol = solve(_prob, Tsit5(), saveat = 0.0:0.1:10.0, tstops=[ForwardDiff.value(p[2])], callback = DiscreteCallback(condition_disc, affect!), sensealg=ForwardDiffSensitivity())
sol = solve(_prob, Tsit5(), saveat = 0.0:0.1:10.0, tstops=[ForwardDiff.value(p[2])], callback = ContinuousCallback(condition_cont, affect!), sensealg=ForwardDiffSensitivity())
loss = sum(abs2, sol .- 1)
return loss
end
ForwardDiff.gradient(loss, p_start)
This yields
3-element Vector{Float64}:
-2.9178227175008735
0.0
-116.18633113248218
But simple finite differences for p[2]
give
julia> (loss([1.20, 2.01, 0.1]) - loss([1.2, 1.99, 0.1])) / 0.02
-7.956884892551486
Neither DiscreteCallback
nor ContinuousCallback
appear to work.
How can I get the derivative wrt. p[2]
using AD?