`Zygote.gradient` is 54000 TIMES slower than `jax.gradient`

It would look like this,

using BenchmarkTools,Enzyme,Statistics,Reactant,Random
Reactant.set_default_backend("cpu")
@fastmath logistic_sigmoid(x) = 1.0f0 / (1 + exp(-x))
# Vectorized loss function
function loss(
    Wmid, Wctx,
    nvocab, ndim,
    tok_mid, tok_ctx, x
)
    Wmid2 = reshape(Wmid, nvocab, ndim)
    Wctx2 = reshape(Wctx, nvocab, ndim)
    Wmid_selected = @view(Wmid2[tok_mid, :])
    Wctx_selected = @view(Wctx2[tok_ctx, :])
    dotprod = sum(Wmid_selected .* Wctx_selected, dims=2)
    nll = mean(x .* log.(logistic_sigmoid.(dotprod)) .+ (1 .- x) .* log.(1 .- logistic_sigmoid.(dotprod)))
    return nll
end

function get_grad(Wmid, Wctx, nvocab, ndim, tok_mid, tok_ctx, x)
	dWMid = Enzyme.make_zero(Wmid)
	dWctx = Enzyme.make_zero(Wctx)
	autodiff(Enzyme.Reverse, loss, Duplicated(Wmid,dWMid), DuplicatedNoNeed(Wctx,dWctx), Const(nvocab), Const(ndim), Const(tok_mid), Const(tok_ctx), Const(x))
	return dWMid
end
dev= Reactant.to_rarray
ndim = 128
nsamples = 1000
nvocab = 100277
rng = Random.Xoshiro(5)
ntrues = nsamples ÷ 2
tok_mid = rand(rng, 1:nvocab, nsamples)
tok_ctx = rand(rng, 1:nvocab, nsamples) 
x = ([trues(ntrues); falses(nsamples - ntrues)] .|> Float32)  |> dev
weights_mid = 0.1f0 .* randn(rng, Float32, nvocab * ndim) |> dev
weights_ctx = 0.1f0 .* randn(rng, Float32, nvocab * ndim) |> dev
L = @compile loss(weights_mid,weights_ctx,nvocab,ndim,tok_mid,tok_ctx,x)
@info L(weights_mid,weights_ctx,nvocab,ndim,tok_mid,tok_ctx,x) # works fine

g = @compile get_grad(weights_mid,weights_ctx,nvocab,ndim,tok_mid,tok_ctx,x) # error
@info g(weights_mid,weights_ctx,nvocab,ndim,tok_mid,tok_ctx,x)
@btime g($weights_mid,$weights_ctx,$nvocab,$ndim,$tok_mid,$tok_ctx,$x)

but it errors with

error: could not compute the adjoint for this operation %8 = "stablehlo.gather"(%arg1, %1) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1>}> : (tensor<12835456xf32>, tensor<128000x1xi64>) -> tensor<128000x1xf32>
error: could not compute the adjoint for this operation %6 = "stablehlo.gather"(%arg0, %5) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1>}> : (tensor<12835456xf32>, tensor<128000x1xi64>) -> tensor<128000x1xf32>
ERROR: "failed to run pass manager on module"

I guess something is missing on reactant to do that ? It does compile the loss but not the gradient. Maybe @wsmoses can help ? I made an mwe and an issue Compiling gradient of slicing function fails · Issue #676 · EnzymeAD/Reactant.jl · GitHub