Hi all,
I am trying to use the interpolated solution of my ODE for adjoint sensitivity analysis. Since I have a very small system of ODEs (n=3
), I am interested in evaluating the performance of continuous adjoint methods when using computing a dense solution during my forward pass. However, when evaluating dense=true
(together with save_everystep=true
and saveat=nothing
) I cannot get the gradients computed in the reverse pass.
Here is a MWI:
using SciMLSensitivity, OrdinaryDiffEqCore, OrdinaryDiffEqTsit5, Zygote
using BenchmarkTools
using Random, Distributions
using Optimization, OptimizationOptimisers, OptimizationOptimJL
function fiip(du, u, p, t)
du[1] = dx = p[1] * u[1] - p[2] * u[1] * u[2]
du[2] = dy = -p[3] * u[2] + p[4] * u[1] * u[2]
end
p = [1.5, 1.0, 3.0, 1.0];
u0 = [1.0; 1.0];
abstol = reltol = 1e-8
tspan = (0.0, 10.0)
prob = ODEProblem(fiip, u0, tspan, p)
N = 50
times = sort(rand(sampler(Uniform(tspan[1], tspan[2])), N))
function loss(u0, p)
sol = solve(prob, Tsit5(), u0 = u0, p = p, saveat = times,
abstol=abstol, reltol=reltol,
sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP(true)))
return sum(sum(sol.u))
end
@benchmark du0, dp = Zygote.gradient(loss, u0, p)
This works correctly, but it is saving the solution at saveat = times
. On the other hand, when evaluating a dense solution like here:
function loss_dense(u0, p)
sol_dense = solve(prob, Tsit5(), u0 = u0, p = p,
dense=true, save_everystep=true,
abstol=abstol, reltol=reltol,
sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true)))
return sum(sol_dense(times))
end
@benchmark du0, dp = Zygote.gradient(loss_dense, u0, p)
I get the message
LoadError: Standard interpolation is disabled due to sensitivity analysis being
used for the gradients. Only linear and constant interpolations are
compatible with non-AD sensitivity analysis calculations. Either
utilize tooling like saveat to avoid post-solution interpolation, use
the keyword argument dense=false for linear or constant interpolations,
or use the keyword argument sensealg=SensitivityADPassThrough() to revert
to AD-based derivatives.
Now, the message is quite clear: it seems that the continuous adjoint here does not like the interpolation (same error when using direct continuous adjoint method for optimization with AutoZygote()
). However, is this the expected behaviour? I am doing something fundamentally wrong or maybe the continuous adjoint method for higher order interpolations hasn’t been implemented?
My motivation here is to have a dense solution in the forward pass so I can compute my adjoints as fast as possible without worrying about memory use.
Thank you!!!