Solving non-linear coupled ODEs backwards in time vs reverse-mode autodiff

I have a very silly question. I am trying to understand how reverse-mode autodiff is able to actually work on coupled non-linear ODE systems that are very hard to solve backwards in time.

Take the Lorenz system (I am new to Julia so this is all Python code, but should be very simple to understand/adapt to Julia):

import numpy as np
from scipy.integrate import solve_ivp 
import matplotlib.pyplot as plt

def lorenz(t, state, args):
    x, y, z = state
    rho, sigma, beta = args 
    return [sigma * (y - x), x * (rho - z) - y, x * y - beta * z]

# set up initial conditions, timespan for integration, and fiducial parameter values
y0 = np.array([5., 5., 5.])
tarr = np.linspace(0, 10., 1000)
params = [28., 10., 8/3.] # rho, sigma, beta 

# solve the system forward in time
ys = solve_ivp(lorenz,[tarr[0],tarr[-1]],y0,dense_output=True,method='RK23',
                atol=1e-12,rtol=1e-12,args=(params,),t_eval=tarr)

# now solve the system backward in time starting from the final state above and reversing time array
ys2 = solve_ivp(lorenz,[tarr[-1],tarr[0]],ys.y.T[-1],dense_output=True,method='BDF',
                atol=1e-12,rtol=1e-12,args=(params,),t_eval=tarr[::-1])

# overplot the forward and backward solutions
fig, ax = plt.subplots(1,figsize=(6,4),dpi=150,subplot_kw={'projection':'3d'})
ax.plot(ys.y[0],ys.y[1],ys.y[2],'b-',lw=0.5)

end = -1 # or change this to -30 to disregard the last 30 chaotic backward plots
ax.plot(ys2.y[0][:end],ys2.y[1][:end],ys2.y[2][:end],'r-',lw=0.5,alpha=0.8) backward steps

With end=-1, the plot looks like this (blue is the forward solution, and red is the backward solution):

You can see that the backward solution has gone horribly wrong. In fact the above solve_ivp call doesn’t even fully finish all the way to t=0 – it stops at t~7.42 with the error/warning “Required step size is less than spacing between numbers.” If we disregard the last 30 steps of the backward solution (end=-30), we see that the solution was doing “decent” until it started spiraling out of control which probably then somehow led to the step size error:

This seems like chaotic dynamical system behavior. Notice that I am using relatively tight tolerances (atol=rtol=1e-12) and the BDF implicit/stiff solver for the backwards-in-time solve (none of the other solvers work at all for the backward solve – solve_ivp just seems to hang).

Now my questions:

  1. Would any of the Julia solvers be able to solve the original system backwards in time? Am I doing anything wrong above? Would the backward solve become easier for different parameters (rho,sigma,beta) and/or ICs (x,y,z)?
  2. Reverse-mode autodiff involves solving the sensitivity/adjoint ODEs alongside the original ODE system backwards in time. Given that I can’t even solve the original Lorenz system by itself backwards in time, how is reverse-mode autodiff able to do the backward ODE solve for the full augmented/adjoint system and propagate perturbations in the final state to corresponding linear perturbations in the initial input parameters?
  3. How can we trust the gradients found by reverse-mode autodiff for these kinds of systems that are difficult to solve backwards in time? (For simplicity, restrict ourselves to just the partials of the final state variable values wrt free parameters and ICs rather than the full time series of the sensitivities.)
1 Like

It doesn’t work. You need different methods. See:

1 Like

One thing that is worth being clear on – this isn’t necessarily the case.

There’s actually like 10 different ways to autodiff through an ODE! :smiley: It looks like you’re thinking specifically of the “continuous adjoint method”/“optimise then discretise”, which is the version that involves (a) reconstructing the original dy/dt, and also (b) solving the adjoint ODE da/dt (where a(t) = dL/dy(t)).

For this particular method: the backwards evolution of y is indeed basically impossible to solve backwards in time. But the backwards-in-time system for a actually exhibits the same stability properties as the forwards-in-time system for y. This is part of why the continuous adjoint method is a terrible method that should almost never be used in practice, because the autodiff is unconditionally unstable (either the y or a piece will explode when you run both backward).

That doesn’t mean you can’t autodiff through your system! The usual best thing to do here (and indeed 99% of the time) is just to differentiate through the internals of ODE solver. (“Discretise then optimise”.) This doesn’t involve solving an adjoint ODE at all – it entirely ignores the ODE structure and just directly differentiates all the additions and multiplications and so on, that are happening inside your solver.

