# Solving non-linear coupled ODEs backwards in time vs reverse-mode autodiff

I have a very silly question. I am trying to understand how reverse-mode autodiff is able to actually work on coupled non-linear ODE systems that are very hard to solve backwards in time.

Take the Lorenz system (I am new to Julia so this is all Python code, but should be very simple to understand/adapt to Julia):

``````import numpy as np
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt

def lorenz(t, state, args):
x, y, z = state
rho, sigma, beta = args
return [sigma * (y - x), x * (rho - z) - y, x * y - beta * z]

# set up initial conditions, timespan for integration, and fiducial parameter values
y0 = np.array([5., 5., 5.])
tarr = np.linspace(0, 10., 1000)
params = [28., 10., 8/3.] # rho, sigma, beta

# solve the system forward in time
ys = solve_ivp(lorenz,[tarr[0],tarr[-1]],y0,dense_output=True,method='RK23',
atol=1e-12,rtol=1e-12,args=(params,),t_eval=tarr)

# now solve the system backward in time starting from the final state above and reversing time array
ys2 = solve_ivp(lorenz,[tarr[-1],tarr[0]],ys.y.T[-1],dense_output=True,method='BDF',
atol=1e-12,rtol=1e-12,args=(params,),t_eval=tarr[::-1])

# overplot the forward and backward solutions
fig, ax = plt.subplots(1,figsize=(6,4),dpi=150,subplot_kw={'projection':'3d'})
ax.plot(ys.y[0],ys.y[1],ys.y[2],'b-',lw=0.5)

end = -1 # or change this to -30 to disregard the last 30 chaotic backward plots
ax.plot(ys2.y[0][:end],ys2.y[1][:end],ys2.y[2][:end],'r-',lw=0.5,alpha=0.8) backward steps
``````

With `end=-1`, the plot looks like this (blue is the forward solution, and red is the backward solution):

You can see that the backward solution has gone horribly wrong. In fact the above solve_ivp call doesnâ€™t even fully finish all the way to t=0 â€“ it stops at t~7.42 with the error/warning â€śRequired step size is less than spacing between numbers.â€ť If we disregard the last 30 steps of the backward solution (`end=-30`), we see that the solution was doing â€śdecentâ€ť until it started spiraling out of control which probably then somehow led to the step size error:

This seems like chaotic dynamical system behavior. Notice that I am using relatively tight tolerances (atol=rtol=1e-12) and the BDF implicit/stiff solver for the backwards-in-time solve (none of the other solvers work at all for the backward solve â€“ solve_ivp just seems to hang).

Now my questions:

1. Would any of the Julia solvers be able to solve the original system backwards in time? Am I doing anything wrong above? Would the backward solve become easier for different parameters (rho,sigma,beta) and/or ICs (x,y,z)?
2. Reverse-mode autodiff involves solving the sensitivity/adjoint ODEs alongside the original ODE system backwards in time. Given that I canâ€™t even solve the original Lorenz system by itself backwards in time, how is reverse-mode autodiff able to do the backward ODE solve for the full augmented/adjoint system and propagate perturbations in the final state to corresponding linear perturbations in the initial input parameters?
3. How can we trust the gradients found by reverse-mode autodiff for these kinds of systems that are difficult to solve backwards in time? (For simplicity, restrict ourselves to just the partials of the final state variable values wrt free parameters and ICs rather than the full time series of the sensitivities.)
1 Like

It doesnâ€™t work. You need different methods. See:

1 Like

One thing that is worth being clear on â€“ this isnâ€™t necessarily the case.

Thereâ€™s actually like 10 different ways to autodiff through an ODE! It looks like youâ€™re thinking specifically of the â€ścontinuous adjoint methodâ€ť/â€śoptimise then discretiseâ€ť, which is the version that involves (a) reconstructing the original `dy/dt`, and also (b) solving the adjoint ODE `da/dt` (where `a(t) = dL/dy(t)`).

For this particular method: the backwards evolution of `y` is indeed basically impossible to solve backwards in time. But the backwards-in-time system for `a` actually exhibits the same stability properties as the forwards-in-time system for `y`. This is part of why the continuous adjoint method is a terrible method that should almost never be used in practice, because the autodiff is unconditionally unstable (either the `y` or `a` piece will explode when you run both backward).

