`Zygote.gradient` is 54000 TIMES slower than `jax.gradient`

This should work:

using Statistics
import Random
using LinearAlgebra, BenchmarkTools
import Enzyme

const AV{T} = AbstractVector{T}
const AM{T} = AbstractMatrix{T}
logistic_sigmoid(x) = 1.0f0 / (1 + exp(-x))

function loss(
    Wmid, Wctx,
    nvocab, ndim,
    tok_mid, tok_ctx, x
)

    Wmid_reshaped = reshape(Wmid, nvocab, ndim)
    Wctx_reshaped = reshape(Wctx, nvocab, ndim)
    nll::Float32 = 0.0f0
    @inbounds for c in eachindex(x)
        Xij = x[c]
        dotprod = dot(@view(Wmid_reshaped[tok_mid[c], :]), @view(Wctx_reshaped[tok_ctx[c], :]))
        nll += Xij * log(logistic_sigmoid(dotprod))
        nll += (1 - Xij) * log(1 - logistic_sigmoid(dotprod))
    end
    nll = nll / length(x)
    nll
end

function train(rng, nvocab::Integer, nsamples::Integer, ndim::Integer)
    ntrues = nsamples ÷ 2

    tok_mid = rand(rng, 1:nvocab, nsamples)
    tok_ctx = rand(rng, 1:nvocab, nsamples)
    x = [trues(ntrues); falses(nsamples - ntrues)]

    weights_mid = 0.1f0 .* randn(rng, Float32, nvocab * ndim)
    weights_ctx = 0.1f0 .* randn(rng, Float32, nvocab * ndim)

    dweights_mid = Enzyme.make_zero(weights_mid)

    Enzyme.autodiff(
        Enzyme.set_runtime_activity(Enzyme.Reverse),
        loss,
        Enzyme.Duplicated(weights_mid, dweights_mid),
        Enzyme.Const(weights_ctx),
        Enzyme.Const(nvocab),
        Enzyme.Const(ndim),
        Enzyme.Const(tok_mid),
        Enzyme.Const(tok_ctx),
        Enzyme.Const(x),
    )
    dweights_mid
end

rng = Random.Xoshiro(5)
@btime train($rng, 100277, 100, 2)

If you’re using Julia 1.10, this should also work and perhaps be a hair faster:

Julia 1.10 version
using Statistics
import Random
using LinearAlgebra, BenchmarkTools
import Enzyme

const AV{T} = AbstractVector{T}
const AM{T} = AbstractMatrix{T}
logistic_sigmoid(x) = 1.0f0 / (1 + exp(-x))

function loss(
    Wmid, Wctx,
    nvocab, ndim,
    tok_mid, tok_ctx, x
)

    Wmid_reshaped = reshape(Wmid, nvocab, ndim)
    Wctx_reshaped = reshape(Wctx, nvocab, ndim)
    nll::Float32 = 0.0f0
    @inbounds for c in eachindex(x)
        Xij = x[c]
        dotprod = dot(@view(Wmid_reshaped[tok_mid[c], :]), @view(Wctx_reshaped[tok_ctx[c], :]))
        nll += Xij * log(logistic_sigmoid(dotprod))
        nll += (1 - Xij) * log(1 - logistic_sigmoid(dotprod))
    end
    nll = nll / length(x)
    nll
end

function train(rng, nvocab::Integer, nsamples::Integer, ndim::Integer)
    ntrues = nsamples ÷ 2

    tok_mid = rand(rng, 1:nvocab, nsamples)
    tok_ctx = rand(rng, 1:nvocab, nsamples)
    x = [trues(ntrues); falses(nsamples - ntrues)]

    weights_mid = 0.1f0 .* randn(rng, Float32, nvocab * ndim)
    weights_ctx = 0.1f0 .* randn(rng, Float32, nvocab * ndim)

    dweights_mid = Enzyme.make_zero(weights_mid)

    Enzyme.autodiff(
        Enzyme.Reverse,
        loss,
        Enzyme.Duplicated(weights_mid, dweights_mid),
        Enzyme.Const(weights_ctx),
        Enzyme.Const(nvocab),
        Enzyme.Const(ndim),
        Enzyme.Const(tok_mid),
        Enzyme.Const(tok_ctx),
        Enzyme.Const(x),
    )
    dweights_mid
