DiffEqSensitivity and DiffEqFlux have a selection of great tools to backpropagate in various ways through differential equations, to get gradients with respect to DE parameters. The documentation discusses the numerical stability of different methods, and how extra sauces like checkpointing can help.
What about the numerical stability of parameter gradients using Forward-mode autodifferentation? I’m interested in knowing, heuristically or otherwise, the regimes in which this is good/bad/horrible. So things like:
- How/does the stiffness/solution accuracy of the ODE itself translate to the numerical stability of the forward-mode gradient?
- What kinds of functions of the solution are amenable to accurate forward-mode gradients?
- Are there any intutions one would have as to situations in which forward-mode AD through a differential equation will result in heartache?
- Are there any resources containing a mathematical treatment of these particular stability issues?
Why am I asking? I’ll motivate with an example in the ‘horrible’ category. This is a a bursting, conductance-based neuron model (i.e. like the Hodgkin-Huxley equations, but more complicated). I’ve attached an image of plot(sol)
to show this:
…you can see it’s quite stiff.
In the code below, I calculate the l2 difference between the voltage trace of the nominal solution (i.e. fixed parameters), and that of the parameter-varying solution.
Here, loss(prob.p)
is a local minimum, so the true solution is norm(grad)=0
. Instead, norm(grad) = 1.45e6
.
using Pkg
Pkg.add(url="https://github.com/Dhruva2/MyModelMenagerie.jl.git")
using MyModelMenagerie, OrdinaryDiffEq, ForwardDiff, Trapz
od,ic, tspan, ps = CalciumNeuron(t->0.)
tspan = (0., 500.)
prob = ODEProblem(od,ic,tspan,ps)
sol = solve(prob, Tsit5())
tsteps = sol.t
function sol_at(p)
pprob = remake(prob, p=p)
psol = solve(pprob, Tsit5(), saveat=tsteps)
end
function loss(p)
return sum(abs2, sol_at(p) - sol)
end
grad = ForwardDiff.gradient(loss, prob.p)
using LinearAlgebra
norm(grad)
Thanks! I realise this is quite involved to answer, but at the very least it can serve as a warning to check the accuracy of your AD gradients!