If the ODE function of a neural ODE includes the gradient f’(x) of a neural network f(x), we need to calculate second order gradients of f(x) in the backward pass. For me this led to all sorts of problems when the AD method is not carefully selected.
Assume the ODE function is g(f(x), f’(x)), where f(x) is again a neural network and g(y) some simple (let’s say rational) function. Any suggestions about the choice of AD method and library for both the calculation of f’(x) and the second order gradients in the adjoint method?
So far, the only combination which worked for me was to use Zygote.gradient() for f’(x) in the forward pass and an optimize-then-discretize adjoint method (such as BacksolveAdjoint, InterpolatingAdjoint) in combination with autojacvec=ZygoteVJP(). However, I could not yet find a working method for the differentiation through the ODE solver (ReverseDiffAdjoint, TrackerAdjoint, `ZygoteAdjoint() all failed). This is especially a problem for DDEs, as there is currently no optimize-then-discretize adjoint implemented there.
Which combinations of AD methods could work here (especially for the DDE case)? Also does it make sense to combine forward and backward AD here and if yes, which libraries work well together?