Local sensitivities via ForwardSensitivity from DifferentialEquations in optimization with AD

I am trying to optimize local sensitivities of an ode problem using ODEForwardSensitivityProblem from the DifferentialEquations suite, but the last line throws a TypeError.

I am using forward sensitivites as the ode system is small.

The minimal example is taken from here.

using DiffEqSensitivity
using ForwardDiff

function f(du,u,p,t)
  du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2]
  du[2] = dy = -p[3]*u[2] + u[1]*u[2]
end

p = [1.5,1.0,3.0]
prob = ODEForwardSensitivityProblem(f,[1.0;1.0],(0.0,10.0),p)

function sensitivity(x)
  _u0 = convert.(eltype(x), prob.u0)
  _u0[1] = x[1]
  _p = convert.(eltype(x), prob.p)
  _p[3] = x[2]
  _prob = remake(prob,u0=_u0,p=_p)
  sol = solve(_prob,Tsit5())
  dp = extract_local_sensitivities(sol)[2]
  dp[1][end]
end

sensitivity(ones(2))
eval_grad = x -> ForwardDiff.gradient(sensitivity, x)
eval_grad(ones(2))

Any help is much appreciated as it is not possible to use ForwardSensitivityProblem with any optimizer where an exact jacobian is required or beneficial, at least with the AD packages I tried, i.e. ForwardDiffand Zygote.

If you’re doing ForwardDiff on the outside, then the problem is AD-differentiable, in which case you might as well use double ForwardDiff. The following works fine:

using DiffEqSensitivity, OrdinaryDiffEq
using ForwardDiff

function f(du,u,p,t)
  du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2]
  du[2] = dy = -p[3]*u[2] + u[1]*u[2]
end

p = [1.5,1.0,3.0]
prob = ODEProblem(f,[1.0;1.0],(0.0,10.0),p)
function sensitivity(x)
  _u0 = convert.(eltype(x), prob.u0)
  _u0[1] = x[1]
  _p = convert.(eltype(x), prob.p)
  _p[3] = x[2]
  prob2 = remake(prob,u0=_u0,p=_p)
  dp = ForwardDiff.jacobian(_p) do _p2
      prob3 = remake(prob2,u0=convert.(eltype(_p2),prob2.u0),p=_p2)
      solve(prob3,Tsit5())[end]
  end
  dp[1]
end

sensitivity(ones(2))
eval_grad = x -> ForwardDiff.gradient(sensitivity, x)
eval_grad(ones(2))

Now if you really want to use the sensitivity problem, you’re not differentiating the problem construction. You’d want to do something like:

using DiffEqSensitivity, OrdinaryDiffEq
using ForwardDiff

function f(du,u,p,t)
  du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2]
  du[2] = dy = -p[3]*u[2] + u[1]*u[2]
end

p = [1.5,1.0,3.0]

function sensitivity(x)
  _u0 = convert.(eltype(x), prob.u0)
  _u0[1] = x[1]
  _p = convert.(eltype(x), prob.p)
  _p[3] = x[2]
  _prob = ODEForwardSensitivityProblem(f,_u0,(0.0,10.0),_p,ForwardSensitivity(autojacvec=false))
  sol = solve(_prob,Tsit5())
  dp = extract_local_sensitivities(sol)[2]
  dp[1][end]
end

sensitivity(ones(2))
eval_grad = x -> ForwardDiff.gradient(sensitivity, x)
eval_grad(ones(2))

but somewhere the gradients are getting dropped so that will need an issue. Could you file an issue? Even with that, I’d still suggest the doubled ForwardDiff.