I’m down to 1.5ms, Enzyme has issues with reshape on 1.11 I think (maybe on 1.10 too didn’t try) so I used Mooncake
using Statistics; import Random
using DifferentiationInterface,Tullio,LinearAlgebra,BenchmarkTools
import Zygote, Enzyme, Mooncake
const AV{T} = AbstractVector{T}
const AM{T} = AbstractMatrix{T}
@fastmath logistic_sigmoid(x) = 1.0f0 / (1 + exp(-x))
function loss(
Wmid, Wctx,
nvocab, ndim,
tok_mid, tok_ctx, x
)
Wmid = reshape(Wmid, nvocab, ndim)
Wctx = reshape(Wctx, nvocab, ndim)
nll = 0.0f0
@inbounds @fastmath for c in eachindex(x)
Xij = x[c]
dotprod = dot(@view(Wmid[tok_mid[c],:]),@view(Wctx[tok_ctx[c],:]))
nll += Xij * log(logistic_sigmoid(dotprod))
nll += (1 - Xij) * log(1 - logistic_sigmoid(dotprod))
end
nll = nll / length(x)
nll
end
function train(rng, nvocab::Integer, nsamples::Integer, ndim::Integer, backend)
ntrues = nsamples ÷ 2
tok_mid = rand(rng, 1:nvocab, nsamples)
tok_ctx = rand(rng, 1:nvocab, nsamples)
x = [trues(ntrues); falses(nsamples - ntrues)]
weights_mid = 0.1f0 .* randn(rng, Float32, nvocab * ndim)
weights_ctx = 0.1f0 .* randn(rng, Float32, nvocab * ndim)
# @info "Number of parameters:" size(weights_mid) size(weights_ctx) total=(length(weights_mid) + length(weights_ctx))
# @info loss(weights_mid, weights_ctx, nvocab, ndim, tok_mid, tok_ctx, x)
dweights_mid = Enzyme.make_zero(weights_mid)
# @info "Computing gradient..." backend
gradient!(loss,dweights_mid,backend,weights_mid,Constant(weights_ctx),Constant(nvocab),Constant(ndim),Constant(tok_mid),Constant(tok_ctx),Constant(x))
dweights_mid
end
rng = Random.Xoshiro(5)
bck = AutoMooncake(config=nothing)
@btime train($rng, 100277, 100, 2,$bck )