Zygote much slower than JAX for automatic differentiation of energy

using Lux, Zygote, Random, LinearAlgebra, SparseArrays, BenchmarkTools
using Enzyme

# Create an RNG object
rng = MersenneTwister(1234)

function generate_combinations(N)
    M = 2^N
    col = 1:M
    row = 1:N

    combinations = 1 .- 2 .* mod.(cld.(col', 2 .^ (N .- row)), 2)

    return Float32.(combinations)
end

# Example usage:
N = 15 # equivalent to 15 spins
all_configurations = generate_combinations(N)
size(all_configurations)

# 15×32768 Matrix{Float32}

dim = 2^N
density = 2 / dim

H = sprandn(rng, Float32, dim, dim, density)
H = (H + H') / 2

# 32768×32768 SparseMatrixCSC{Float32, Int64} with 131678 stored entries

model = @compact(λ=rand(rng, Float32, 1) * 0.01f0) do x
    y = @. logsigmoid(λ * x)
    return sum(y; dims=1) / 2
end

function to_array(model, ps, st, all_configurations)
    ψ, _ = model(all_configurations, ps, st)
    return vec(exp.(ψ))
end

function compute_energy(model, ps, st, H, all_configurations)
    ψ = to_array(model, ps, st, all_configurations)
    return dot(ψ, H, ψ)
end

function compute_energy_and_gradient(model, ps, st, H, all_configurations)
    (; val, grad) = Zygote.withgradient(
        compute_energy, model, ps, st, H, all_configurations)
    return (; val, grad=grad[2])
end

function compute_energy_and_gradient_enzyme(model, ps, st, H, all_configurations)
    dps = Enzyme.make_zero(ps)
    _, val = Enzyme.autodiff(ReverseWithPrimal, compute_energy, Active, Const(model),
        Duplicated(ps, dps), Const(st), Const(H), Const(all_configurations))
    return (; val, grad=dps)
end

ps, st = Lux.setup(rng, model)

compute_energy(model, ps, st, H, all_configurations)
# 0.000719939269632178
compute_energy_and_gradient(model, ps, st, H, all_configurations)

compute_energy_and_gradient_enzyme(model, ps, st, H, all_configurations)

@benchmark compute_energy($model, $ps, $st, $H, $all_configurations)

@benchmark compute_energy_and_gradient($model, $ps, $st, $H, $all_configurations)

@benchmark compute_energy_and_gradient_enzyme($model, $ps, $st, $H, $all_configurations)

Enzyme is 22ms here compared to Zygote’s 1s on my machine (with a very crappy CPU)

Also this is in the range where julia’s broadcast on CPU would be terrible. For example,

model = @compact(λ=rand(rng, Float32, 1) * 0.01f0) do x
    y = similar(x)
    Threads.@threads :static for I in eachindex(y)
        @inbounds y[I] = logsigmoid(x[I] * λ[1])
    end
    return sum(y; dims=1) / 2
end

brings down the forward pass from 10ms to 2ms (with 16 threads). You would need enzyme to differentiate this

1 Like