If you have a few parameters and a lot of functions to differentiate, probably use ForwardDiff (forward-mode differentiation); if you have a lot of parameters and a few functions, probably use Zygote or maybe ReverseDiff (reverse-mode differentiation).
We’re putting out a paper in a few days that will go through quite a few examples of AD package performance and how it differs, but one discussion of this can be found in the following paper:
You might want to watch the talk that explains the results:
Essentially forward mode methods scale like the number of inputs while reverse mode scales like the number of outputs, but in many applications this can look like O(states * parameters) for forward vs O(states + parameters) for reverse. So obviously reverse is better right? Wrong: there are many natural reasons why reverse-mode AD will have a higher baseline overhead.
So if forward-mode AD is faster when problems are small and reverse-mode AD is faster when problems are large, where’s the cutoff? That’s very problem-dependent, and the other paper to be posted soon will show that the given problem can change what AD packages are going to be fast as well. But one thing to look at is the following:
We found that when you had like a size 50 system you’d get the some reverse-mode methods (“based on” Enzyme.jl) would be faster than ForwardDiff.jl, and around 100-150 or so you could get versions of ReverseDiff.jl then at the cutoff. So “roughly 100” is a decent general idea for switching from forward to reverse, depending on the properties of the package.
Also I’m curious about if you observe any “significant” difference in speed among these different packages. Or they are roughly comparable even if one is slightly faster in the different situations as you described?
It’s not just a difference in software. There is a fundamental difference in algorithms and computational scaling between forward and reverse mode AD, as I said, and which one is better depends on the number of inputs vs the number of outputs; google it.
Here’s the paper I mentioned where Appendix B describes how on the same application 4 or 5 different AD mechanisms can be the optimal choice depending on the user inputs.
This paper also conveniently describes AbstractDifferentiation.jl which is a higher level API for using any AD system, which I would recommend for handling this complexity.