To conclude a little here is what we can get with Enzyme and Mooncake on a non-alocating version of the code, Zygote just go boom on my memory here, if someone has a bigger ram (32go) tthey can add it easly
Scenario |
Time(s) AutoEnzyme() |
Time(s) AutoMooncake{Nothing}(nothing) |
Time(s) Torch Autograd |
ndim = 2, nsamples = 50 |
0.00005110 |
0.00006670 |
0.000565 |
ndim = 8, nsamples = 50 |
0.00059380 |
0.00052820 |
0.000748 |
ndim = 32, nsamples = 50 |
0.00256860 |
0.00374450 |
0.006419 |
ndim = 128, nsamples = 50 |
0.01127950 |
0.02048170 |
0.025756 |
ndim = 128, nsamples = 100 |
0.01141160 |
0.02171200 |
0.027593 |
ndim = 128, nsamples = 200 |
0.01139890 |
0.02136800 |
0.024627 |
ndim = 128, nsamples = 1000 |
0.01214850 |
0.02213250 |
0.028803 |
code julia :
using BenchmarkTools,DifferentiationInterface,DifferentiationInterfaceTest
import Random
import Enzyme,Mooncake
@fastmath logistic_sigmoid(x) = 1.0f0 / (1 + exp(-x))
function loss(
Wmid, Wctx,
nvocab, ndim,
tok_mid, tok_ctx, x
)
nll = zero(eltype(Wmid))
for i in eachindex(x)
dotprod = zero(eltype(Wmid))
@inbounds @simd for j in 1:ndim
id1 = (tok_mid[i]-1)*ndim + j
id2 = (tok_ctx[i]-1)*ndim + j
dotprod += Wmid[id1]*Wctx[id2]
end
@inbounds nll += x[i]*log(logistic_sigmoid(dotprod)) + (1-x[i])*log(1-logistic_sigmoid(dotprod))
end
nll / length(x)
end
backends = [AutoEnzyme(),AutoMooncake(config=nothing)]
scenarios = vcat(map([2,8,32,128]) do ndim
nsamples = 50
nvocab = 100277
rng = Random.Xoshiro(5)
ntrues = nsamples ÷ 2
tok_mid = rand(rng, 1:nvocab, nsamples)
tok_ctx = rand(rng, 1:nvocab, nsamples)
x = [trues(ntrues); falses(nsamples - ntrues)]
weights_mid = 0.1f0 .* randn(rng, Float32, nvocab * ndim)
weights_ctx = 0.1f0 .* randn(rng, Float32, nvocab * ndim)
Scenario{:gradient,:out}(loss,weights_mid, contexts =(Constant(weights_ctx),Constant(nvocab), Constant(ndim),Constant(tok_mid), Constant(tok_ctx), Constant(x));name="ndim = $ndim, nsamples = $nsamples")
end,
map([100,200,1000]) do nsamples
ndim = 128
nvocab = 100277
rng = Random.Xoshiro(5)
ntrues = nsamples ÷ 2
tok_mid = rand(rng, 1:nvocab, nsamples)
tok_ctx = rand(rng, 1:nvocab, nsamples)
x = [trues(ntrues); falses(nsamples - ntrues)]
weights_mid = 0.1f0 .* randn(rng, Float32, nvocab * ndim)
weights_ctx = 0.1f0 .* randn(rng, Float32, nvocab * ndim)
Scenario{:gradient,:out}(loss,weights_mid, contexts =(Constant(weights_ctx),Constant(nvocab), Constant(ndim),Constant(tok_mid), Constant(tok_ctx), Constant(x));name="ndim = $ndim, nsamples = $nsamples")
end
)
df = benchmark_differentiation(backends, scenarios);
df_final = filter(row->row.operator == :gradient,df)[!,[:backend,:scenario,:time,:bytes]]
using DataFrames
df_pivot = unstack(df_final, :scenario, :backend, :time)
df_pivot
code python :
import torch
import time
import itertools
# Sigmoid function
def logistic_sigmoid(x):
return 1.0 / (1.0 + torch.exp(-x))
# Vectorized loss function
def loss(Wmid, Wctx, tok_mid, tok_ctx, x, ndim):
# Reshape 1D weights into (nvocab, ndim) for efficient indexing
Wmid = Wmid.view(-1, ndim)
Wctx = Wctx.view(-1, ndim)
# Get the correct embeddings
Wmid_selected = Wmid[tok_mid]
Wctx_selected = Wctx[tok_ctx]
# Compute dot products in a batch
dotprod = torch.sum(Wmid_selected * Wctx_selected, dim=1)
# Compute vectorized loss
nll = torch.mean(x * torch.log(logistic_sigmoid(dotprod)) + (1 - x) * torch.log(1 - logistic_sigmoid(dotprod)))
return nll
# Function to compute gradients
def get_grad(Wmid, Wctx, tok_mid, tok_ctx, x, ndim):
Wmid.requires_grad_(True)
Wctx.requires_grad_(True)
l = loss(Wmid, Wctx, tok_mid, tok_ctx, x, ndim)
l.backward() # Compute gradients
return Wmid.grad # Return gradient of Wmid
# Define problem sizes
ndim_list = [2, 8, 32, 128]
nsamples_dict = {2: [50], 8: [50], 32: [50], 128: [50, 100, 200, 1000]}
# Number of runs for mean calculation
num_runs = 100
# Benchmark storage
results = []
# Run experiments
for ndim in ndim_list:
for nsamples in nsamples_dict[ndim]:
nvocab = 100277
rng = torch.Generator().manual_seed(5)
tok_mid = torch.randint(0, nvocab, (nsamples,), generator=rng)
tok_ctx = torch.randint(0, nvocab, (nsamples,), generator=rng)
x = torch.cat([torch.ones(nsamples // 2), torch.zeros(nsamples // 2)])
Wmid = 0.1 * torch.randn((nvocab * ndim,), dtype=torch.float32, generator=rng)
Wctx = 0.1 * torch.randn((nvocab * ndim,), dtype=torch.float32, generator=rng)
# Run multiple times and take mean time
times = []
for _ in range(num_runs):
start = time.perf_counter()
grad = get_grad(Wmid, Wctx, tok_mid, tok_ctx, x, ndim)
end = time.perf_counter()
times.append(end - start)
avg_time = times[49] # Compute mean execution time
# Store results
results.append((f"ndim = {ndim}, nsamples = {nsamples}", "Torch Autograd", avg_time))
# Print results as a Markdown table
print("\n| Scenario | Backend | Mean Time (s) |")
print("|---------------------------|----------------|---------------|")
for scenario, backend, avg_time in results:
print(f"| {scenario:<25} | {backend:<14} | {avg_time:.6f} |")
cas vectoriel en julia :
Scenario |
Time(s) AutoZygote() |
Time(s) AutoEnzyme() |
Time(s) AutoMooncake{Nothing}(nothing) |
Time(s) Torch Autograd |
ndim = 2, nsamples = 50 |
0.0002035 |
5.63e-05 |
7.66e-05 |
0.000565 |
ndim = 8, nsamples = 50 |
0.0011986 |
0.0005918 |
0.0005545 |
0.000748 |
ndim = 32, nsamples = 50 |
0.0047636 |
0.0023224 |
0.0042134 |
0.006419 |
ndim = 128, nsamples = 50 |
0.0199836 |
0.0100148 |
0.0233116 |
0.025756 |
ndim = 128, nsamples = 100 |
0.0204402 |
0.0112289 |
0.0220511 |
0.027593 |
ndim = 128, nsamples = 200 |
0.0186262 |
0.0110781 |
0.0229095 |
0.024627 |
ndim = 128, nsamples = 1000 |
0.0250179 |
0.0151688 |
0.0368872 |
0.028803 |
code :
using BenchmarkTools,DifferentiationInterface,DifferentiationInterfaceTest,Statistics
import Random
import Enzyme,Mooncake,Zygote
@fastmath logistic_sigmoid(x) = 1.0f0 / (1 + exp(-x))
using LinearAlgebra
# Vectorized loss function
function loss(
Wmid, Wctx,
nvocab, ndim,
tok_mid, tok_ctx, x
)
Wmid2 = reshape(Wmid, nvocab, ndim)
Wctx2 = reshape(Wctx, nvocab, ndim)
Wmid_selected = @view(Wmid2[tok_mid, :])
Wctx_selected = @view(Wctx2[tok_ctx, :])
dotprod = sum(Wmid_selected .* Wctx_selected, dims=2)
nll = mean(x .* log.(logistic_sigmoid.(dotprod)) .+ (1 .- x) .* log.(1 .- logistic_sigmoid.(dotprod)))
return nll
end
backends = [AutoZygote(),AutoEnzyme(),AutoMooncake(config=nothing)]
scenarios = vcat(map([2,8,32,128]) do ndim
nsamples = 50
nvocab = 100277
rng = Random.Xoshiro(5)
ntrues = nsamples ÷ 2
tok_mid = rand(rng, 1:nvocab, nsamples)
tok_ctx = rand(rng, 1:nvocab, nsamples)
x = [trues(ntrues); falses(nsamples - ntrues)]
weights_mid = 0.1f0 .* randn(rng, Float32, nvocab * ndim)
weights_ctx = 0.1f0 .* randn(rng, Float32, nvocab * ndim)
Scenario{:gradient,:out}(loss,weights_mid, contexts =(Constant(weights_ctx),Constant(nvocab), Constant(ndim),Constant(tok_mid), Constant(tok_ctx), Constant(x));name="ndim = $ndim, nsamples = $nsamples")
end,
map([100,200,1000]) do nsamples
ndim = 128
nvocab = 100277
rng = Random.Xoshiro(5)
ntrues = nsamples ÷ 2
tok_mid = rand(rng, 1:nvocab, nsamples)
tok_ctx = rand(rng, 1:nvocab, nsamples)
x = [trues(ntrues); falses(nsamples - ntrues)]
weights_mid = 0.1f0 .* randn(rng, Float32, nvocab * ndim)
weights_ctx = 0.1f0 .* randn(rng, Float32, nvocab * ndim)
Scenario{:gradient,:out}(loss,weights_mid, contexts =(Constant(weights_ctx),Constant(nvocab), Constant(ndim),Constant(tok_mid), Constant(tok_ctx), Constant(x));name="ndim = $ndim, nsamples = $nsamples")
end
)
df = benchmark_differentiation(backends, scenarios);
df_final = filter(row->row.operator == :gradient,df)[!,[:backend,:scenario,:time,:bytes]]
using DataFrames
df_pivot = unstack(df_final, :scenario, :backend, :time)
df_pivot