Correct way of computing adjoints/gradients with dense solution of ODE

Hi all,

I am trying to use the interpolated solution of my ODE for adjoint sensitivity analysis. Since I have a very small system of ODEs (n=3), I am interested in evaluating the performance of continuous adjoint methods when using computing a dense solution during my forward pass. However, when evaluating dense=true (together with save_everystep=true and saveat=nothing) I cannot get the gradients computed in the reverse pass.

Here is a MWI:

using SciMLSensitivity, OrdinaryDiffEqCore, OrdinaryDiffEqTsit5, Zygote
using BenchmarkTools
using Random, Distributions
using Optimization, OptimizationOptimisers, OptimizationOptimJL

function fiip(du, u, p, t)
    du[1] = dx = p[1] * u[1] - p[2] * u[1] * u[2]
    du[2] = dy = -p[3] * u[2] + p[4] * u[1] * u[2]
end
p = [1.5, 1.0, 3.0, 1.0];
u0 = [1.0; 1.0];

abstol = reltol = 1e-8
tspan = (0.0, 10.0)

prob = ODEProblem(fiip, u0, tspan, p)

N = 50
times = sort(rand(sampler(Uniform(tspan[1], tspan[2])), N))

function loss(u0, p)
    sol = solve(prob, Tsit5(), u0 = u0, p = p, saveat = times, 
                abstol=abstol, reltol=reltol, 
                sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP(true)))
    return sum(sum(sol.u))
end

@benchmark du0, dp = Zygote.gradient(loss, u0, p)

This works correctly, but it is saving the solution at saveat = times. On the other hand, when evaluating a dense solution like here:

function loss_dense(u0, p)
    sol_dense = solve(prob, Tsit5(), u0 = u0, p = p, 
                    dense=true, save_everystep=true,
                    abstol=abstol, reltol=reltol, 
                    sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true)))
    return sum(sol_dense(times))
end

@benchmark du0, dp = Zygote.gradient(loss_dense, u0, p)

I get the message

LoadError: Standard interpolation is disabled due to sensitivity analysis being
used for the gradients. Only linear and constant interpolations are
compatible with non-AD sensitivity analysis calculations. Either
utilize tooling like saveat to avoid post-solution interpolation, use
the keyword argument dense=false for linear or constant interpolations,
or use the keyword argument sensealg=SensitivityADPassThrough() to revert
to AD-based derivatives.

Now, the message is quite clear: it seems that the continuous adjoint here does not like the interpolation (same error when using direct continuous adjoint method for optimization with AutoZygote()). However, is this the expected behaviour? I am doing something fundamentally wrong or maybe the continuous adjoint method for higher order interpolations hasn’t been implemented?

My motivation here is to have a dense solution in the forward pass so I can compute my adjoints as fast as possible without worrying about memory use.

Thank you!!!

It’s expected and it’s fundamental to the method. The continuous interpolation of most methods requires an alternative calculation with respect to the k of each step, since the dense interpolations are built by solving a (usually one order lower) bi*k[i] summation. One, in order for the standard interpolation to be differentiable, you’d need to differentiate w.r.t. the ks of the problem, which is equivalent to the discrete adjoint formulation since that is the core internal step information. As such, it does not make sense to really add that as a default calculation in a continuous adjoint because that would add all of the calculations of the discrete adjoint to the continuous adjoint guaranteeing it would be the slowest of all. Thus if you need to use the standard dense output of the ODE solver, it really only make sense to use discrete adjoints.

However, you could build an alternative interpolation on the output. There’s pros and cons of this: it’s not really guaranteed to have the same accuracy for example. But if you take sol.u and give it to DataInterpolations.jl you could get an interpolation like a CubicSpline based only on the sol.u. It thus does not use internal derivative estimates, and therefore crucially may not be as robust to stiff behavior, but for some applications this could be sufficient. Since it’s only based on the sol.u it would only need to differentiate w.r.t. the saved values which thus propagates back into the adjoint as the delta in the g delta functions and is thus fine.

1 Like

Thank you for your response @ChrisRackauckas !

Ok, that makes sense, yes. I can see the problem. So, in general, the only alternative if I want to use a dense solution of the forward pass is using discrete adjoints? I guess I was trying to see how to setup the following option in the continuous adjoint method in cases where I don’t care about memory but I still want to use a continuous method instead of discrete.

If I understand you correctly, this right now just makes sense using the discrete adjoint method, or for the continuos adjoint method this would have to be done with a different interpolation technique that just depends of sol.u rather that the internal interpolation using during solving. Am I right?

Yes, the interpolation is important for the method of the continuous adjoint because it allows for the reverse to not require using the same time steps as going forwards (by reinterpolating u) and allows for the integral to be handled more optimally by choosing a minimal set of points (by reinterpolating the lambda). However, taking the derivative with respect to the interpolation is a much more complicated process that the continuous adjoint method cannot do very easily, and in fact if you want that done you have to either add all of the differentiation parts of the discrete adjoint (in which case, you might as well do a discrete adjoint at that point) or you use a k-independent interpolation like a spline (which then has some accuracy trade-offs)

Makes sense.

So, just to be sure, in my original code,

function loss(u0, p)
    sol = solve(prob, Tsit5(), u0 = u0, p = p, saveat = times, 
                abstol=abstol, reltol=reltol, 
                sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP(true)))
    return sum(sum(sol.u))
end

this is already using the linear interpolation (no the interpolation of the solver) for constructing the continuous adjoint or is not using any form of interpolation at all?

Sorry, saveat is fine, it’s interpolating but as part of the adjoint process. It’s just post-solution interpolation that would require differentiating the ks.

1 Like