Zygote is >50 times slower than the numerical derivative, ForwardDiff is 2 times faster.
I know that Zygote is reverse mode, but is that the reason of the performance difference? Can the Zygote call be improved?
Notice that I only need the derivatives of rs with respect to x, I do not need to compute the derivatives with respect to d. Is there any way to tell that to Zygote? And would that matter? Using the anonymous function was just my quick failed attempt.
on the function performance, can you replace sum(d .* d0) for dot(d,d0) ? (you have to use using LinearAlgebra first).
on the performance of AD, as a general rule of thumb, forward AD is faster than reverse AD at small sizes
I was using dot(a,b) before, identical benchmarks as sum( a .* b) in this simple case.
I just left sum() to avoid another dependency for the MWE.
I’m just really curious about the expected behavior of Zygote in this case and the best way to ignore some inputs. To make sure I’m using the packages at their max potential.
I wonder if the culprit isn’t what inputs are used, but that Zygote un-fuses broadcasts (hence why it isn’t ideal for this kind of highly scalarized code).
There’s an algorithmic thing here too. The best case for reverse mode (like Zygote) is many parameters and scalar output. Then it does 1 reverse pass of (ideally) comparable difficulty to the original function. The best case for forward mode is what you have here (if I read this correctly), one scalar leading to a vector output. Again it does the original work plus tracking this one perturbation forwards.
For many outputs, Zygote needs a whole reverse pass per element. So the completely ideal expectation would be that drs_auto and drs_auto_v2 are 20 times slower than drs_auto_v3. In addition reverse mode is just more complicated, which could well be the remaining factor of 5.
drs_auto_v2 won’t save much – Zygote will still work backwards most of the way,
help?> Zygote.jacobian
jacobian(f, args...) -> Tuple
...
This reverse-mode Jacobian needs to evaluate the pullback once for each element of y. Doing so
is usually only efficient when length(y) is small compared to length(a), otherwise forward mode
is likely to be better.