If the ODE function of a neural ODE includes the gradient f’(x) of a neural network f(x), we need to calculate second order gradients of f(x) in the backward pass. For me this led to all sorts of problems when the AD method is not carefully selected.
Assume the ODE function is g(f(x), f’(x)), where f(x) is again a neural network and g(y) some simple (let’s say rational) function. Any suggestions about the choice of AD method and library for both the calculation of f’(x) and the second order gradients in the adjoint method?
So far, the only combination which worked for me was to use Zygote.gradient() for f’(x) in the forward pass and an optimize-then-discretize adjoint method (such as BacksolveAdjoint, InterpolatingAdjoint) in combination with autojacvec=ZygoteVJP(). However, I could not yet find a working method for the differentiation through the ODE solver (ReverseDiffAdjoint, TrackerAdjoint, `ZygoteAdjoint() all failed). This is especially a problem for DDEs, as there is currently no optimize-then-discretize adjoint implemented there.
Which combinations of AD methods could work here (especially for the DDE case)? Also does it make sense to combine forward and backward AD here and if yes, which libraries work well together?
There’s a lot of details on new automatic differentiation libraries that will be released fairly soon. Just as quick spoilers, the ARPA-E DIFFERENTIATE program (ARPA-E DIFFERENTIATE program) funded three Julia Computing projects which led to massive efforts over the last year on new AD mechanisms. This has led to new compiler tooling in Julia, with the vast majority being set to merge in Julia v1.7, which includes flexible compiler passes to be written from user code. This fixes essentially all of the issues we had with Cassette.jl and IRTools.jl (and thus the issues of Zygote.jl), which is why those libraries are somewhat in maintenance mode. The new AD, Diffractor.jl, will get a full announcement soon (so think of this as just the trailer ) with a full explanation of how these issues were solved and what it’s being currently used and tested on.
With this new AD, there are projects starting up in the Julia Lab which will use the new composable pass structure to add features to the AD, like mutation and MPI support, to solve the issues of integrating AD with scientific computing code (since these issues are distinctly different from machine learning code). We’re also teaming up with people who had solved such issues in C++ and Fortran AD tools before, so that we have the right expertise on the team to do it correctly.
Again this has been a big project with lots of moving parts and it’s not complete yet, but you’ll start to hear announcements on it fairly soon.
Mixing forward and reverse almost always makes sense for higher order, limiting to only one or two reverses. The arguments for that can be found in Griewank’s tome IIRC and it has to do with how the complexity grows.
Thanks a lot for the detailed answer. Those projects sound really interesting and it’s good to hear that AD will be further improved soon. This will make the whole SciML libraries even more amazing In the meantime, I’ll give it a try to implement the adjoint method for DDEs.
Even with the new AD, this will be needed because the missing component cannot be calculated without a derivative rule to catch the discontinuity. The issue to follow is:
An implementation of Enright’s corrections is needed regardless of what AD is used, so it would be much appreciated!
Ok, I do have a first implementation of Enright’s DDE adjoint method, but at the moment it still produces slightly different gradients, as opposed to ForwardDiffSensitivity and ReverseDiffAdjoint. So there must be some small bug in my code. But once that is fixed I could make a post in the above issue on github.
Ok, but isn’t this only a problem for the discontinuities coming from the possibly not C1 transition between initial history and DDE solution? Since in my example that transition should be smooth… Or is AD also having problems with the discontinuities in the adjoint state?