Why the separation of `ODEProblem` and `solve` in DifferentialEquations.jl?

https://diffeq.sciml.ai/latest/analysis/sensitivity/

TrackerAdjoint is able to use a TrackedArray form with out-of-place functions du = f(u,p,t) but requires an Array{TrackedReal} form for f(du,u,p,t) mutating du . The latter has much more overhead, and should be avoided if possible. Thus if solving non-ODEs with lots of parameters, using TrackerAdjoint with an out-of-place definition may be the current best option.

The same is true about the vjps, though if you don’t have branching you can ReverseDiffVJP(true) to compile the backpass. But I think your problem is probably much better off doing

function aug_dynamics!(z, policy_params, t)
    x = @view z[2:end]
    u = policy(x, policy_params)
    [x' * x + u' * u;u]
end

to accommodate for those factors in reverse mode AD (or use the AoS->SoA conversion and convert back).

1 Like