That doesnâ€™t mean you canâ€™t autodiff through your system! The usual best thing to do here (and indeed 99% of the time) is just to differentiate through the internals of ODE solver. (â€śDiscretise then optimiseâ€ť.) This doesnâ€™t involve solving an adjoint ODE at all â€“ it entirely ignores the ODE structure and just directly differentiates all the additions and multiplications and so on, that are happening inside your solver.

This works because it â€śsolvesâ€ť for `a` (discretely and without setting up an ODE, but thatâ€™s not important), but it uses the values of `y` found on the forward pass â€“ and does not attempt to also reconstruct `y`, which is the problematic bit. This is just like normal autodiff through a neural network or whatever.

1. Almost no solver can reasonable tackle this problem backwards in time, whether thatâ€™s in Julia or JAX or anything else.
2. and 3. for most applications we can trust (good choices of) autodiff because it has the same stability as the forward evolution. If you were able to solve your system forwards in time successfully, then youâ€™ll be able to autodiff it successfully.

Finally, I believe Chris mispoke a little above. You wrote â€śI am trying to understand how reverse-mode autodiff is able to actually work on coupled non-linear ODE systems that are very hard to solve backwards in time.â€ť and as weâ€™ve seen above, the ability to solve the original ODE backwards-in-time is irrelevant.

The shadowing methods he links are for chaotic systems for which the original forward-in-time IVP is poorly behaved. It just so happens that the example you chose (the Lorenz system) is both chaotic in the forward direction and unstable in the backward direction.

(Chris â€“ Iâ€™d welcome a correction if you think Iâ€™ve got you wrong.)

Finally, if youâ€™re coming from Python â€“ shameless plug for Diffrax, which you might like.

7 Likes

I thought he was just asking about chaotic systems. Indeed there are plenty of systems which are non-chaotic but donâ€™t reverse well. This is dependent on the Lipschitz constant, which generally would mean stiffness (though not necessarily). More details:

Though there are cases where even methods which donâ€™t reverse the ODE are unstable. Chaotic systems like mentioned above are one clear case because the forward solution is not well-defined in a non-probabilistic way for any numerical method (and thus you cannot even assume that the forward trajectory is correct, you just have a shadow trajectory). But there are also cases where the forward solution is stable but the derivative calculation is unstable if the stepping behavior is unchanged (this mostly shows up on the most trivial of problems), even with forward-mode AD of an ODE solver.

3 Likes

Thank you so much @ChrisRackauckas and @patrick-kidger ! I was actually hoping to hear from one or both of you! What you said makes a lot of sense.

I went ahead and used @patrick-kidger 's amazing `diffrax` package to confirm that diffraxâ€™s forward-mode autodiff (called DirectAdjoint) and reverse-mode autodiff through the internals of the solver (called RecursiveCheckpointAdjoint) successfully return the derivatives of the final state wrt parameter variations (both sets of Jacobians are similar to each other). And that the reverse-mode based on solving the augmented ODE system backwards in time (called BacksolveAdjoint in diffrax) indeed fails with a â€śreached max stepsâ€ť error even if I allow for up to 16**5 steps (the other 2 diffrax autodiff methods solve the forward system in only ~950 steps with dopri5 and atol=rtol=1.4e-8).

However, the default jax.experimental.ode.odeint IS successful in returning the same derivatives but I was under the impression that it was solving the augmented ODE system backwards in time (based on comments in its code â€“ @patrick-kidger can you please confirm?). I get the same derivatives from jaxâ€™s odeint as diffraxâ€™s DirectAdjoint and RecursiveCheckpointAdjoint and it is very fast.

Furthermore, when I compute the gradients using finite differences, that also works and agrees with the other methods.

Here is my code for reproducibility (I will try to port this to Julia soon):

``````from jax.config import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax import jit, jvp, vjp
from jax.experimental.ode import odeint

def lorenz_diffrax(t, state, args):
x, y, z = state
rho, sigma, beta = args
return jnp.array([sigma * (y - x), x * (rho - z) - y, x * y - beta * z])

# set up diffrax inputs
terms = ODETerm(lorenz_diffrax)
y0 = jnp.array([5., 5., 5.])
rho, sigma, beta = 28., 10., 8/3.
t0 = 0.0
t1 = 10.0
dt0 = None
tsave = SaveAt(ts=jnp.linspace(t0, t1, 1000),dense=True)

