Parameter sensitivity of ODE with parameter-dependent event

Hey,

I’m trying to get parameter sensitivities of an ODE problem with a parameter-dependent time-event.
Just throwing AD at the problem doesn’t seem to work:

using OrdinaryDiffEq
using SciMLSensitivity
using ForwardDiff

# At t = p[2], we assign p[1] <- p[3]
function rhs!(du, u, p, t)
    du[1] = -u[1] + p[1]
end

u0 = [1.0]
p_start = [1.2, 2.0, 0.1]

prob = ODEProblem(rhs!, u0, (0.0, 10.0), p_start)

function loss(p)
    _prob = remake(prob, p=p)

    function condition_disc(u, t, integrator)
        return t == integrator.p[2]
    end

    function condition_cont(u, t, integrator)
        return t - integrator.p[2]
    end

    function affect!(integrator)
        # Triggered at t = p[2], use p[3] instead of p[1] for the remaining time
        integrator.p[1] = p[3]
    end

#    sol = solve(_prob, Tsit5(), saveat = 0.0:0.1:10.0, tstops=[ForwardDiff.value(p[2])], callback = DiscreteCallback(condition_disc, affect!), sensealg=ForwardDiffSensitivity())
    sol = solve(_prob, Tsit5(), saveat = 0.0:0.1:10.0, tstops=[ForwardDiff.value(p[2])], callback = ContinuousCallback(condition_cont, affect!), sensealg=ForwardDiffSensitivity())
    loss = sum(abs2, sol .- 1)
    return loss
end

ForwardDiff.gradient(loss, p_start)

This yields

3-element Vector{Float64}:
   -2.9178227175008735
    0.0
 -116.18633113248218

But simple finite differences for p[2] give

julia> (loss([1.20, 2.01, 0.1]) - loss([1.2, 1.99, 0.1])) / 0.02
-7.956884892551486

Neither DiscreteCallback nor ContinuousCallback appear to work.
How can I get the derivative wrt. p[2] using AD?

Hello @Neodym, ~2.5 years after your question I am also looking at the same problem, and finding myself struggling with getting it to work. Have you had any luck on getting the derivatives to materialized with any form of AD? If yes, would you consider sharing the solution.

And generally speaking I was trying to verify the claims/results from A Comparison of Automatic Differentiation and Continuous Sensitivity Analysis for Derivatives of Differential Equation Solutions article about the generalization of the Discrete Local Sensitivity Analysis via AD to hybrid systems, i.e., systems with parametrized events. Unfortunately, I have hard time reproducing the results stated in the Sec. IV for the given system, particularly obtaining the value \frac{\partial y(0)}{\partial a}.

I would be really thankful for any explanation/resource/discussion that could help me verify the results of the paper, and understand the mechanism of injecting the dependence on a into the dataflow of the calculation of y(0).

I think this could be one of the many, many problems (once you stray outside of widely used NN components) where AD needs some manual “help”, and it’s well worth learning how to take the derivatives yourself. For ODEs, this is reviewed in chapter 9 of our course notes from our MIT “Matrix Calculus” course.

Here, you have an ODE of the form

\frac{du}{dt} = f(u,t) = \begin{cases} f_1(u, t) & t \le t_0 \\ f_2(u,t) & t > t_0 \end{cases} = f_1(u,t) + \Theta(t - t_0) \left(f_2(u,t) - f_1(u,t)\right)

where \Theta(t) is the Heaviside step function.

Now, suppose that you want the derivative \frac{\partial u}{\partial t_0} with respect to the time t_0 where the jump occurs in the right-hand side. (This is “forward-mode” sensitivity analysis.) This derivative satisfies the linear differential equation:

