Zygote much slower than JAX for automatic differentiation of energy

Hello,

I’m new in the neural network field, and I’m starting to study Neural Network Quantum States (for example ground state searching using neural networks). Usually, they are implemented in Python using Jax, but then I asked myself whether I could perform better using Julia (spoiler: not yet).

In the following, you will find the two examples, in Jax and in Lux.jl with Zygote. Just to make the example minimal, I take a random sparse matrix representing the Hamiltonian, then I generate all the 2^N possible states, and then apply a very simple network to this set of states.

I start with the Julia case

using Lux
using Zygote
using Random
using LinearAlgebra
using SparseArrays
using BenchmarkTools

# Create an RNG object
rng = MersenneTwister(1234)

function generate_combinations(N)
    M = 2^N
    col = 1:M
    row = 1:N

    combinations = 1 .- 2 .* mod.(cld.(col', 2 .^ (N .- row)), 2)

    return Float32.(combinations)
end

# Example usage:
N = 15 # equivalent to 15 spins
all_configurations = generate_combinations(N)
size(all_configurations)

# 15×32768 Matrix{Float32}

Matrix generation

dim = 2^N
density = 2 / dim

H = sprandn(rng, Float32, dim, dim, density)
H = (H + H') / 2

# 32768×32768 SparseMatrixCSC{Float32, Int64} with 131678 stored entries

Neural Network definition

# I see that λ becomes Float64 although I declared it as Float32
model = @compact(λ=rand(rng, Float32, 1)*0.01) do x
    y = λ .* x
    return sum(logsigmoid(y), dims=1) / 2
end

function to_array(model, ps, st, all_configurations)
    ψ, _ = model(all_configurations, ps, st)
    return vec(exp.(ψ))
end

function compute_energy(model, ps, st, H, all_configurations)
    ψ = to_array(model, ps, st, all_configurations)
    return dot(ψ, H, ψ)
end

function compute_energy_and_gradient(model, ps, st, H, all_configurations)
    return Zygote.withgradient(ps -> compute_energy(model, ps, st, H, all_configurations), ps)
end

ps, st = Lux.setup(rng, model)

# I make the parameter Float32
ps = (λ = Float32.(ps.λ), )
# (λ = Float32[0.005596993],)

Try to compute the energy

compute_energy(model, ps, st, H, all_configurations)
# 0.000719939269632178

And try to benchmark it

@benchmark compute_energy(model, ps, st, H, all_configurations)
BenchmarkTools.Trial: 957 samples with 1 evaluation.
 Range (min … max):  4.577 ms …  15.013 ms  ┊ GC (min … max): 0.00% … 6.86%
 Time  (median):     5.038 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   5.217 ms ± 715.179 μs  ┊ GC (mean ± σ):  1.89% ± 3.71%

        ▂▆█▄          ▂▂▃                                      
  ▂▃▃▃▄▆████▇▅▅▄▄▄▄▄▄▆███▇▅▃▃▃▃▂▃▂▂▂▂▁▂▁▂▂▂▃▃▃▄▃▃▂▂▂▂▁▁▁▁▁▁▁▂ ▃
  4.58 ms         Histogram: frequency by time        6.76 ms <

 Memory estimate: 8.25 MiB, allocs estimate: 13.

Now benchmark the gradient

@benchmark compute_energy_and_gradient(model, ps, st, H, all_configurations)
BenchmarkTools.Trial: 5 samples with 1 evaluation.
 Range (min … max):  897.608 ms …   1.083 s  ┊ GC (min … max):  0.04% … 15.73%
 Time  (median):        1.020 s              ┊ GC (median):    16.70%
 Time  (mean ± σ):      1.013 s ± 74.032 ms  ┊ GC (mean ± σ):  13.33% ±  7.26%

  █                              █       █               █   █  
  █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁█ ▁
  898 ms          Histogram: frequency by time          1.08 s <

 Memory estimate: 4.01 GiB, allocs estimate: 83.

And look at the 4GB of allocations!


Now let’s try the same with python, JAX and flax

%env JAX_PLATFORM_NAME=cpu
import numpy as np
from scipy.sparse import random as sparse_random

import jax
import jax.random as jrandom
import jax.numpy as jnp
from jax.experimental.sparse import BCOO

from functools import partial

# PRNG key for JAX
key = jrandom.PRNGKey(0)

import flax.linen as nn

jax.devices()
# [CpuDevice(id=0)]
def generate_combinations(N):
    M = 2**N
    col = jnp.arange(1, M + 1)
    row = jnp.arange(1, N + 1)
    
    jax.jit
    def compute_combinations(col, row):
        return 1 - 2 * jnp.mod(jnp.ceil(col[:, None] / (2 ** (N - row))), 2)
    
    combinations = compute_combinations(col, row)
    return combinations

# Example usage:
N = 15
all_configurations = generate_combinations(N)
all_configurations.shape
# (32768, 15)

Generation of the random matrix

# Parameters for the sparse matrix
dim = 2**N  # Number of dim
density = 2 / dim  # 2% sparsity

# Calculate the number of non-zero elements
num_nonzeros = int(dim * dim * density)

# Generate random indices for the sparse matrix
row_indices = jrandom.randint(key, (num_nonzeros,), 0, dim)
col_indices = jrandom.randint(key, (num_nonzeros,), 0, dim)

# Generate random values for the sparse matrix
values = jrandom.normal(key, (num_nonzeros,))

# Create the JAX sparse matrix (BCOO)
indices = jnp.vstack((row_indices, col_indices))
H = BCOO((values, indices.T), shape=(dim, dim))

H = (H + H.T) / 2
H
# BCOO(float32[32768, 32768], nse=131072)

Neural network definition

class MF(nn.Module):

    @nn.compact
    def __call__(self, x):
        lam = self.param(
            "lambda", nn.initializers.normal(), (1,), x.dtype
        )
        
        p = nn.log_sigmoid(lam*x)

        return 0.5 * jnp.sum(p, axis=-1)

def to_array(model, parameters, all_configurations):
    # now evaluate the model, and convert to a normalised wavefunction.
    logpsi = model.apply(parameters, all_configurations)
    psi = jnp.exp(logpsi)
    psi = psi / jnp.linalg.norm(psi)
    return psi

# we use partial to directly jit this function. Jitting the top-most will jit everything inside it as well.
@partial(jax.jit, static_argnames='model')
def compute_energy(model, parameters, H, all_configurations):
    psi_gs = to_array(model, parameters, all_configurations)
    return psi_gs.conj().T @ H @ psi_gs

@partial(jax.jit, static_argnames='model')
def compute_energy_and_gradient(model, parameters, H, all_configurations):
    grad_fun = jax.value_and_grad(compute_energy, argnums=1)
    return grad_fun(model, parameters, H, all_configurations)

# create an instance of the model
model = MF()

# initialise the weights
parameters = model.init(key, np.random.rand(N))
parameters
# {'params': {'lambda': Array([-0.01280743], dtype=float32)}}

Benchmarks energy computation

%timeit compute_energy(model, parameters, H, all_configurations)
# 822 µs ± 71.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Benchmark gradient

%timeit compute_energy_and_gradient(model, parameters, H, all_configurations)
# 2.29 ms ± 288 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

So we have almost a factor 500 of difference. I don’t know if I am doing something wrong. For sure I didn’t expected those 4GB of allocations using Zygote.

3 Likes

I have never used JAX so I can’t help on that front, but here are some remarks on the Julia side.

The number 0.01 has type Float64, so that might be the reason. If you replace it with 0f01 or Float32(0.01), I suspect everything will remain in 32 bits.

It’s a good reflex to interpolate global variables when using BenchmarkTools.jl

@benchmark compute_energy($model, $ps, $st, $H, $all_configurations)

In this case it doesn’t make much of a difference, but global variables are the first thing I look for to explain weird benchmarking results.

You can try to profile your function to see where it spends time, which is the next logical step after benchmarking. For example, in VSCode, the following works well:

@profview for _ in 1:100; compute_energy(model, ps, st, H, all_configurations); end

From what I observe on the resulting flame graph, nearly all of the time is spent in the softplus activation function, which is applied elementwise on a 15×32768 matrix. It seems a bit of a waste because this matrix has only two different possible values: and . So the first answer would be to optimize that part.

With this small change

model = @compact(λ=rand(rng, Float32, 1)*0f01) do x
    y = λ .* x
    return sum(logsigmoid(y), dims=1) / 2
end

model_fast = @compact(λ = rand(rng, Float32, 1) * 0.0f01) do x
    a = only(logsigmoid(λ)) / 2  # scalar
    b = only(logsigmoid(-λ)) / 2  # scalar
    f(z) = ifelse(isone(z), a, b)  # function to apply elementwise during the sum
    return sum(f, x; dims=1)
end

I get the exact same numerics and a significant speedup.

Benchmark results
julia> ps, st = Lux.setup(rng, model)
((λ = Float32[0.0],), NamedTuple())

julia> compute_energy(model, ps, st, H, all_configurations)
0.00066391635f0

julia> @benchmark compute_energy($model, $ps, $st, $H, $all_configurations)
BenchmarkTools.Trial: 406 samples with 1 evaluation.
 Range (min … max):  10.647 ms … 22.327 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     11.495 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   12.303 ms ±  1.856 ms  ┊ GC (mean ± σ):  0.78% ± 3.30%

  ▂▄█▆▃▄▃▂▁  ▁▃▄▁▁ ▁ ▁    ▁    ▂                               
  █████████▅▅███████▅█▇█▅▇█▄▅█▁█▆▄▄▅▄▄▆▅▄▅▄▁▅▁▁▅▁▄▁▄▁▄▁▄▁▅▁▁▄ ▇
  10.6 ms      Histogram: log(frequency) by time      19.2 ms <

 Memory estimate: 4.13 MiB, allocs estimate: 12.

julia> ps_fast, st_fast = Lux.setup(rng, model_fast)
((λ = Float32[0.0],), NamedTuple())

julia> compute_energy(model_fast, ps_fast, st_fast, H, all_configurations)
0.00066391635f0

julia> @benchmark compute_energy($model_fast, $ps_fast, $st_fast, $H, $all_configurations)
BenchmarkTools.Trial: 4478 samples with 1 evaluation.
 Range (min … max):  959.208 μs …   2.722 ms  ┊ GC (min … max): 0.00% … 33.34%
 Time  (median):       1.064 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):     1.110 ms ± 166.208 μs  ┊ GC (mean ± σ):  0.29% ±  2.51%

    ▄ ▂█▂▄▃▂▂                                                    
  ▂▂█▆████████▇▅▄▄▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▁ ▂
  959 μs           Histogram: frequency by time         1.72 ms <

 Memory estimate: 256.36 KiB, allocs estimate: 9.
1 Like

If we’re looking at the allocations themselves, I don’t observe the same as you.
On my laptop, it’s 4 Mb rather than 4 Gb.
Can you share the versions of Julia and the packages that you are using?

As a side note, allocations of 4 Gb would be coherent with a dense version of the sparse H matrix:

julia> @benchmark ones(Float32, 2^15, 2^15)
BenchmarkTools.Trial: 4 samples with 1 evaluation.
 Range (min … max):  1.505 s …    1.775 s  ┊ GC (min … max): 0.02% … 14.93%
 Time  (median):     1.647 s               ┊ GC (median):    6.40%
 Time  (mean ± σ):   1.643 s ± 133.053 ms  ┊ GC (mean ± σ):  7.24% ±  7.88%

  █         █                                      █       █  
  █▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁█ ▁
  1.5 s          Histogram: frequency by time         1.78 s <

 Memory estimate: 4.00 GiB, allocs estimate: 2.

I wonder if this is related to a naive chain rule which allocates the full matrix during a pullback

Maybe @mohamed82008 will have an opinion.

This is great, thanks. However, I think that the same can be applied to the JAX case, further improving the performances also in that side. I would say to benchmark them using the same code structure. BTW, good to know!

Even for the gradient calculation? I have 4MB too for the energy calculation, but not for the gradient.

This is my Julia version:

julia> versioninfo()
Julia Version 1.10.3
Commit 0b4590a5507 (2024-04-30 10:59 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 32 × 13th Gen Intel(R) Core(TM) i9-13900KF
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, goldmont)
Threads: 32 default, 0 interactive, 16 GC (on 32 virtual cores)
Environment:
  JULIA_EDITOR = code
  JULIA_NUM_THREADS = 32

And my packages:

(NQS Basics) pkg> st
Status `~/GitHub/Research/2024/Neural Network Basics/NQS Basics/Project.toml`
  [6e4b80f9] BenchmarkTools v1.5.0
  [052768ef] CUDA v5.3.3
  [13f3f980] CairoMakie v0.12.0
  [b2108857] Lux v0.5.46
  [d0bbae9a] LuxCUDA v0.3.2
  [eb30cadb] MLDatasets v0.7.14
  [0b1bfda6] OneHotArrays v0.2.5
  [3bd65402] Optimisers v0.3.3
  [6c2fb7c5] QuantumToolbox v0.8.0
  [e88e6eb3] Zygote v0.6.70
  [10745b16] Statistics v1.10.0
2 Likes

You’re definitely right that we would need to apply this improvement on both sides.
However in some cases a well-written code can be much easier to differentiate for Zygote than a badly written one, so it’s not necessarily zero-sum.

By the way, I edited my code above with an even faster version, yielding x10 speedup on the energy.

Silly me, I forgot to run the actual gradient computation… I now observe the same allocations as you, and my x10 faster energy function is actually… slower to differentiate. Very frustrating indeed.

julia> @benchmark compute_energy_and_gradient($model, $ps, $st, $H, $all_configurations)
BenchmarkTools.Trial: 3 samples with 1 evaluation.
 Range (min … max):  1.593 s …    1.842 s  ┊ GC (min … max):  0.03% … 10.68%
 Time  (median):     1.827 s               ┊ GC (median):    10.77%
 Time  (mean ± σ):   1.754 s ± 139.300 ms  ┊ GC (mean ± σ):   8.00% ±  6.64%

  █                                                     █  █  
  █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁█ ▁
  1.59 s         Histogram: frequency by time         1.84 s <

 Memory estimate: 4.01 GiB, allocs estimate: 82.

julia> @benchmark compute_energy_and_gradient($model_fast, $ps_fast, $st_fast, $H, $all_configurations)
BenchmarkTools.Trial: 3 samples with 1 evaluation.
 Range (min … max):  1.782 s …   1.919 s  ┊ GC (min … max):  0.16% … 11.39%
 Time  (median):     1.840 s              ┊ GC (median):    11.87%
 Time  (mean ± σ):   1.847 s ± 68.330 ms  ┊ GC (mean ± σ):   8.09% ±  6.77%

  █                       █                               █  
  █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█ ▁
  1.78 s         Histogram: frequency by time        1.92 s <

 Memory estimate: 4.04 GiB, allocs estimate: 324.

I was right, this specific gradient computation spends most of its time in the reverse rule for dot(x, A, y), specifically in a line which allocates a dense matrix. The same exact line that @mohamed82008 wants to fix with

However this part of the cotangent should never be un-thunked because the matrix A in question is fixed. @oxinabox any insights?

1 Like

As a workaround, you can use Enzyme.jl for differentiation instead of Zygote.jl:

using Enzyme

function compute_gradient_enzyme(model, ps, st, H, all_configurations)
    return Enzyme.gradient(
        Enzyme.Reverse, ps -> compute_energy(model, ps, st, H, all_configurations), ps
    )
end
julia> @benchmark compute_gradient_enzyme($model, $ps, $st, $H, $all_configurations)
BenchmarkTools.Trial: 154 samples with 1 evaluation.
 Range (min … max):  29.815 ms … 57.770 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     31.078 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   32.506 ms ±  3.847 ms  ┊ GC (mean ± σ):  0.34% ± 1.17%

  ▁▂▆█▅▄▃▁▅                                                    
  █████████▅██▅▁▆▆▅▅▆▆▁▆▅▅▆▁▁▅▁▁▅▅▅▅▁▅▅▁▅▅▁▁█▁▁▅▁▁▁▅▁▁▁▁▅▁▁▁▅ ▅
  29.8 ms      Histogram: log(frequency) by time      43.8 ms <

 Memory estimate: 8.25 MiB, allocs estimate: 26.

julia> @benchmark compute_gradient_enzyme($model_fast, $ps_fast, $st_fast, $H, $all_configurations)
BenchmarkTools.Trial: 942 samples with 1 evaluation.
 Range (min … max):  4.864 ms …   8.664 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     5.179 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   5.298 ms ± 428.144 μs  ┊ GC (mean ± σ):  0.14% ± 1.45%

     ▂█▄   ▁                                                   
  ▄▃▄█████▇██▇▆▆▄▄▃▃▃▃▃▃▃▃▃▃▂▃▂▂▁▂▃▂▂▂▂▂▂▁▂▂▁▁▂▂▂▂▂▂▁▂▁▂▂▁▁▂▂ ▃
  4.86 ms         Histogram: frequency by time         7.2 ms <

 Memory estimate: 513.09 KiB, allocs estimate: 20.

In which case you also do benefit from the faster energy implem.

I don’t know how much of Lux.jl supports it officially, but it does give you the speedup you need.

1 Like

Zygote unthunks everything by default.
There are PRs open to fix that but they have gone stale.

2 Likes

Enzyme passes all tests as of this week:

Two performance improvements were identified:

But @avikpal and @wsmoses I think we’re at least at the point where we can consider Enzyme the default for Lux?

I’m porting over all SciML docs to start using Enzyme exclusively within the next month as well. I think it’s hit that level. There’s still a bit of a CUDA thing highlighted there but the core issues seem to be solved and the tests are passing all around.

That said, it’s literally this week kind of stuff, so as we do the full rollout we may find something.

I guess the way to put it is, there used to be major blockers we would identify as “this is why you couldn’t use it yet”. As of this week, those are gone and now it’s small bugs or things we haven’t uncovered yet.

7 Likes

A couple of things:

  1. Lux.jl and Flux.jl now test Enzyme for a number of models (at least on CPU), so this should generally be considered supported. It is incredibly recent though so performance is in progress. Here though, if I’m reading correctly it looks like is on par with the forward pass (both 5us or so?) so I wouldn’t expect more to be possible without making the original primal code faster.
  2. Enzyme is also usable in JaX and has been shown to give nontrivial speedups as well. I should add some docs on it but you can see here (Enzyme-JAX/test/bench_vs_xla.py at main · EnzymeAD/Enzyme-JAX · GitHub) and other tests how to use it (essentially just add a enzyme_jax_ir decorator as well).
  3. Minor, but to benchmark JaX code effectively, you’ll need to add a block_until_ready call. JaX arrays are lazy by default so you may not be timing the actual execution you expect.
7 Likes

Just to do some expectation-setting differentiating CUDA device code from the host in Enzyme is still in progress. It is actively expanding presently, but because it is partial, folks may hit something unsupported in those cases.

2 Likes

I know virtually nothing about any of this, but should this parameter really be a single-element vector instead of a scalar?

It needs to be to let Lux decide that it is a parameter and we need gradients for it. It can be bypassed if not using the compact API but currently that is how it is Meta-Issue for improvements to `@compact` · Issue #606 · LuxDL/Lux.jl · GitHub

Looks like a chunk of time is being spent on log and exp functions. We should probably use fastmath here (optionally atleast)

image

using Lux, Zygote, Random, LinearAlgebra, SparseArrays, BenchmarkTools
using Enzyme

# Create an RNG object
rng = MersenneTwister(1234)

function generate_combinations(N)
    M = 2^N
    col = 1:M
    row = 1:N

    combinations = 1 .- 2 .* mod.(cld.(col', 2 .^ (N .- row)), 2)

    return Float32.(combinations)
end

# Example usage:
N = 15 # equivalent to 15 spins
all_configurations = generate_combinations(N)
size(all_configurations)

# 15×32768 Matrix{Float32}

dim = 2^N
density = 2 / dim

H = sprandn(rng, Float32, dim, dim, density)
H = (H + H') / 2

# 32768×32768 SparseMatrixCSC{Float32, Int64} with 131678 stored entries

model = @compact(λ=rand(rng, Float32, 1) * 0.01f0) do x
    y = @. logsigmoid(λ * x)
    return sum(y; dims=1) / 2
end

function to_array(model, ps, st, all_configurations)
    ψ, _ = model(all_configurations, ps, st)
    return vec(exp.(ψ))
end

function compute_energy(model, ps, st, H, all_configurations)
    ψ = to_array(model, ps, st, all_configurations)
    return dot(ψ, H, ψ)
end

function compute_energy_and_gradient(model, ps, st, H, all_configurations)
    (; val, grad) = Zygote.withgradient(
        compute_energy, model, ps, st, H, all_configurations)
    return (; val, grad=grad[2])
end

function compute_energy_and_gradient_enzyme(model, ps, st, H, all_configurations)
    dps = Enzyme.make_zero(ps)
    _, val = Enzyme.autodiff(ReverseWithPrimal, compute_energy, Active, Const(model),
        Duplicated(ps, dps), Const(st), Const(H), Const(all_configurations))
    return (; val, grad=dps)
end

ps, st = Lux.setup(rng, model)

compute_energy(model, ps, st, H, all_configurations)
# 0.000719939269632178
compute_energy_and_gradient(model, ps, st, H, all_configurations)

compute_energy_and_gradient_enzyme(model, ps, st, H, all_configurations)

@benchmark compute_energy($model, $ps, $st, $H, $all_configurations)

@benchmark compute_energy_and_gradient($model, $ps, $st, $H, $all_configurations)

@benchmark compute_energy_and_gradient_enzyme($model, $ps, $st, $H, $all_configurations)

Enzyme is 22ms here compared to Zygote’s 1s on my machine (with a very crappy CPU)

Also this is in the range where julia’s broadcast on CPU would be terrible. For example,

model = @compact(λ=rand(rng, Float32, 1) * 0.01f0) do x
    y = similar(x)
    Threads.@threads :static for I in eachindex(y)
        @inbounds y[I] = logsigmoid(x[I] * λ[1])
    end
    return sum(y; dims=1) / 2
end

brings down the forward pass from 10ms to 2ms (with 16 threads). You would need enzyme to differentiate this

1 Like

@avikpal Enzyme should be able to differentiate threads.@threads However, if you’re going for performance it may be useful to try out Enzyme.@parallel, which makes a closure-less threading (it is obviously differentiable, but works without differentiation). See here: Enzyme.jl/test/threads.jl at 27859773e566377542c11c445aa936ed69dbe14d · EnzymeAD/Enzyme.jl · GitHub

1 Like

Seems to be same as @threads

model = @compact(λ=rand(rng, Float32, 1) * 0.01f0) do x
    y = similar(x)
    Enzyme.@parallel x y λ for i in eachindex(x)
        @inbounds y[i] = logsigmoid(x[i] * λ[1])
    end
    return sum(y; dims=1) / 2
end
julia> @benchmark compute_energy($model, $ps, $st, $H, $all_configurations)
BenchmarkTools.Trial: 878 samples with 1 evaluation.
 Range (min … max):  2.710 ms … 69.396 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     4.588 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   5.666 ms ±  3.883 ms  ┊ GC (mean ± σ):  0.72% ± 3.49%

     ███▇▆▅▄▃▃▄▁                                              
  █▆▇██████████████▇█▇▇█▇▁▇▆▇▆▇▇▇▄▇▄▄▅▁▆▁▄▅▄▄▅▁▆▄▅▁▁▄▄▄▁▁▁▁▄ █
  2.71 ms      Histogram: log(frequency) by time     20.1 ms <

 Memory estimate: 2.26 MiB, allocs estimate: 91.

julia> @benchmark compute_energy_and_gradient_enzyme($model, $ps, $st, $H, $all_configurations)
BenchmarkTools.Trial: 126 samples with 1 evaluation.
 Range (min … max):  29.315 ms … 128.247 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     35.490 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   39.914 ms ±  14.689 ms  ┊ GC (mean ± σ):  0.20% ± 0.81%

   ▃▃█▅                                                         
  ▇████▅▄▄▄▃▃▃▃▁▃▄▃▁▁▁▁▄▃▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃ ▃
  29.3 ms         Histogram: frequency by time          127 ms <

 Memory estimate: 4.52 MiB, allocs estimate: 184.

~Though taking a second look at the code, it should be written as a parallel mapreduce than a broadcast followed by sum.~ Missed the earlier comment, which indeed gets the performance by doing it as a direct sum (mapreduce) without the broadcasting

1 Like

Yeah I don’t think it would matter much for the forward pass, it’s possibly helpful for the reverse pass, if there’s any aliasing info provided from the outside.

That reminds me that I need to finish tying that parallel to Tapir (the parallel IR/optimization compiler used by Enzyme & tested in the Julia compiler, not the new AD package which was recently registered with a confusingly similar name).

As you can see from slide 61 here: https://c.wsmoses.com/presentations/defense.pdf and published at SC’22 here (https://c.wsmoses.com/papers/enzymePar.pdf), doing Tapir-style parallel optimizations [in this case in OpenMP] allows the generated gradients to scale much better.

1 Like

Thank you!

With Enzyme.jl the gradient calculation takes 12.5 ms compared to 1s with Zygote.jl!

Jax takes only 1.95 ms, but just because the energy calculation takes 680us (compared to 5.1ms of Lux.jl). If we calculate the ratio

\frac{0.68 ms}{1.95 ms} = 0.3487

it is less than the Lux.jl + Enzyme.jl

\frac{5.11 ms}{12.6 ms} = 0.4055

So now we have only a factor 6 related only to Lux.jl rather than Enzyme.jl. Does Jax maybe uses multithreading natively?


However, I noticed that it fails when porting everything on the GPU. The energy calculation works, an it takes 40 us. But the grandient calculation using Enzyme.jl fails.

2 Likes

To be precise, here the main changes for having GPU support

all_configurations = all_configurations |> gpu_device()
H = H |> gpu_device()
ps, st = Lux.setup(rng, model) |> gpu_device()

@benchmark compute_energy($model, $ps, $st, $H, $all_configurations)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  41.412 μs …  32.012 ms  ┊ GC (min … max): 0.00% … 31.57%
 Time  (median):     44.698 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   51.313 μs ± 425.632 μs  ┊ GC (mean ± σ):  3.76% ±  0.45%

            ▄▇█▅▁                                               
  ▂▂▃▃▂▂▂▂▄▇██████▇▇▆▇▆▇▆▇▆▆▆▆▆▅▄▄▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▂▂▂▂▂▂▂ ▃
  41.4 μs         Histogram: frequency by time         53.8 μs <

 Memory estimate: 7.91 KiB, allocs estimate: 225.

And the error for the gradient

function compute_energy_and_gradient(model, ps, st, H, all_configurations)
    # return Zygote.withgradient(ps -> compute_energy(model, ps, st, H, all_configurations), ps)
    dps = make_zero(ps)

    res = Enzyme.autodiff(ReverseWithPrimal, compute_energy, Active, Const(model), Duplicated(ps, dps), Const(st), Const(H), Const(all_configurations))
    
    res[2], dps
end

compute_energy_and_gradient(model, ps, st, H, all_configurations)
Error output
┌ Error: Found null pointer
│   arg = %active_repl.checked = load atomic {} addrspace(10)*, {} addrspace(10)** inttoptr (i64 138344612374880 to {} addrspace(10)**) unordered, align 8, !dbg !426, !tbaa !427, !alias.scope !418, !noalias !421
└ @ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/absint.jl:117
┌ Error: Found null pointer
│   arg = %active_repl.checked = load atomic {} addrspace(10)*, {} addrspace(10)** inttoptr (i64 138344612374880 to {} addrspace(10)**) unordered, align 8, !dbg !426, !tbaa !427, !alias.scope !418, !noalias !421
└ @ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/absint.jl:117
┌ Error: Found null pointer
│   arg = %active_repl.checked = load atomic {} addrspace(10)*, {} addrspace(10)** inttoptr (i64 138344612374880 to {} addrspace(10)**) unordered, align 8, !dbg !426, !tbaa !427, !alias.scope !418, !noalias !421
└ @ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/absint.jl:117
┌ Error: Found null pointer
│   arg = %active_repl.checked = load atomic {} addrspace(10)*, {} addrspace(10)** inttoptr (i64 138344612374880 to {} addrspace(10)**) unordered, align 8, !dbg !426, !tbaa !427, !alias.scope !418, !noalias !421
└ @ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/absint.jl:117
┌ Error: Found null pointer
│   arg = %active_repl.checked = load atomic {} addrspace(10)*, {} addrspace(10)** inttoptr (i64 138344612374880 to {} addrspace(10)**) unordered, align 8, !dbg !426, !tbaa !427, !alias.scope !418, !noalias !421
└ @ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/absint.jl:117
┌ Error: Found null pointer
│   arg = %active_repl.checked = load atomic {} addrspace(10)*, {} addrspace(10)** inttoptr (i64 138344612374880 to {} addrspace(10)**) unordered, align 8, !dbg !426, !tbaa !427, !alias.scope !418, !noalias !421
└ @ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/absint.jl:117

task switch not allowed from inside staged nor pure functions

Stacktrace:
  [1] try_yieldto(undo::typeof(Base.ensure_rescheduled))
    @ Base ./task.jl:921
  [2] wait()
    @ Base ./task.jl:995
  [3] uv_write(s::Base.PipeEndpoint, p::Ptr{UInt8}, n::UInt64)
    @ Base ./stream.jl:1048
  [4] unsafe_write(s::Base.PipeEndpoint, p::Ptr{UInt8}, n::UInt64)
    @ Base ./stream.jl:1120
  [5] unsafe_write
    @ ./io.jl:431 [inlined]
  [6] unsafe_write
    @ ./io.jl:698 [inlined]
  [7] write(s::IJulia.IJuliaStdio{Base.PipeEndpoint}, a::Vector{UInt8})
    @ Base ./io.jl:721
  [8] handle_message(logger::Logging.ConsoleLogger, level::Base.CoreLogging.LogLevel, message::Any, _module::Any, group::Any, id::Any, filepath::Any, line::Any; kwargs...)
    @ Logging ~/.julia/juliaup/julia-1.10.3+0.x64.linux.gnu/share/julia/stdlib/v1.10/Logging/src/ConsoleLogger.jl:178
  [9] #invokelatest#2
    @ ./essentials.jl:894 [inlined]
 [10] invokelatest
    @ ./essentials.jl:889 [inlined]
 [11] macro expansion
    @ ./logging.jl:365 [inlined]
 [12] absint(arg::LLVM.LoadInst, partial::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/absint.jl:117
 [13] abs_typeof(arg::LLVM.LoadInst, partial::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/absint.jl:307
 [14] check_ir!(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, errors::Vector{Tuple{String, Vector{Base.StackTraces.StackFrame}, Any}}, imported::Set{String}, inst::LLVM.CallInst, calls::Vector{Any})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/compiler/validation.jl:504
 [15] check_ir!(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, errors::Vector{Tuple{String, Vector{Base.StackTraces.StackFrame}, Any}}, imported::Set{String}, f::LLVM.Function)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/compiler/validation.jl:208
 [16] check_ir!(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, errors::Vector{Tuple{String, Vector{Base.StackTraces.StackFrame}, Any}}, mod::LLVM.Module)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/compiler/validation.jl:178
 [17] check_ir
    @ ~/.julia/packages/Enzyme/srACB/src/compiler/validation.jl:157 [inlined]
 [18] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/compiler.jl:4520
 [19] codegen
    @ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:4481 [inlined]
 [20] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5771
 [21] _thunk
    @ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5771 [inlined]
 [22] cached_compilation
    @ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5809 [inlined]
 [23] (::Enzyme.Compiler.var"#560#561"{DataType, UnionAll, DataType, Enzyme.API.CDerivativeMode, NTuple{6, Bool}, Int64, Bool, Bool, UInt64, DataType})(ctx::LLVM.Context)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5875
 [24] JuliaContext(f::Enzyme.Compiler.var"#560#561"{DataType, UnionAll, DataType, Enzyme.API.CDerivativeMode, NTuple{6, Bool}, Int64, Bool, Bool, UInt64, DataType}; kwargs::@Kwargs{})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/kqxyC/src/driver.jl:52
 [25] JuliaContext(f::Function)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/kqxyC/src/driver.jl:42
 [26] #s2027#559
    @ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5827 [inlined]
 [27] var"#s2027#559"(FA::Any, A::Any, TT::Any, Mode::Any, ModifiedBetween::Any, width::Any, ReturnPrimal::Any, ShadowInit::Any, World::Any, ABI::Any, ::Any, ::Type, ::Type, ::Type, tt::Any, ::Type, ::Type, ::Type, ::Type, ::Type, ::Any)
    @ Enzyme.Compiler ./none:0
 [28] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
    @ Core ./boot.jl:602
 [29] autodiff
    @ ~/.julia/packages/Enzyme/srACB/src/Enzyme.jl:286 [inlined]
 [30] autodiff
    @ ~/.julia/packages/Enzyme/srACB/src/Enzyme.jl:303 [inlined]
 [31] compute_energy_and_gradient(model::CompactLuxLayer{nothing, var"#1#2", @NamedTuple{λ::String}, @NamedTuple{}, Lux.ValueStorage{@NamedTuple{λ::Returns{Vector{Float32}}}, @NamedTuple{}}, Tuple{Tuple{}, Tuple{}}}, ps::@NamedTuple{λ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, st::@NamedTuple{}, H::CUDA.CUSPARSE.CuSparseMatrixCSC{Float32, Int32}, all_configurations::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Main ./In[19]:16
 [32] top-level scope
    @ In[35]:1
1 Like