I need to take the gradient of a function. I can use ForwardDiff without any issues, but for part of my code I have found ReverseDiff to run much faster. The issue is that the other part of my code errors with ReverseDiff. The basic structure of my code looks something like
tmp = errors_with_reversediff(x)
Is it possible to differentiate my function using ForwardDiff on the part that doesn’t work with ReverseDiff and using ReverseDiff on the other part? (ReverseDiff.jl’s README seems to indicate it is possible.) If so, can someone please explain how?
Generally, if ReverseDiff (or Zygote) doesn’t handle a portion of a calculation (separated into some function), you should just use ChainRulesCore.jl to define a custom “pullback” (vector–Jacobian product) rule for that function, either with manual differentiation (typically by an adjoint method) or by using some other AD package (though forward-mode AD is not that efficient for pullbacks).
In general, reverse-mode differentiation (a.k.a. backpropagation or adjoint methods) is much faster than forward-mode when you are computing gradients (i.e. the derivative of one input output with respect to many outputs inputs). See also our matrix-calculus course notes.
This was the direction I needed. It hadn’t crossed my mind to enable mixed-mode AD by defining a custom rrule that used ForwardDiff. I was able to get things working by doing so (though I had to switch to Zygote; I also tried ReverseDiff and Yota but they didn’t work with my code).
As a side note for those interested, using this mixed-mode AD with Zygote and ForwardDiff is actually slower and more memory intensive for my problem (i.e., I get OutOfMemoryErrors for larger problem sizes using mixed-mode AD, even when using a ForwardDiff.Chunk size of 1). So I’ll probably stick with just ForwardDiff for now.
Reverse-mode AD has to store results of all of the intermediate steps of your algorithm in order to “backpropagate” the derivatives, so it is notorious be memory intensive if you are trying to differentiate some kind of iterative calculation.
However, there are often workarounds. For example, if you are using an iterative method to solve some system of equations (an iterative linear or nonlinear solver), you can instead define an “adjoint” pullback rule directly on the solution that completely avoids backpropagating through the iteration. People sometimes also apply techniques like “checkpointing” to trade off computation and memory.
In general, it is extremely helpful to know something about how forward and reverse-mode AD algorithms work in order to use them effectively, and sometimes to know when you should judiciously replace AD with manual derivatives (vector–Jacobian or Jacobian–vector products) for a portion of your code.
I would use Zygote to orchestrate the reverse-mode AD, then define an rrule for the function that I want to use ForwardDiff with and define another rrule for the function that I want to use ReverseDiff with. The time you will spend inside Zygote should be small that way. Just be careful that when you use ForwardDiff for an rrule, you are essentially evaluating the whole Jacobian to define the rrule which depending on the structure of the Jacobian may not be the most efficient thing to do.