I don’t think you need to vectorize, and for this example the norm(x)
makes a huge difference. Running your original example on my machine for N=200
I got 156 ms for ForwardDiff
, and slower 782 ms for ReverseDiff
. This seems consistent with your experience.
Like @roflmaostc I moved the norm out of the outer loop, but kept the inner loop:
function ff2(x)
res = 0.0
nmx = norm(x)
for i in 1:100
for j in eachindex(x)
res += sin(exp(x[j])-nmx)
end
end
atan(res)
end
@btime ForwardDiff.gradient!($res, $ff2, $y); # 9.764 ms
ftape2 = ReverseDiff.GradientTape(ff2, y)
compiled_ftape2 = ReverseDiff.compile(ftape2)
@btime ReverseDiff.gradient!($res, $compiled_ftape2, $y) # 15.827 ms
Now forward runs 16x faster, reverse 49x faster than before (but still 50% slower then the new forward), all just because of the norm
. I did not find the @roflmaostc’s ff_vec2
to make much difference, about comparable in timing.
I am very inexperienced with this, but here’s my impression. This example is quite simple, and so norm
clearly outweighs the other computations. The loop operations are computed 100N times, but only the norm
has N terms itself, so it dominates with large N. That single operation is trivial whether forward or backwards diffing, but with some overhead backwards. Also, sin
and exp
are trivial to diff.
So I don’t think the issue is vectorization, and it’s up to you whether the norm
here is representative of your real problem. But I believe backpropagation gains most when you have a lot of graph nodes where there is opportunity to re-use pieces of the forward pass, or gradients in the backward pass. Maybe this example is too trivial to get good re-use, relative to the overhead, which was not so bad (~50% worse performance). I don’t know enough to comment on potential gains from stuff like @simd
, @inbounds
, or other optimizations.
[EDIT:] I take it back about ff_vec2
. I ran it again with N=200
and got 13.2 ms forwards and 1.47 ms backwards, so it seems backwards is winning with vectorization, the norm computed once, and larger N
. I don’t know why my ff2
is faster than ff_vec2
forwards. But backwards, it seems vectorization is a big help.