end

rng = Random.Xoshiro(5)
@btime train($rng, 100277, 100, 2)

The only difference is that Enzyme.set_runtime_activity(Enzyme.Reverse), which always works, is replaced with simply Enzyme.Reverse in the 1.10-only version.

The code here worked fine for me on v1.11 and ran fast:

2 hours later, benchmark time again! Now I use my original loss function with the loop.

Main results:

  • Zygote’s TTFG (Time To First Gradient) depends on the number of loops in the objective (or the number of indexing operations). Worst I observed is 20 minutes for 200 observations.
  • Mooncake’s TTFG doesn’t depend on anything: same 72 seconds everywhere. Which is still a lot, but much better than Zygote.

128-dim vectors, varying number of observations (loop iterations)

Backend nobs TTFG Other grads
Zygote 10 18.553415 s 1.701 s
Zygote 50 116.313049 s 7.636 s
Zygote 100 482.955427 s 15.152 s
Zygote 200 1204.946515 s 30.213 s
Mooncake 10 72.063802 s 539.833 ms
Mooncake 50 72.163969 s 560.456 ms
Mooncake 100 72.083851 s 543.928 ms
Mooncake 200 71.951032 s 559.037 ms
JAX 50 0.602 s 181.0 ms
JAX 100 0.623 s 181.3 ms
JAX 200 0.614 s 181.08 ms
JAX 1000 0.612 s 183.109 ms
  • Zygote’s TTFG increases faster than the number of loops in the objective.
  • Zygote’s runtime after compilation increases at the same rate as the number of loops.
  • Mooncake’s TTFG doesn’t depend on the number of loops.
  • Mooncake’s timings in general don’t depend on the number of loops.
  • JAX is fast as usual. TTFG (I include JIT and jax.block_until_ready here!) is nonexistent. Also uses multithreading automatically. All Julia benchmarks here were ran with julia --threads=auto, but I didn’t see any attempts at multithreading on the CPU monitor.

Varying vector dimensionality, 100 iterations

Backend ndim TTFG Other grads
Zygote 2 466.198646 s 83.078 ms
Zygote 8 462.867302 s 300.332 ms
Zygote 32 471.075816 s 1.155 s
Zygote 128 482.955427 s 15.152 s
Mooncake 2 71.578484 s 3.598 ms
Mooncake 8 71.606989 s 16.902 ms
Mooncake 32 71.323625 s 76.016 ms
Mooncake 128 72.083851 s 543.928 ms
  • Startup times are stable (Zygote’s increase a bit).
  • Zygote’s overall TTFG is 6.4 times worse than Mooncake’s.
  • Avg gradient time increases, first at the same rate as dimensionality, then x13 for Zygote and x7 for Mooncake (expected x4).
Full code
using Statistics, LinearAlgebra; import Random
using DifferentiationInterface
import Zygote, Mooncake#, Enzyme

const AV{T} = AbstractVector{T}
const AM{T} = AbstractMatrix{T}

logistic_sigmoid(x::Real) = 1 / (1 + exp(-x))

loss_original(
	Wmid::AM{<:Real}, Wctx::AM{<:Real},
	tok_mid::AV{<:Integer}, tok_ctx::AV{<:Integer}, x::AV{Bool}
) = -mean(
	let
		pij = @views logistic_sigmoid(dot(Wmid[i, :], Wctx[j, :]))
		Xij * log(pij) + (1 - Xij) * log(1 - pij)
	end
	for (i, j, Xij) in zip(tok_mid, tok_ctx, x)
)

function train(rng, loss, nvocab::Integer, nsamples::Integer, ndim::Integer, backend, quiet::Bool)
	ntrues = nsamples ÷ 2

	tok_mid = rand(rng, 1:nvocab, nsamples)
	tok_ctx = rand(rng, 1:nvocab, nsamples)
	x = [trues(ntrues); falses(nsamples - ntrues)]

	weights_mid = 0.1f0 .* randn(rng, Float32, nvocab, ndim)
	weights_ctx = 0.1f0 .* randn(rng, Float32, nvocab, ndim)
	quiet || @info "Number of parameters:" size(weights_mid) size(weights_ctx) total=(length(weights_mid) + length(weights_ctx))

	dweights_mid = similar(weights_mid)
	quiet || @info "Computing gradient..." backend
	if !quiet
		@timev gradient!(
			loss,
			dweights_mid, backend,
			weights_mid, Constant(weights_ctx),
			Constant(tok_mid), Constant(tok_ctx), Constant(x)
		)
	else
		gradient!(
			loss,
			dweights_mid, backend,
			weights_mid, Constant(weights_ctx),
			Constant(tok_mid), Constant(tok_ctx), Constant(x)
		)
	end
	dweights_mid