# function that solves Lorenz system with DirectAdjoint fwd-mode method
def evolve_diffrax_DirectAdjoint(y0, rho, sigma, beta): # returns only final state
return diffeqsolve(terms,Dopri5(),t0,t1,dt0,y0,jnp.array([rho,sigma,beta]),

# set up perturbations in input parameters -- for now we only vary the beta parameter
delta_y0 = jnp.array([0., 0., 0.])
delta_rho = 0.
delta_sigma = 0.
delta_beta = 1.

print(diffrax_ys, diffrax_delta_ys,sep='\n')
# [ 2.11236828  3.72088662 11.39477902]
# [2527.58786396 4462.56417013  724.80455082]

# another function that instead uses RecursiveCheckpointAdjoint through internals of ODE solver
return diffeqsolve(terms,Dopri5(),t0,t1,dt0,y0,jnp.array([rho,sigma,beta]),

# define perturbations on final output state for rev-mode autodiff -- here we want gradients of 1st output
delta_ys = jnp.array([1.0,0.0,0.0])

vjp_partials = vjp_func_diffrax(delta_ys)

print(vjp_ys, vjp_partials[-1],sep='\n') #
# [ 2.11236828  3.72088662 11.39477902] # same as above
# Array([  97.38204096, -104.75567762,  989.0743617 ], dtype=float64),
# Array(-251.93713359, dtype=float64, weak_type=True),
# Array(-313.86799075, dtype=float64, weak_type=True), Array(2527.58786396, dtype=float64, weak_type=True)

# now use diffrax's BacksolveAdjoint to solve augmented ODE system backwards in time
return diffeqsolve(terms,Dopri5(),t0,t1,dt0,y0,jnp.array([rho,sigma,beta]), # cannot use saveat

### THIS FAILS
vjp_func_diffrax2(delta_ys)

### but jax-odeint works, and I think it's solving the augmented ODEs backwards in time???

# new Lorenz function with order of inputs reversed for jax-odeint convention
def lorenz_odeint(state, t, args):
x, y, z = state
rho, sigma, beta = args
return jnp.array([sigma * (y - x), x * (rho - z) - y, x * y - beta * z])

tarr = jnp.linspace(0, 10., 1000)

def evolve_odeint(y0, rho, sigma, beta):
return odeint(lorenz_odeint, y0, tarr, (rho, sigma, beta))[-1]

vjp_ys_odeint, vjp_func_odeint = vjp(evolve_odeint,y0,rho,sigma,beta)

print(vjp_ys_odeint,vjp_func_odeint(delta_ys),sep='\n')
# [ 2.11236071  3.72087324 11.39477673]
#(Array([  97.38263992, -104.75600712,  989.07885387], dtype=float64),
#Array(-251.93927603, dtype=float64),
#Array(-313.86969568, dtype=float64), Array(2527.59583685, dtype=float64))

### finally compute dy/dbeta with finite differences (choice of adjoint method is irrelevant for this)
eps = 1e-4
ys_low = odeint(lorenz_odeint,y0,tarr,(rho,sigma,beta-eps/2))[-1]
ys_high = odeint(lorenz_odeint,y0,tarr,(rho,sigma,beta+eps/2))[-1]
dy_dbeta = (ys_high-ys_low)/eps

print(dy_dbeta)
# 2527.39344683, 4462.07627433,  724.82886168 # agrees with others

``````

Also @patrick-kidger: if diffraxâ€™s DirectAdjoint and RecursiveCheckpointAdjoint are not solving the adjoint/sensitivity ODEs, why do you still call these adjoint methods? Just curious if Iâ€™m misunderstanding the word adjoint in this context.

It does! If thatâ€™s getting the same results then thatâ€™s very mysterious Iâ€™ll take a look at your example.

The literature has been wildly inconsistent with the terminology on this point. For some folks, â€śthe adjoint methodâ€ť has referred to specifically optimise-then-discretise (`BacksolveAdjoint`), and for others it has referred to specifically discretise-then-optimise (`RecursiveCheckpointAdjoint`).

Some people take it to mean anything to do with reverse-mode autodifferentiation. And indeed this is where the word is derived from! Forward mode autodiff is a linear transformation tangent_in->tangent_out, and reverse-mode autodiff is the adjoint of this transform â€“ in the same sense as the adjoint of a matrix â€“ to get the linear transformation cotangent_out->cotangent_in.

Anyway, I basically subscribe to that latter camp. I actually deliberately used the word `adjoint=...` in Diffrax to help emphasise that â€śadjointâ€ť doesnâ€™t necessarily refer to any one specific method.

1 Like

Diffrax can AD through stiff solver with newton solvesin the middle? Which lonear solver is allowed? A comparison with DE would be interesting

For the chaotic properties to kick in you need a long enough time span. Try making the final time 100.

Alright, I figured this one out. Itâ€™s because for `odeint` you are asking for output (on the forward pass) at `jnp.linspace(0, 10., 1000)`. This means that the backward solve doesnâ€™t need to go all the way from 10->0 in one go; it only needs to move between adjacent checkpoints.

In contrast with Diffrax you were only getting the output at the final time, and the backward solve has no checkpoints it can use: it has to do the unstable backward solve for the whole interval 10->0. If you pass `saveat=SaveAt(ts=jnp.linspace(0., 10., 1000))` then youâ€™ll find that Diffrax with `BacksolveAdjoint` is able to successfully compute gradients as well.

Yes, it can! It backprops through the newton solves via the implicit function theorem.
Right now the linear solver is just hardcoded to an LU solve (see here) but this is actually about to be switched out! Sneak peak â€“ weâ€™ll be announcing a slew of new linear and nonlinear solvers in the next few weeks

2 Likes

Thank you so much @patrick-kidger ! Your fix to provide SaveAt to diffraxâ€™s BacksolveAdjoint works. Now it also gives the same gradients as odeint and the other two fwd/rev-mode diffrax autodiff methods through the internals of the ODE solver.

I also tried @ChrisRackauckas â€™ suggestion to increase t_final to 100 instead of 10 to let the chaos really kick in:

Indeed, when I do this, all 4 autodiff methods give the same derivatives in the final state wrt free parameter variations and it looks like the derivatives are exploding, presumably due to the chaos. For example, all 4 methods give similar `[dx/dbeta, dy/dbeta, dz/dbeta] = [-1.62510744e+39, -2.93277731e+39, -1.04332235e+38]`.

I have to say itâ€™s interesting that even though this is chaotic, clearly the state is following the attractor solution and isnâ€™t fully occupying the â€śphase spaceâ€ť. And yet the gradients explode like crazy despite the forward solution trajectory being confined to the attractor. I suppose it would be worse for other combinations of free parameters (rho,sigma,beta), and I suppose if I increased t_final to >> 100, then eventually the gradients would just become NaNâ€™s.

I have 2 last questions if/when you have time:

1. Now that I have gradients/jacobians from autodiffâ€™ing through the internals of the ODE solver, can I use those for inference problems? Specifically, say I generated a trajectory of the Lorenz system assuming some initial (x0,y0,z0) and parameters (rho,sigma,beta) and only evolved to t_final ~ 1, so before chaos really kicks in. And then I want to infer (x0,y0,z0,rho,sigma,beta) using something like Hamiltonian Monte Carlo with gradients provided by autodiff. My impression from your posts above is that this is not possible since the Lorenz system is chaotic in both the forward and reverse directions, and instead we need novel shadow methods like the ones @ChrisRackauckas described.

2. What about doing that kind of inference for non-linear but not necessarily chaotic ODE systems such as Lotka-Volterra or the SIR disease model that are still difficult to solve backwards in time (e.g., see this)? @patrick-kidger said that itâ€™s irrelevant whether you can solve the original/augmented ODEs backward in time, but how non-linear/chaotic can an ODE system be before parameter inference with autodiff becomes infeasible? Is there a practical way to compute the Lipschitz constant of some arbitrary non-linear ODE system as @ChrisRackauckas mentioned above to know whether inference is possible?

1 Like
1. If I understand you correctly â€“ I think you should be able to uses these gradients (without shadow methods). If youâ€™re working over short time scales then chaos isnâ€™t an issue.

2. LV or SIR should be pretty simple to use alongside autodiff. (Theyâ€™re just not that complicated.) I donâ€™t think thereâ€™s really a precise answer to the question â€śhow non-linear/chaotic can an ODE system be before parameter inference with autodiff becomes infeasible?â€ť â€“ this depends on things like how accurately youâ€™re steppping with your solver, your choice of floating point precision, your time horizon, etc. As for getting a Lipschitz constant â€“ for arbitrary systems this can be quite finickity, but for many practical systems you can derive this analytically. (Or at least get bounds on it analytically.)

1 Like