`Zygote.gradient` is 54000 TIMES slower than `jax.gradient`

I’m down to 1.5ms, Enzyme has issues with reshape on 1.11 I think (maybe on 1.10 too didn’t try) so I used Mooncake

using Statistics; import Random
using DifferentiationInterface,Tullio,LinearAlgebra,BenchmarkTools
import Zygote, Enzyme, Mooncake

const AV{T} = AbstractVector{T}
const AM{T} = AbstractMatrix{T}

@fastmath logistic_sigmoid(x) = 1.0f0 / (1 + exp(-x))

function loss(
	Wmid, Wctx,
	nvocab, ndim,
	tok_mid, tok_ctx, x
)

    Wmid = reshape(Wmid, nvocab, ndim)
    Wctx = reshape(Wctx, nvocab, ndim)
    nll = 0.0f0
    @inbounds @fastmath for c in eachindex(x)
        Xij = x[c]
        dotprod = dot(@view(Wmid[tok_mid[c],:]),@view(Wctx[tok_ctx[c],:])) 
        nll += Xij * log(logistic_sigmoid(dotprod))
        nll += (1 - Xij) * log(1 - logistic_sigmoid(dotprod))
    end
    nll = nll / length(x)
	nll
end

function train(rng, nvocab::Integer, nsamples::Integer, ndim::Integer, backend)
	ntrues = nsamples ÷ 2

	tok_mid = rand(rng, 1:nvocab, nsamples)
	tok_ctx = rand(rng, 1:nvocab, nsamples)
	x = [trues(ntrues); falses(nsamples - ntrues)]

	weights_mid = 0.1f0 .* randn(rng, Float32, nvocab * ndim)
	weights_ctx = 0.1f0 .* randn(rng, Float32, nvocab * ndim)
	# @info "Number of parameters:" size(weights_mid) size(weights_ctx) total=(length(weights_mid) + length(weights_ctx))

    # @info loss(weights_mid, weights_ctx, nvocab, ndim, tok_mid, tok_ctx, x)

	dweights_mid = Enzyme.make_zero(weights_mid)
	# @info "Computing gradient..." backend
    gradient!(loss,dweights_mid,backend,weights_mid,Constant(weights_ctx),Constant(nvocab),Constant(ndim),Constant(tok_mid),Constant(tok_ctx),Constant(x))
	dweights_mid
end

rng = Random.Xoshiro(5)
bck = AutoMooncake(config=nothing)
@btime train($rng, 100277, 100, 2,$bck )
3 Likes