That shouldn’t effect performance. This is forward-mode autodiff of the Jacobians in the ODE solve, not the adjoint method or the vjp calculation. So autodiff=false
should be around 2x of the autodiff one.
I’m turning this into a test case as we speak. A working version is:
sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP())
_, dp = adjoint_sensitivities(sol, Rosenbrock23(); sensealg=sensealg, g=(u, p, t)->sum(u), dgdu_continuous = (du,u,p,t)->du.=1);
Basically, it seems to be an interaction between stiff ODE solvers with autodiff on, not supplying the dg
derivative, and checkpointing. In this piece I turned off checkpointing and supplied dgdu_continuous
and it works. I have a fix for the checkpointing part and will put a patch in when I figure out why not supplying dgdu_continuous
is related, but this should be enough for you to test the real impact on your code.
For performance gains, I would suspect you wouldn’t want to use Rosenbrock methods for the adjoint. Things like TRBDF2
will be a lot faster in the backpass, regardless of the forward pass’s ODE solver choice.