Mixed-mode automatic differentiation using ForwardDiff and ReverseDiff

I need to take the gradient of a function. I can use ForwardDiff without any issues, but for part of my code I have found ReverseDiff to run much faster. The issue is that the other part of my code errors with ReverseDiff. The basic structure of my code looks something like

function take_my_gradient(x)

    tmp = errors_with_reversediff(x)
    faster_with_reversediff(tmp)

end

Is it possible to differentiate my function using ForwardDiff on the part that doesn’t work with ReverseDiff and using ReverseDiff on the other part? (ReverseDiff.jl’s README seems to indicate it is possible.) If so, can someone please explain how?

2 Likes

If anyone has any ideas about this, I would greatly appreciate the help.

Generally, if ReverseDiff (or Zygote) doesn’t handle a portion of a calculation (separated into some function), you should just use ChainRulesCore.jl to define a custom “pullback” (vector–Jacobian product) rule for that function, either with manual differentiation (typically by an adjoint method) or by using some other AD package (though forward-mode AD is not that efficient for pullbacks).

In general, reverse-mode differentiation (a.k.a. backpropagation or adjoint methods) is much faster than forward-mode when you are computing gradients (i.e. the derivative of one input output with respect to many outputs inputs). See also our matrix-calculus course notes.

5 Likes

Reversed?

2 Likes

Whoops, fixed.

Thanks for your response!

This was the direction I needed. It hadn’t crossed my mind to enable mixed-mode AD by defining a custom rrule that used ForwardDiff. I was able to get things working by doing so (though I had to switch to Zygote; I also tried ReverseDiff and Yota but they didn’t work with my code).


As a side note for those interested, using this mixed-mode AD with Zygote and ForwardDiff is actually slower and more memory intensive for my problem (i.e., I get OutOfMemoryErrors for larger problem sizes using mixed-mode AD, even when using a ForwardDiff.Chunk size of 1). So I’ll probably stick with just ForwardDiff for now.

Reverse-mode AD has to store results of all of the intermediate steps of your algorithm in order to “backpropagate” the derivatives, so it is notorious be memory intensive if you are trying to differentiate some kind of iterative calculation.

However, there are often workarounds. For example, if you are using an iterative method to solve some system of equations (an iterative linear or nonlinear solver), you can instead define an “adjoint” pullback rule directly on the solution that completely avoids backpropagating through the iteration. People sometimes also apply techniques like “checkpointing” to trade off computation and memory.

In general, it is extremely helpful to know something about how forward and reverse-mode AD algorithms work in order to use them effectively, and sometimes to know when you should judiciously replace AD with manual derivatives (vector–Jacobian or Jacobian–vector products) for a portion of your code.

2 Likes

I would use Zygote to orchestrate the reverse-mode AD, then define an rrule for the function that I want to use ForwardDiff with and define another rrule for the function that I want to use ReverseDiff with. The time you will spend inside Zygote should be small that way. Just be careful that when you use ForwardDiff for an rrule, you are essentially evaluating the whole Jacobian to define the rrule which depending on the structure of the Jacobian may not be the most efficient thing to do.

A bit of shameless self promotion: I made this video Understanding automatic differentiation (in Julia) - YouTube explaining some of these concepts which you may find useful.

3 Likes

Thanks, I’ve linked to that from our Matrix Calculus course (GitHub - mitmath/matrixcalc: MIT IAP short course: Matrix Calculus for Machine Learning and Beyond).

1 Like

Thanks for the link; I’ll be sure to check it out.