Why can forwarddiff get gradient wrt. to multiple parameters in 1 function call?

I wonder how ForwardDiff.jl can compute the gradient wrt. to multiple parameters of a function my_func(a,b,c,d,...) in one function evaluation? According to Giles, Michael B. “Monte Carlo evaluation of sensitivities in computational finance.” (2007) page. 9 link

If one is interested in the sensitivity to N_I different input elements, then \dot{\mathbf{u}}^{N} = D^N D^{N-1} \dots D^2 D^1 \dot{\mathbf{u}}^{0} must be evaluated for each one, at a cost which is proportional to N_I .

Is there something I am not getting here?

ForwardDiff propagates sensitivities wrt all the parameters, so while it’s technically a single function call, inside this function call is the equivalent of many function calls. To compute gradients in about the same time as a single function call (in theory) for functions with scalar output, use ReverseDiff.

There are many resources on how forward and reverse differentiation work, but I was confused about it for a long time before I found a way of thinking about it that spoke to me: for a function f with jacobian Jf computed as the composition of many functions (say f = f1 o f2 for simplicity), forward-mode differentiation is a way of efficiently computing J*(delta_x) for any given delta_x. By the chain rule, Jf*delta_x = Jf1*(Jf2*delta_x), and so all you have to do is propagate the sensitivities, taking about the same time as calling f. For a function R^n → R though, you have to do this for N different delta_x to build the gradient. Instead, reverse-mode differentiation is a way of computing Jf^T * (delta_f) for any given delta_f, again via the chain rule, but this time transposed: J^T delta_y = Jf2^T*(Jf1^T * delta_f). Now it’s enough to do it only once to compute a gradient, which is why it’s much more efficient for gradients. However, because the transpose changes the order of application of the inner functions, you have to go through the computation tree in reverse, which can cause overhead.

6 Likes

Solid answer! Thanks. I just tried some simple printing tests to see how many times my function was called, and stuff was only printed once so I thought there was only one function call. But apparently there are implicitly more function calls I guess.

My function has a scalar output and the input is only three parameters. However, the output is a result of a Monte Carlo simulation including a lot of operations. As such, is it likely that the reverse mode will give any speed-up, as one has to store and go through the computation tree in reverse?

ForwardDiff only calls your function once, but for each operation within your function [eg, sin(x)] it also calculates the derivative [cos(x)*partial]. So your function runs once, but each operation inside results in also evaluating each of the derivatives.

You can see here to see a list of many of the rules ForwardDiff uses to calculate derivatives here:
https://github.com/JuliaDiff/DiffBase.jl/blob/master/src/rules.jl

No, 3 parameters is likely not enough to make ReverseDiff faster than ForwardDiff (but the only way to be sure is to try both and compare them).

I see, that is what I thought. I tried ReverseDiff, but I got some errors saying that my function did not support the TrackedArray as input…, although ForwardDiff works fine