I tried to use Zygote.jl
to define a custom adjoint for Flux.logitbinarycrossentropy
but I don’t get the expected speed up when computing the gradient. The context is that the XGBoost tree boosting algorithm requires taking the derivative of the same function at many differents points using.
using Flux: logitbinarycrossentropy
using Zygote: @adjoint
using ForwardDiff: derivative
using CuArrays
CuArrays.allowscalar(false)
prevw = rand(150000);
target = rand([0, 1.], length(prevw));
@adjoint logitbinarycrossentropy(w, t) = logitbinarycrossentropy(w, t), delta -> (delta*(1/(1+exp(-w)) - t), delta)
g(prevw, target) = Zygote.gradient(prevw->logitbinarycrossentropy(prevw, target), prevw)[1]
g2(prevw, target) = Zygote.gradient(logitbinarycrossentropy, prevw, target)[1]
g3(prevw, target) = ForwardDiff.gradient(x->logitbinarycrossentropy(x[1], target), [prevw])[1]
g4(prevw, target) = ForwardDiff.derivative(prevw->logitbinarycrossentropy(prevw, target), prevw)
easy(w, t) = 1/(1+exp(-w)) - t
@benchmark g.($prevw, $target)
@benchmark g2.($prevw, $target)
@benchmark g3.($prevw, $target)
@benchmark g4.($prevw, $target)
@benchmark easy.($prevw, $target)
# do it on the GPU
gtarget = gpu(target);
gprevw = gpu(prevw);
@benchmark easy.($gprevw, $gtarget)
As can be seen, I have defined the adjoint
of logitbinarycrossentropy
, so I thought
g2
would be quite fast, perhaps approaching the speed of easy
which is coding the derivative direclty.
But g2
is slower than g
. Of course the GPU version of easy
is the fastest, and I can’t get g*
to work for any of the g
functions.
How do I more effectively use Zygote.@adjoint
?
julia> @benchmark g.($prevw, $target)
BenchmarkTools.Trial:
memory estimate: 53.79 MiB
allocs estimate: 2400002
--------------
minimum time: 123.166 ms (4.75% GC)
median time: 128.427 ms (4.79% GC)
mean time: 129.541 ms (5.51% GC)
maximum time: 142.943 ms (4.94% GC)
--------------
samples: 39
evals/sample: 1
julia> @benchmark g2.($prevw, $target)
BenchmarkTools.Trial:
memory estimate: 81.25 MiB
allocs estimate: 3300002
--------------
minimum time: 201.589 ms (4.15% GC)
median time: 208.014 ms (5.38% GC)
mean time: 212.042 ms (5.00% GC)
maximum time: 234.661 ms (4.78% GC)
--------------
samples: 24
evals/sample: 1