This works because it “solves” for a (discretely and without setting up an ODE, but that’s not important), but it uses the values of y found on the forward pass – and does not attempt to also reconstruct y, which is the problematic bit. This is just like normal autodiff through a neural network or whatever.

So to answer your questions:

  1. Almost no solver can reasonable tackle this problem backwards in time, whether that’s in Julia or JAX or anything else.
  2. and 3. for most applications we can trust (good choices of) autodiff because it has the same stability as the forward evolution. If you were able to solve your system forwards in time successfully, then you’ll be able to autodiff it successfully.

Finally, I believe Chris mispoke a little above. You wrote “I am trying to understand how reverse-mode autodiff is able to actually work on coupled non-linear ODE systems that are very hard to solve backwards in time.” and as we’ve seen above, the ability to solve the original ODE backwards-in-time is irrelevant.

The shadowing methods he links are for chaotic systems for which the original forward-in-time IVP is poorly behaved. It just so happens that the example you chose (the Lorenz system) is both chaotic in the forward direction and unstable in the backward direction.

(Chris – I’d welcome a correction if you think I’ve got you wrong.)

Finally, if you’re coming from Python – shameless plug for Diffrax, which you might like.

7 Likes

I thought he was just asking about chaotic systems. Indeed there are plenty of systems which are non-chaotic but don’t reverse well. This is dependent on the Lipschitz constant, which generally would mean stiffness (though not necessarily). More details:

Though there are cases where even methods which don’t reverse the ODE are unstable. Chaotic systems like mentioned above are one clear case because the forward solution is not well-defined in a non-probabilistic way for any numerical method (and thus you cannot even assume that the forward trajectory is correct, you just have a shadow trajectory). But there are also cases where the forward solution is stable but the derivative calculation is unstable if the stepping behavior is unchanged (this mostly shows up on the most trivial of problems), even with forward-mode AD of an ODE solver.

3 Likes

Thank you so much @ChrisRackauckas and @patrick-kidger ! I was actually hoping to hear from one or both of you! What you said makes a lot of sense.

I went ahead and used @patrick-kidger 's amazing diffrax package to confirm that diffrax’s forward-mode autodiff (called DirectAdjoint) and reverse-mode autodiff through the internals of the solver (called RecursiveCheckpointAdjoint) successfully return the derivatives of the final state wrt parameter variations (both sets of Jacobians are similar to each other). And that the reverse-mode based on solving the augmented ODE system backwards in time (called BacksolveAdjoint in diffrax) indeed fails with a “reached max steps” error even if I allow for up to 16**5 steps (the other 2 diffrax autodiff methods solve the forward system in only ~950 steps with dopri5 and atol=rtol=1.4e-8).

However, the default jax.experimental.ode.odeint IS successful in returning the same derivatives but I was under the impression that it was solving the augmented ODE system backwards in time (based on comments in its code – @patrick-kidger can you please confirm?). I get the same derivatives from jax’s odeint as diffrax’s DirectAdjoint and RecursiveCheckpointAdjoint and it is very fast.

Furthermore, when I compute the gradients using finite differences, that also works and agrees with the other methods.

Here is my code for reproducibility (I will try to port this to Julia soon):

from jax.config import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax import jit, jvp, vjp
from jax.experimental.ode import odeint
from diffrax import diffeqsolve, ODETerm, PIDController, SaveAt, Dopri5, DirectAdjoint, RecursiveCheckpointAdjoint, BacksolveAdjoint

def lorenz_diffrax(t, state, args): 
    x, y, z = state 
    rho, sigma, beta = args 
    return jnp.array([sigma * (y - x), x * (rho - z) - y, x * y - beta * z]) 

# set up diffrax inputs 
terms = ODETerm(lorenz_diffrax)
y0 = jnp.array([5., 5., 5.])
rho, sigma, beta = 28., 10., 8/3. 
t0 = 0.0
t1 = 10.0 
dt0 = None
tsave = SaveAt(ts=jnp.linspace(t0, t1, 1000),dense=True)

# function that solves Lorenz system with DirectAdjoint fwd-mode method
def evolve_diffrax_DirectAdjoint(y0, rho, sigma, beta): # returns only final state
    return diffeqsolve(terms,Dopri5(),t0,t1,dt0,y0,jnp.array([rho,sigma,beta]),
                       stepsize_controller=PIDController(rtol=1.4e-8,atol=1.4e-8),adjoint=DirectAdjoint()).ys[-1] 

# set up perturbations in input parameters -- for now we only vary the beta parameter
delta_y0 = jnp.array([0., 0., 0.])
delta_rho = 0.
delta_sigma = 0.
delta_beta = 1.

