This custom Zygote.jl adjoint is not giving me the speed up I expected and how to migrate to GPU?

I tried to use Zygote.jl to define a custom adjoint for Flux.logitbinarycrossentropy but I don’t get the expected speed up when computing the gradient. The context is that the XGBoost tree boosting algorithm requires taking the derivative of the same function at many differents points using.

using Flux: logitbinarycrossentropy
using Zygote: @adjoint
using ForwardDiff: derivative
using CuArrays
CuArrays.allowscalar(false)


prevw = rand(150000);
target = rand([0, 1.], length(prevw));


@adjoint logitbinarycrossentropy(w, t) = logitbinarycrossentropy(w, t), delta -> (delta*(1/(1+exp(-w)) - t), delta)

g(prevw, target) = Zygote.gradient(prevw->logitbinarycrossentropy(prevw, target), prevw)[1]
g2(prevw, target) = Zygote.gradient(logitbinarycrossentropy, prevw, target)[1]
g3(prevw, target) = ForwardDiff.gradient(x->logitbinarycrossentropy(x[1], target), [prevw])[1]
g4(prevw, target) = ForwardDiff.derivative(prevw->logitbinarycrossentropy(prevw, target), prevw)
easy(w, t) = 1/(1+exp(-w)) - t

@benchmark g.($prevw, $target)
@benchmark g2.($prevw, $target)
@benchmark g3.($prevw, $target)
@benchmark g4.($prevw, $target)
@benchmark easy.($prevw, $target)

# do it on the GPU
gtarget = gpu(target);
gprevw = gpu(prevw);
@benchmark easy.($gprevw, $gtarget)

As can be seen, I have defined the adjoint of logitbinarycrossentropy, so I thought
g2 would be quite fast, perhaps approaching the speed of easy which is coding the derivative direclty.

But g2 is slower than g. Of course the GPU version of easy is the fastest, and I can’t get g* to work for any of the g functions.

How do I more effectively use Zygote.@adjoint?

julia> @benchmark g.($prevw, $target)                                                     
BenchmarkTools.Trial:                                                                     
  memory estimate:  53.79 MiB                                                             
  allocs estimate:  2400002                                                               
  --------------                                                                          
  minimum time:     123.166 ms (4.75% GC)                                                 
  median time:      128.427 ms (4.79% GC)                                                 
  mean time:        129.541 ms (5.51% GC)                                                 
  maximum time:     142.943 ms (4.94% GC)                                                 
  --------------                                                                          
  samples:          39                                                                    
  evals/sample:     1                                                                     
                                                                                          
julia> @benchmark g2.($prevw, $target)                                                    
BenchmarkTools.Trial:                                                                     
  memory estimate:  81.25 MiB                                                             
  allocs estimate:  3300002                                                               
  --------------                                                                          
  minimum time:     201.589 ms (4.15% GC)                                                 
  median time:      208.014 ms (5.38% GC)                                                 
  mean time:        212.042 ms (5.00% GC)                                                 
  maximum time:     234.661 ms (4.78% GC)                                                 
  --------------                                                                          
  samples:          24                                                                    
  evals/sample:     1     
2 Likes