using Lux, Zygote, Random, LinearAlgebra, SparseArrays, BenchmarkTools
using Enzyme
# Create an RNG object
rng = MersenneTwister(1234)
function generate_combinations(N)
M = 2^N
col = 1:M
row = 1:N
combinations = 1 .- 2 .* mod.(cld.(col', 2 .^ (N .- row)), 2)
return Float32.(combinations)
end
# Example usage:
N = 15 # equivalent to 15 spins
all_configurations = generate_combinations(N)
size(all_configurations)
# 15×32768 Matrix{Float32}
dim = 2^N
density = 2 / dim
H = sprandn(rng, Float32, dim, dim, density)
H = (H + H') / 2
# 32768×32768 SparseMatrixCSC{Float32, Int64} with 131678 stored entries
model = @compact(λ=rand(rng, Float32, 1) * 0.01f0) do x
y = @. logsigmoid(λ * x)
return sum(y; dims=1) / 2
end
function to_array(model, ps, st, all_configurations)
ψ, _ = model(all_configurations, ps, st)
return vec(exp.(ψ))
end
function compute_energy(model, ps, st, H, all_configurations)
ψ = to_array(model, ps, st, all_configurations)
return dot(ψ, H, ψ)
end
function compute_energy_and_gradient(model, ps, st, H, all_configurations)
(; val, grad) = Zygote.withgradient(
compute_energy, model, ps, st, H, all_configurations)
return (; val, grad=grad[2])
end
function compute_energy_and_gradient_enzyme(model, ps, st, H, all_configurations)
dps = Enzyme.make_zero(ps)
_, val = Enzyme.autodiff(ReverseWithPrimal, compute_energy, Active, Const(model),
Duplicated(ps, dps), Const(st), Const(H), Const(all_configurations))
return (; val, grad=dps)
end
ps, st = Lux.setup(rng, model)
compute_energy(model, ps, st, H, all_configurations)
# 0.000719939269632178
compute_energy_and_gradient(model, ps, st, H, all_configurations)
compute_energy_and_gradient_enzyme(model, ps, st, H, all_configurations)
@benchmark compute_energy($model, $ps, $st, $H, $all_configurations)
@benchmark compute_energy_and_gradient($model, $ps, $st, $H, $all_configurations)
@benchmark compute_energy_and_gradient_enzyme($model, $ps, $st, $H, $all_configurations)
Enzyme is 22ms here compared to Zygote’s 1s on my machine (with a very crappy CPU)
Also this is in the range where julia’s broadcast on CPU would be terrible. For example,
model = @compact(λ=rand(rng, Float32, 1) * 0.01f0) do x
y = similar(x)
Threads.@threads :static for I in eachindex(y)
@inbounds y[I] = logsigmoid(x[I] * λ[1])
end
return sum(y; dims=1) / 2
end
brings down the forward pass from 10ms to 2ms (with 16 threads). You would need enzyme to differentiate this