end

using BenchmarkTools

function main(loss, ndim::Integer, nobs::Integer)
	for backend in [AutoZygote(), AutoMooncake(config=nothing)]
		@info "CODE: $loss; ndim: $ndim; nobs: $nobs; backend: $backend"
		grad = train(Random.Xoshiro(5), loss, 100277, nobs, ndim, backend, false)
		@info "Gradient info:" size(grad) sum(grad)

		println(mean(
			@benchmark let
				train(Random.Xoshiro(5), $loss, 100277, $nobs, $ndim, $backend, true)
			end
		))

		println("\n\n")
	end
end

main(loss_original, 128, 50)

Versions

  • Julia v1.11.3
  • Zygote v0.7.3
  • Mooncake v0.4.78
  • DifferentiationInterface v0.6.36
6 Likes

(post deleted by author)

Oh wow Thank you, Zygote is really having a bad day on this function, I tried too and it didn’t fit my memory so I forgot about Zygote, but for Mooncake and Enzyme, it’s going quite nice.

Backend ndim TTFG Other grads
Mooncake 2 24.539188s 2.201 ms
Mooncake 8 28.220113 s 11.849 ms
Mooncake 32 27.454269 s 61.279 ms
Mooncake 128 28.095298 s 257.074 ms
Enzyme 2 7.698519 s 1.296 ms
Enzyme 8 8.057153 s 7.078 ms
Enzyme 32 8.169446 s 27.356 ms
Enzyme 128 9.084161 s 122.082 ms

with the fastest version I could get of your function,

@fastmath logistic_sigmoid(x) = 1.0f0 / (1 + exp(-x))
function loss(
	Wmid, Wctx,
	nvocab, ndim,
	tok_mid, tok_ctx, x
)
    nll = zero(eltype(Wmid))
    for i in eachindex(x)
        dotprod = zero(eltype(Wmid))
        @inbounds for j in 1:ndim
            id1 = (tok_mid[i]-1)*ndim + j
            id2 = (tok_ctx[i]-1)*ndim + j
            dotprod += Wmid[id1]*Wctx[id2]
        end
        @inbounds nll += x[i]*log(logistic_sigmoid(dotprod)) + (1-x[i])*log(1-logistic_sigmoid(dotprod))
    end
    nll / length(x)
end

could do better with simd I think but not sure. So we got our answer, TTFG depends on the input size only for Zygote for a very obscour reason but since it messes with the compiler we will never know why. Oh and Enzyme is a beast.

3 Likes