diffrax_ys, diffrax_delta_ys = jvp(evolve_diffrax_DirectAdjoint, (y0,rho,sigma,beta),(delta_y0,delta_rho,delta_sigma,delta_beta))

print(diffrax_ys, diffrax_delta_ys,sep='\n')
# [ 2.11236828  3.72088662 11.39477902]
# [2527.58786396 4462.56417013  724.80455082]

# another function that instead uses RecursiveCheckpointAdjoint through internals of ODE solver
def evolve_diffrax_RecursiveCheckpointAdjoint(y0, rho, sigma, beta): 
    return diffeqsolve(terms,Dopri5(),t0,t1,dt0,y0,jnp.array([rho,sigma,beta]),
                       stepsize_controller=PIDController(rtol=1.4e-8,atol=1.4e-8),adjoint=RecursiveCheckpointAdjoint()).ys[-1] 

# define perturbations on final output state for rev-mode autodiff -- here we want gradients of 1st output
delta_ys = jnp.array([1.0,0.0,0.0])

vjp_ys_diffrax, vjp_func_diffrax = vjp(evolve_diffrax_RecursiveCheckpointAdjoint,y0,rho,sigma,beta)
vjp_partials = vjp_func_diffrax(delta_ys)

print(vjp_ys, vjp_partials[-1],sep='\n') # 
# [ 2.11236828  3.72088662 11.39477902] # same as above 
# Array([  97.38204096, -104.75567762,  989.0743617 ], dtype=float64), 
# Array(-251.93713359, dtype=float64, weak_type=True), 
# Array(-313.86799075, dtype=float64, weak_type=True), Array(2527.58786396, dtype=float64, weak_type=True)

# now use diffrax's BacksolveAdjoint to solve augmented ODE system backwards in time
def evolve_diffrax_BacksolveAdjoint(y0, rho, sigma, beta): 
    return diffeqsolve(terms,Dopri5(),t0,t1,dt0,y0,jnp.array([rho,sigma,beta]), # cannot use saveat 
                       stepsize_controller=PIDController(rtol=1.4e-8,atol=1.4e-8),adjoint=BacksolveAdjoint(),max_steps=16**5).ys[-1] 

vjp_ys_diffrax2, vjp_func_diffrax2 = vjp(evolve_diffrax_BacksolveAdjoint,y0,rho,sigma,beta)

### THIS FAILS 
vjp_func_diffrax2(delta_ys)


### but jax-odeint works, and I think it's solving the augmented ODEs backwards in time???

# new Lorenz function with order of inputs reversed for jax-odeint convention
def lorenz_odeint(state, t, args):
    x, y, z = state
    rho, sigma, beta = args 
    return jnp.array([sigma * (y - x), x * (rho - z) - y, x * y - beta * z])

tarr = jnp.linspace(0, 10., 1000)

def evolve_odeint(y0, rho, sigma, beta): 
    return odeint(lorenz_odeint, y0, tarr, (rho, sigma, beta))[-1]

vjp_ys_odeint, vjp_func_odeint = vjp(evolve_odeint,y0,rho,sigma,beta)

print(vjp_ys_odeint,vjp_func_odeint(delta_ys),sep='\n')
# [ 2.11236071  3.72087324 11.39477673]
#(Array([  97.38263992, -104.75600712,  989.07885387], dtype=float64), 
#Array(-251.93927603, dtype=float64), 
#Array(-313.86969568, dtype=float64), Array(2527.59583685, dtype=float64))

### finally compute dy/dbeta with finite differences (choice of adjoint method is irrelevant for this)
eps = 1e-4
ys_low = odeint(lorenz_odeint,y0,tarr,(rho,sigma,beta-eps/2))[-1]
ys_high = odeint(lorenz_odeint,y0,tarr,(rho,sigma,beta+eps/2))[-1]
dy_dbeta = (ys_high-ys_low)/eps

print(dy_dbeta)
# 2527.39344683, 4462.07627433,  724.82886168 # agrees with others


Also @patrick-kidger: if diffrax’s DirectAdjoint and RecursiveCheckpointAdjoint are not solving the adjoint/sensitivity ODEs, why do you still call these adjoint methods? Just curious if I’m misunderstanding the word adjoint in this context.

It does! If that’s getting the same results then that’s very mysterious :smiley: I’ll take a look at your example.

The literature has been wildly inconsistent with the terminology on this point. For some folks, “the adjoint method” has referred to specifically optimise-then-discretise (BacksolveAdjoint), and for others it has referred to specifically discretise-then-optimise (RecursiveCheckpointAdjoint).