\frac{d}{dt} \left( \frac{\partial u}{\partial t_0}\right) = \frac{\partial f}{\partial u} \frac{\partial u}{\partial t_0} + \frac{\partial f}{\partial t_0} \\ = \frac{\partial f}{\partial u} \frac{\partial u}{\partial t_0} - \delta(t - t_0) \left(f_2(u(t_0,t_0) - f_1(u(t_0),t_0)\right)

Now, you should immediately see why AD will typically have a problem here, unless it is specifically taught how to handle such right-hand-sides: you can probably use AD to compute \partial f/\partial u with no problem, but \partial f/\partial t_0 yields a Dirac delta function which AD won’t know what to do with (since it is a distribution rather than an ordinary function).

But if you do it manually, there is no problem with a Dirac delta on the right-hand-side: it just means that \partial u/\partial t_0 gains a jump discontinuity at t_0:

\left. \frac{\partial u}{\partial t_0} \right|_{t=t_0^+} = \left. \frac{\partial u}{\partial t_0} \right|_{t=t_0^-} - \left(f_2(u(t_0,t_0) - f_1(u(t_0),t_0)\right)

which you can easily specify via a continuous callback added to the equation \frac{d}{dt} \left( \frac{\partial u}{\partial t_0}\right) = \frac{\partial f}{\partial u} \frac{\partial u}{\partial t_0} that holds for t \ne t_0 (which can be co-evolved with the \frac{du}{dt} ODE).

Note that the initial condition for the sensitivity is \partial u/\partial t_0 = 0, so this gives a further simplification: the solution of \frac{d}{dt} \left( \frac{\partial u}{\partial t_0}\right) = \frac{\partial f}{\partial u} \frac{\partial u}{\partial t_0} is simply \partial u/\partial t_0 = 0 for t < t_0, so you can simply start the solution at t=t_0 with initial condition \left. \frac{\partial u}{\partial t_0} \right|_{t=t_0^+} = - \left(f_2(u(t_0,t_0) - f_1(u(t_0),t_0)\right).

If you have other parameters besides t_0, you can differentiate them in the usual way. If you have lots of parameters, and are differentiating a scalar function of the solution (e.g. a loss function), then you may want to implement reverse-mode differentiation. I’ll leave this as an exercise following the description in our course notes, but it is straightforward — again, one term just has a Dirac delta function, which will probably confuse AD but is easy to insert analytically.

2 Likes

Reverse Mode

This is already taken into account in the adjoint system. It’s documented here:

with a video about it as well:

So if you use SciMLSensitivty.jl it’s all handled. That means any reverse mode that is captured will do it. And there’s tests along these lines for reverse mode here:

Forward Mode

But ForwardDiff.gradient doesn’t capture in the adjoint system. As you can see from the derivation though, the key issue is that you need to differentiate time. This is tested here:

https://github.com/SciML/OrdinaryDiffEq.jl/blob/v6.106.0/test/ad/autodiff_events.jl

Now the tricky thing about this is that in order to make direct AD of the solver work in this kind of situation, what needs to happen is that the time span itself must be upgraded to dual numbers, because that’s effectively happening is you need to differentiate through the change in the time point of the event. So it is not sufficient to make u0 dual valued, you also need to make tspan. You can force this by doing something like:

function loss(p)
    _prob = remake(prob, p=p, tspan = eltype(p).(prob.tspan))

and that should be the fix that forces differentiation through the callback. Note that this should be happening automatically: the DiffEqBase preprocessing pipeline has a tspan promotion that exists to cover this case:

The actual issue here then is a user issue. You can see it in their code:

tstops=[ForwardDiff.value(p[2])] means “drop derivatives on this term”. USERS SHOULD NEVER USE FORWARDDIFF.VALUE OR ONLY DO SO WITH EXTREME CAUTION! It’s not documented for a reason :sweat_smile:. Now from the analyses above you can see that the differentiation with respect to this tstop value is actually exactly the missing derivative, so setting that dual value to zero is “the bug”. If the user let tspan be dual valued and kept the dual here, it would differentiate the callback correctly. So please never use ForwardDiff.value unless you know why you’re doing it.

The examples there are in the linked package tests. They pass in every release. If you have any trouble reproducing that then open an issue with an MWE.

3 Likes

Thanks to both of you @ChrisRackauckas, @stevengj for the insight.

Unfortunately I did not manage to find the right test that implements the ODE in the aforementioned manuscript (so if you remember and could kindly point me to it it’d be great!), but I figured out how the event needs to affect the derivative of y in order to incorporate the correct sensitivity.

Just gonna leave it for those here, who might also be interested in a similar question. The system in question is

\begin{align*} & \dot{x} = f_x = -a \\ & \dot{y} = f_y = \begin{cases} & b, x(t) > 0 \\ & 0 \end{cases}. \end{align*}

The generic solution for both variables at some time T is as follows:

\begin{align*} & x(T) = x_0 + \int_{0}^{T} f_x dt = x_0 + \int_{0}^{T} (-a) dt \\ & y(T) = y_0 + \int_{0}^{T} f_y dt = y_0 + \int_{0}^{t^*} b dt + \int_{t^*}^{T} 0 dt. \end{align*}

With the resultant derivative wrt a, i.e. \frac{\partial}{\partial a}(\cdot) := \partial_a (\cdot) (including simplifications)

\begin{align*} &\partial_a x(T) = \partial_a x_0 - \partial_a(a \int_{0}^{T} dt) \\ &\partial_a y(T) = \partial_a y_0 + \underline{\partial_a \int_{0}^{t^*} b dt}, \end{align*}

where the Leibniz rule of differentiation applies to the underlined term (this is equivalent to dealing with dirac deltas in @stevengj answer), and yields for \partial_a y(T)

\begin{align*} &\partial_a y(T) = b \cdot \partial_a t^*, \end{align*}

where \partial_a t^* can be obtained by the invocation of implicit function differentiation upon detection of the event. Since the event is controlled by x_a(t) = 0, the implicit function theorem tells us that locally t^*(a) exists which satisfies the x_a(t) = 0. The implicit derivative wrt a gives

\begin{align*} & \partial_a x + \partial_t x \partial_a t^* = 0 \rightarrow \\ & \partial_a t^* = -(\partial_t x)^{-1}\partial_a x = -(\dot{x})\partial_a x, \end{align*}

which gives us the prescription for how to handle the event in forward sensitivity analysis (in pseudo code), i.e.,

y.der =  -b * x.der / (-a) = b * x.der / a

where x.der term would be picked up by propagating the seeded dual of a.

It is a bit spelled out, but for me personally it was confusing, when I read through the manuscript. Hope somebody else finds it insightful too.