State of reverse mode AD tools

(Opening a new topic instead of resurrecting some older ones.)

I am wondering about the current state of libraries for reverse mode AD, specifically for differentiating \mathbb{R}^n \to \mathbb{R} functions that

  1. accept an AbstractVector and return something Real,
  2. may contain branches (in my experience, everything nontrivial does),
  3. n is too large for a single SVector (think 1-6k), but the input is broken up to smaller pieces (groups of parameters, hyperparameters, etc) that may contain SVectors.

While I am following developments with interest, I am mainly interested in libraries that work, or can be made to work with a reasonable effort (eg reporting an issue and getting an answer within a few days with a solution or hints on how to make a PR) at the moment.

My experience is the following:

  1. ForwardDiff is of course forward mode, but super-robust and a good fallback. Surprisingly, it is quite competitive for large n, with some tricks (chunk size, etc, detailed in the manual).
  2. ReverseDiff is reliable, but pre-compiling the tape quickly gets one in trouble unless the code path is 100%-and-I-mean-it deterministic. Recreating a tape each time slows it down.
  3. Flux is reliable, but a bit more picky about code being really generic.
  4. Nabla requires that code works with types which are not necessarily <: AbstractVector, this makes it difficult to use.
  5. Zygote is fast when it works, but I found that breaking issues are not fixed for months, so in practice it is not usable.

I have not tried Yota, or the other libraries. I am curious to hear stories from AD users — maybe I missed something obvious.

16 Likes

Can you provide a couple of examples of branching you meet in practice? I have a couple of ideas how to add support for dynamic graphs to Yota (albeit with lower performance), but without real use cases they might be waste of time.

My experience is that it is very difficult not to run into (insidious, well-hidden) branches for any sufficiently complex and nonlinear calculation of otherwise (mathematically) continuous functions.

A simple example would be anything calling StatsFuns.log1pexp.

Even if branches are not handled, it would be great if the user got an error when a path that is different from the taped one is taken. Perhaps this can be done by turning branches into assertions that throw a BranchOutsideTapeError, which the user could catch and recompile the tape.

Of course the other workaround is to code AD primitives for all of these functions. But they are so easy to miss in practice, so the error above would be better than giving an incorrect result silently.

My goal is to have automatic differentiation support using multivariable dual numbers with Grassmann.jl

Once that’s implemented, things will get much more interesting for me…

Although I don’t work with AbstractVector but TensorAlgebra instead, and instead of Real I have a separate basis for scalar values (from which you get real values contained inside)

1 Like

Can you point me to your issue / bump it on github? We’re definitely still in a beta stage so can’t make any promises, but if I can often prioritise stuff that’s blocking people’s work.

1 Like

Your link has a typo, “Gassmann”

1 Like

It is

1 Like

Thanks. This will be next in the queue when I’m looking through Zygote issues next.

2 Likes

Thanks. I have an open PR that I could merge for LogDensityProblems.jl once this is fixed, which would allow the use of Zygote in MCMC.