Slow backward gradients in ODEs

First of all, this model is tiny. If you go by the results of [1812.01892] A Comparison of Automatic Differentiation and Continuous Sensitivity Analysis for Derivatives of Differential Equation Solutions, you need about 100 ODEs to overcome the cost of reverse-mode AD in the best case, in the worst case even more than 1000. Reverse-mode just generally has a higher overhead but better scaling. So in this case, it shouldn’t be surprising that ForwardDiff is best.

For the other result, Zygote plugs into the DiffEqSensitivity setup Local Sensitivity Analysis (Automatic Differentiation) · DifferentialEquations.jl and so it ends up “not being too bad” because it’s using a hand-written adjoint rule, where all of the internal work is probably delegated to a vjp defined by Enzyme to handle the mutations well.

I’m not sure if ReverseDiff.jl ends up hitting the same rules? I think we defined a connection here https://github.com/SciML/DiffEqBase.jl/blob/master/src/reversediff.jl#L47-L50 but I’d need to double check if that actually catches it. The problem with ReverseDiff.jl though is that if you use it in Array{TrackedReal} accidentally (instead of TrackedArray), you can miss dispatches like this and get some huge performance drops. I think the only real application for ReverseDiff.jl at the user level is for compiled tape mode on cases where Array{TrackedReal} is used to avoid errors (i.e. to handle mutation).

3 Likes