When ForwardDiff-ing through an ODE is unstable

DiffEqSensitivity and DiffEqFlux have a selection of great tools to backpropagate in various ways through differential equations, to get gradients with respect to DE parameters. The documentation discusses the numerical stability of different methods, and how extra sauces like checkpointing can help.

What about the numerical stability of parameter gradients using Forward-mode autodifferentation? I’m interested in knowing, heuristically or otherwise, the regimes in which this is good/bad/horrible. So things like:

  • How/does the stiffness/solution accuracy of the ODE itself translate to the numerical stability of the forward-mode gradient?
  • What kinds of functions of the solution are amenable to accurate forward-mode gradients?
  • Are there any intutions one would have as to situations in which forward-mode AD through a differential equation will result in heartache?
  • Are there any resources containing a mathematical treatment of these particular stability issues?

Why am I asking? I’ll motivate with an example in the ‘horrible’ category. This is a a bursting, conductance-based neuron model (i.e. like the Hodgkin-Huxley equations, but more complicated). I’ve attached an image of plot(sol) to show this:
calc
…you can see it’s quite stiff.

In the code below, I calculate the l2 difference between the voltage trace of the nominal solution (i.e. fixed parameters), and that of the parameter-varying solution.

Here, loss(prob.p) is a local minimum, so the true solution is norm(grad)=0. Instead, norm(grad) = 1.45e6.

using Pkg
Pkg.add(url="https://github.com/Dhruva2/MyModelMenagerie.jl.git")
using MyModelMenagerie, OrdinaryDiffEq, ForwardDiff, Trapz

od,ic, tspan, ps = CalciumNeuron(t->0.)
tspan = (0., 500.)
prob = ODEProblem(od,ic,tspan,ps)
sol = solve(prob, Tsit5())
tsteps = sol.t

function sol_at(p)
    pprob = remake(prob, p=p)
    psol = solve(pprob, Tsit5(), saveat=tsteps)
end

function loss(p)
    return sum(abs2, sol_at(p) - sol)
end

grad = ForwardDiff.gradient(loss, prob.p)
using LinearAlgebra
norm(grad)

Thanks! I realise this is quite involved to answer, but at the very least it can serve as a warning to check the accuracy of your AD gradients!

This looks like a chaotic problem because of its bursting behavior, correct? Automatic differentiation will diverge on chaotic problems. The tangent space is precisely the accumulation space of AD, and so if you have a positive Lyopunov coefficient it will diverge to infinity exponentially fast (that’s what the coefficient means!)

2 Likes

Hi, useful to know that about chaotic systems, thanks. This is due to diverging trajectories for infinitesimally perturbed initial conditions. I can see how that disrupts gradients wrt initial conditions. Is the same true for gradients wrt parameters of the vector field?

This isn’t a chaotic system however. It’s part of a central pattern generator that co-ordinates rhythmic contractions in the stomach of a lobster. As such, it settles to extremely regular firing over a wide basin of initial conditions. Let me zoom out a little to show the rhythmic dynamics:

calc

(Background info, my understanding and i’m not an expert) :
In general, bursting models that have an extra, slow-timescale, positive feedback current (https://journals.physiology.org/doi/full/10.1152/jn.00804.2017; Cellular switches orchestrate rhythmic circuits | SpringerLink) are robust to parametric variation and initial conditions, and can modulate to tonic spiking. Lots of bursting models in the literature don’t have this extra current. You still get rhythmic bursting, but it relies on exact co-ordination of parameters, and funky things like chaos happen when you slightly perturb them.

I see. Yeah, a lot of models of this sort, like chaining together multiple Hodgkin-Huxley neurons, creates a chaotic model because of the activation behavior making it sensitive to parameters, and one quick test of this is the Inf you’d get from forward-mode AD.

What happens if you do a super accurate computation, like solve(pprob, Vern9(), saveat=tsteps, abstol=1e-14,reltol=1e-14)?

With the super accurate computation, it works! norm(grad) = 0.0009. :slight_smile:

I should also note that

  1. it works with adjoint sensitivity analysis (with a normal Tsit5() solution)
  2. the numerical explosion in ForwardDiff doesn’t happen when I look at the L2 deviation of less ‘stiff’ species (i.e. Calcium, the second state, instead of voltage)

I’m not sure what the lesson is to take from this. Maybe that Forward mode can potentially diverge if your error tolerances are high on a stiff system?

I think the solution, while not chaotically-sensitive, still seems pretty sensitive. The difference is that using just generating the steps will grab the value with order 5, while saveat will naturally use a 4th order interpolation, so you’re seeing the accumulation error between the two of those in a very sensitive calculation. I think if you do this same experiment where saveat=1 (so that both modes interpolate) the gradient might turn out to be zero.

The reason why the time steps are not exactly the same between the two is because dual numbers are appending the sensitivity equations which, if you want to make sure the duals are correct, you need the error norm to take those extra equations into consideration, and thus the dt’s will diverge between the two cases after some time.

2 Likes

By

using just generating the steps

I guess you mean allowing the sol.t to be generated naturally by the algorithm. In which case that makes sense! Thanks for the explanation.

I changed saveat to 1. , and the gradient was numerically zero as you predicted.