`Zygote.gradient` is 54000 TIMES slower than `jax.gradient`

I haven’t checked this carefully, but I think that we-writing this not to make slices at all can get us to 2ms:

julia> function loss(
               Wmid, Wctx,
               tok_mid, tok_ctx, x
       )
       tmp = sum(Wmid[tok_mid, :] .* Wctx[tok_ctx, :]; dims=2) |> vec
       -mean(@. x * log(logistic_sigmoid(tmp)) + (1 - x) * log(1 - logistic_sigmoid(tmp)))
       end
loss (generic function with 1 method)

julia> grad = @btime train(Random.Xoshiro(5), 100277, 100, 2);
  1.892 ms (274 allocations: 4.62 MiB)

julia> using Tullio  # first way I thought of

julia> function loss(
               Wmid, Wctx,
               tok_mid, tok_ctx, x
       )
       @tullio tmp[k] := Wmid[tok_mid[k], c] * Wctx[tok_ctx[k], c]  # sum over c
       -mean(@. x * log(logistic_sigmoid(tmp)) + (1 - x) * log(1 - logistic_sigmoid(tmp)))  # sum over k
       end
loss (generic function with 1 method)

julia> grad = @btime train(Random.Xoshiro(5), 100277, 100, 2);
  1.916 ms (348 allocations: 4.62 MiB)

All definitions of loss give me zero gradient, so there could be mistakes.

1 Like