Some people take it to mean anything to do with reverse-mode autodifferentiation. And indeed this is where the word is derived from! Forward mode autodiff is a linear transformation tangent_in->tangent_out, and reverse-mode autodiff is the adjoint of this transform – in the same sense as the adjoint of a matrix – to get the linear transformation cotangent_out->cotangent_in.

Anyway, I basically subscribe to that latter camp. I actually deliberately used the word adjoint=... in Diffrax to help emphasise that “adjoint” doesn’t necessarily refer to any one specific method.

1 Like

Diffrax can AD through stiff solver with newton solvesin the middle? Which lonear solver is allowed? A comparison with DE would be interesting

For the chaotic properties to kick in you need a long enough time span. Try making the final time 100.

Alright, I figured this one out. It’s because for odeint you are asking for output (on the forward pass) at jnp.linspace(0, 10., 1000). This means that the backward solve doesn’t need to go all the way from 10->0 in one go; it only needs to move between adjacent checkpoints.

In contrast with Diffrax you were only getting the output at the final time, and the backward solve has no checkpoints it can use: it has to do the unstable backward solve for the whole interval 10->0. If you pass saveat=SaveAt(ts=jnp.linspace(0., 10., 1000)) then you’ll find that Diffrax with BacksolveAdjoint is able to successfully compute gradients as well.

Yes, it can! It backprops through the newton solves via the implicit function theorem.
Right now the linear solver is just hardcoded to an LU solve (see here) but this is actually about to be switched out! Sneak peak – we’ll be announcing a slew of new linear and nonlinear solvers in the next few weeks :slight_smile:

2 Likes

Thank you so much @patrick-kidger ! Your fix to provide SaveAt to diffrax’s BacksolveAdjoint works. Now it also gives the same gradients as odeint and the other two fwd/rev-mode diffrax autodiff methods through the internals of the ODE solver.

I also tried @ChrisRackauckas ’ suggestion to increase t_final to 100 instead of 10 to let the chaos really kick in:

Indeed, when I do this, all 4 autodiff methods give the same derivatives in the final state wrt free parameter variations and it looks like the derivatives are exploding, presumably due to the chaos. For example, all 4 methods give similar [dx/dbeta, dy/dbeta, dz/dbeta] = [-1.62510744e+39, -2.93277731e+39, -1.04332235e+38].

I have to say it’s interesting that even though this is chaotic, clearly the state is following the attractor solution and isn’t fully occupying the “phase space”. And yet the gradients explode like crazy despite the forward solution trajectory being confined to the attractor. I suppose it would be worse for other combinations of free parameters (rho,sigma,beta), and I suppose if I increased t_final to >> 100, then eventually the gradients would just become NaN’s.

I have 2 last questions if/when you have time:

  1. Now that I have gradients/jacobians from autodiff’ing through the internals of the ODE solver, can I use those for inference problems? Specifically, say I generated a trajectory of the Lorenz system assuming some initial (x0,y0,z0) and parameters (rho,sigma,beta) and only evolved to t_final ~ 1, so before chaos really kicks in. And then I want to infer (x0,y0,z0,rho,sigma,beta) using something like Hamiltonian Monte Carlo with gradients provided by autodiff. My impression from your posts above is that this is not possible since the Lorenz system is chaotic in both the forward and reverse directions, and instead we need novel shadow methods like the ones @ChrisRackauckas described.

  2. What about doing that kind of inference for non-linear but not necessarily chaotic ODE systems such as Lotka-Volterra or the SIR disease model that are still difficult to solve backwards in time (e.g., see this)? @patrick-kidger said that it’s irrelevant whether you can solve the original/augmented ODEs backward in time, but how non-linear/chaotic can an ODE system be before parameter inference with autodiff becomes infeasible? Is there a practical way to compute the Lipschitz constant of some arbitrary non-linear ODE system as @ChrisRackauckas mentioned above to know whether inference is possible?

1 Like
  1. If I understand you correctly – I think you should be able to uses these gradients (without shadow methods). If you’re working over short time scales then chaos isn’t an issue.

  2. LV or SIR should be pretty simple to use alongside autodiff. (They’re just not that complicated.) I don’t think there’s really a precise answer to the question “how non-linear/chaotic can an ODE system be before parameter inference with autodiff becomes infeasible?” – this depends on things like how accurately you’re steppping with your solver, your choice of floating point precision, your time horizon, etc. As for getting a Lipschitz constant – for arbitrary systems this can be quite finickity, but for many practical systems you can derive this analytically. (Or at least get bounds on it analytically.)

1 Like