LMAO. I saw Zygote quickly consume 16 gigs of RAM when computing the first gradient for the IRL model with 100K observations. I’m sorry, is this what we call “21st century AD” nowadays? Sorry, I don’t want to be disrespectful, I’m sure a lot of work went into it and people are actively working on it (2.5K commits on GitHub, a lot of very recent ones), but 16GB of RAM for one gradient!! I get it, different ADs work well in different scenarios (`Zygote.gradient` is 54000 TIMES slower than `jax.gradient` - #56 by ToucheSir).


But can Enzyme or Mooncake tackle 1000 observations? JAX certainly can: 0.6 seconds TTFG and 183 ms average, see updated table in my comment. That’s on my PC which is 45 times slower than @mcabbott’s laptop on battery.

Yes, Zygote has its issues, what’s saving is its ease on the rule system. For 128 ndim and 1000 obs, I get Mooncake at 325ms and Enzyme at 131ms so its really not bad at all

They are not actively working on it. The slogan is a holdover from ~2019, when the AD landscape in Julia was very different. There’s a reason libraries like Enzyme and Mooncake exist now.

To make Zygote work fast, you have to pretend you’re writing PyTorch code circa PyTorch 1.0. Or that you’re writing JAX code but you’re not allowed to use vmap. Much like with those two scenarios, any loop that runs for a large number of iterations is going to be terrible for performance, and that’s why the example in `Zygote.gradient` is 54000 TIMES slower than `jax.gradient` - #8 by mcabbott (which uses zero loops) is the only one that makes sense to benchmark Zygote on if you wanted to use it for yourself. Of course, if the intent is to show pathological behaviour then use whatever code suits your fancy.

3 Likes

At least now that we have DifferentiationInterface.jl we can test almost all backends and see what suits best for the usecase

3 Likes

TIL: Mooncake seems really good, I’ve never heard of it before. The documentation is amazing, it handles mutable arguments and loops as in my case.

3 Likes

Yeah and the fact that there is 0 line of c++ behind it makes it sexy

3 Likes

I have a lot of hope on Enzyme becoming the new goto AD for julia, for now, I feel like there is still a little too mush llvm weird errors or simply crashes but it’s normal it’s far for 1.0 still. Mooncake is even younger so we will see too, once it handle paralel code it will be a big competitor for sure

By the way, I forgot to say this but for benchmarking this kind of thing, you might want to check out DifferentiationInterfaceTest.jl. Its main goal is measuring speed of different backends and ensuring their correctness. The API is a bit clunky but as long as you don’t wander off the beaten path you should be fine.
Tutorial: Tutorial · DifferentiationInterfaceTest.jl

4 Likes

the fact that julia is column major does not play any role in the benchmark? i did not notice that you pay attention to this when converting the code to python

I made a fast one above and the issue is really that Zygote doesn’t like looping and takes forever to compile.

I like doing DI Test only when type changes otherwise following the name of the scenario is a nightmare, could you allow giving hand written name to those? I can make an exemple later on if im not clear. Doesn’t have to be breaking like for instance add a method where scenarios is a Dict, get the keys, make the vector of scenarios and at the end use the keys for the Data frame scenario names.

1 Like

Oh ok I thought that would be breaking, thank you

1 Like

To conclude a little here is what we can get with Enzyme and Mooncake on a non-alocating version of the code, Zygote just go boom on my memory here, if someone has a bigger ram (32go) tthey can add it easly

Scenario Time(s) AutoEnzyme() Time(s) AutoMooncake{Nothing}(nothing) Time(s) Torch Autograd
ndim = 2, nsamples = 50 0.00005110 0.00006670 0.000565
ndim = 8, nsamples = 50 0.00059380 0.00052820 0.000748
ndim = 32, nsamples = 50 0.00256860 0.00374450 0.006419
ndim = 128, nsamples = 50 0.01127950 0.02048170 0.025756
ndim = 128, nsamples = 100 0.01141160 0.02171200 0.027593
ndim = 128, nsamples = 200 0.01139890 0.02136800 0.024627
ndim = 128, nsamples = 1000 0.01214850 0.02213250 0.028803

code julia :

using BenchmarkTools,DifferentiationInterface,DifferentiationInterfaceTest
import Random
import Enzyme,Mooncake

@fastmath logistic_sigmoid(x) = 1.0f0 / (1 + exp(-x))
function loss(
	Wmid, Wctx,
	nvocab, ndim,
	tok_mid, tok_ctx, x
)
    nll = zero(eltype(Wmid))
    for i in eachindex(x)
        dotprod = zero(eltype(Wmid))
        @inbounds @simd for j in 1:ndim
            id1 = (tok_mid[i]-1)*ndim + j
            id2 = (tok_ctx[i]-1)*ndim + j
            dotprod += Wmid[id1]*Wctx[id2]
        end
        @inbounds nll += x[i]*log(logistic_sigmoid(dotprod)) + (1-x[i])*log(1-logistic_sigmoid(dotprod))
    end
    nll / length(x)
end

backends = [AutoEnzyme(),AutoMooncake(config=nothing)]
scenarios = vcat(map([2,8,32,128]) do ndim
	nsamples = 50
	nvocab = 100277
	rng = Random.Xoshiro(5)
	ntrues = nsamples ÷ 2
	tok_mid = rand(rng, 1:nvocab, nsamples) 
	tok_ctx = rand(rng, 1:nvocab, nsamples)  
	x = [trues(ntrues); falses(nsamples - ntrues)] 
	weights_mid = 0.1f0 .* randn(rng, Float32, nvocab * ndim) 
	weights_ctx = 0.1f0 .* randn(rng, Float32, nvocab * ndim)
	Scenario{:gradient,:out}(loss,weights_mid, contexts =(Constant(weights_ctx),Constant(nvocab), Constant(ndim),Constant(tok_mid), Constant(tok_ctx), Constant(x));name="ndim = $ndim, nsamples = $nsamples")
end,
map([100,200,1000]) do nsamples
	ndim = 128
	nvocab = 100277
	rng = Random.Xoshiro(5)
	ntrues = nsamples ÷ 2
	tok_mid = rand(rng, 1:nvocab, nsamples) 
	tok_ctx = rand(rng, 1:nvocab, nsamples)  
	x = [trues(ntrues); falses(nsamples - ntrues)] 
	weights_mid = 0.1f0 .* randn(rng, Float32, nvocab * ndim) 
	weights_ctx = 0.1f0 .* randn(rng, Float32, nvocab * ndim)
	Scenario{:gradient,:out}(loss,weights_mid, contexts =(Constant(weights_ctx),Constant(nvocab), Constant(ndim),Constant(tok_mid), Constant(tok_ctx), Constant(x));name="ndim = $ndim, nsamples = $nsamples")
end
)
df = benchmark_differentiation(backends, scenarios);
df_final = filter(row->row.operator == :gradient,df)[!,[:backend,:scenario,:time,:bytes]]
using DataFrames
df_pivot = unstack(df_final, :scenario, :backend, :time)
df_pivot

code python :

import torch
import time
import itertools

# Sigmoid function
def logistic_sigmoid(x):
    return 1.0 / (1.0 + torch.exp(-x))

# Vectorized loss function
def loss(Wmid, Wctx, tok_mid, tok_ctx, x, ndim):
    # Reshape 1D weights into (nvocab, ndim) for efficient indexing
    Wmid = Wmid.view(-1, ndim)
    Wctx = Wctx.view(-1, ndim)

    # Get the correct embeddings
    Wmid_selected = Wmid[tok_mid]
    Wctx_selected = Wctx[tok_ctx]

    # Compute dot products in a batch
    dotprod = torch.sum(Wmid_selected * Wctx_selected, dim=1)

    # Compute vectorized loss
    nll = torch.mean(x * torch.log(logistic_sigmoid(dotprod)) + (1 - x) * torch.log(1 - logistic_sigmoid(dotprod)))

    return nll

# Function to compute gradients
def get_grad(Wmid, Wctx, tok_mid, tok_ctx, x, ndim):
    Wmid.requires_grad_(True)
    Wctx.requires_grad_(True)

    l = loss(Wmid, Wctx, tok_mid, tok_ctx, x, ndim)
    l.backward()  # Compute gradients

    return Wmid.grad  # Return gradient of Wmid

# Define problem sizes
ndim_list = [2, 8, 32, 128]
nsamples_dict = {2: [50], 8: [50], 32: [50], 128: [50, 100, 200, 1000]}

# Number of runs for mean calculation
num_runs = 100

# Benchmark storage
results = []

# Run experiments
for ndim in ndim_list:
    for nsamples in nsamples_dict[ndim]:
        nvocab = 100277
        rng = torch.Generator().manual_seed(5)

        tok_mid = torch.randint(0, nvocab, (nsamples,), generator=rng)
        tok_ctx = torch.randint(0, nvocab, (nsamples,), generator=rng)
        x = torch.cat([torch.ones(nsamples // 2), torch.zeros(nsamples // 2)])

        Wmid = 0.1 * torch.randn((nvocab * ndim,), dtype=torch.float32, generator=rng)
        Wctx = 0.1 * torch.randn((nvocab * ndim,), dtype=torch.float32, generator=rng)

        # Run multiple times and take mean time
        times = []
        for _ in range(num_runs):
            start = time.perf_counter()
            grad = get_grad(Wmid, Wctx, tok_mid, tok_ctx, x, ndim)
            end = time.perf_counter()

            times.append(end - start)

        avg_time = times[49] # Compute mean execution time

        # Store results
        results.append((f"ndim = {ndim}, nsamples = {nsamples}", "Torch Autograd", avg_time))

# Print results as a Markdown table
print("\n| Scenario                  | Backend         | Mean Time (s) |")
print("|---------------------------|----------------|---------------|")
for scenario, backend, avg_time in results:
    print(f"| {scenario:<25} | {backend:<14} | {avg_time:.6f} |")

cas vectoriel en julia :

Scenario Time(s) AutoZygote() Time(s) AutoEnzyme() Time(s) AutoMooncake{Nothing}(nothing) Time(s) Torch Autograd
ndim = 2, nsamples = 50 0.0002035 5.63e-05 7.66e-05 0.000565
ndim = 8, nsamples = 50 0.0011986 0.0005918 0.0005545 0.000748
ndim = 32, nsamples = 50 0.0047636 0.0023224 0.0042134 0.006419
ndim = 128, nsamples = 50 0.0199836 0.0100148 0.0233116 0.025756
ndim = 128, nsamples = 100 0.0204402 0.0112289 0.0220511 0.027593
ndim = 128, nsamples = 200 0.0186262 0.0110781 0.0229095 0.024627
ndim = 128, nsamples = 1000 0.0250179 0.0151688 0.0368872 0.028803

code :

using BenchmarkTools,DifferentiationInterface,DifferentiationInterfaceTest,Statistics
import Random
import Enzyme,Mooncake,Zygote

@fastmath logistic_sigmoid(x) = 1.0f0 / (1 + exp(-x))

using LinearAlgebra

# Vectorized loss function
function loss(
    Wmid, Wctx,
    nvocab, ndim,
    tok_mid, tok_ctx, x
)
    Wmid2 = reshape(Wmid, nvocab, ndim)
    Wctx2 = reshape(Wctx, nvocab, ndim)
    Wmid_selected = @view(Wmid2[tok_mid, :])
    Wctx_selected = @view(Wctx2[tok_ctx, :])
    dotprod = sum(Wmid_selected .* Wctx_selected, dims=2)
    nll = mean(x .* log.(logistic_sigmoid.(dotprod)) .+ (1 .- x) .* log.(1 .- logistic_sigmoid.(dotprod)))
    return nll
end

backends = [AutoZygote(),AutoEnzyme(),AutoMooncake(config=nothing)]
scenarios = vcat(map([2,8,32,128]) do ndim
	nsamples = 50
	nvocab = 100277
	rng = Random.Xoshiro(5)
	ntrues = nsamples ÷ 2
	tok_mid = rand(rng, 1:nvocab, nsamples) 
	tok_ctx = rand(rng, 1:nvocab, nsamples)  
	x = [trues(ntrues); falses(nsamples - ntrues)] 
	weights_mid = 0.1f0 .* randn(rng, Float32, nvocab * ndim) 
	weights_ctx = 0.1f0 .* randn(rng, Float32, nvocab * ndim)
	Scenario{:gradient,:out}(loss,weights_mid, contexts =(Constant(weights_ctx),Constant(nvocab), Constant(ndim),Constant(tok_mid), Constant(tok_ctx), Constant(x));name="ndim = $ndim, nsamples = $nsamples")
end,
map([100,200,1000]) do nsamples
	ndim = 128
	nvocab = 100277
	rng = Random.Xoshiro(5)
	ntrues = nsamples ÷ 2
	tok_mid = rand(rng, 1:nvocab, nsamples) 
	tok_ctx = rand(rng, 1:nvocab, nsamples)  
	x = [trues(ntrues); falses(nsamples - ntrues)] 
	weights_mid = 0.1f0 .* randn(rng, Float32, nvocab * ndim) 
	weights_ctx = 0.1f0 .* randn(rng, Float32, nvocab * ndim)
	Scenario{:gradient,:out}(loss,weights_mid, contexts =(Constant(weights_ctx),Constant(nvocab), Constant(ndim),Constant(tok_mid), Constant(tok_ctx), Constant(x));name="ndim = $ndim, nsamples = $nsamples")
end
)
df = benchmark_differentiation(backends, scenarios);
df_final = filter(row->row.operator == :gradient,df)[!,[:backend,:scenario,:time,:bytes]]
using DataFrames
df_pivot = unstack(df_final, :scenario, :backend, :time)
df_pivot

1 Like

Note that this way of benchmarking removes the preparation step from the final timing, which is better for Mooncake because preparation can be expensive (unlike for Enzyme where it is basically free)

1 Like