Hello,
I’m new in the neural network field, and I’m starting to study Neural Network Quantum States (for example ground state searching using neural networks). Usually, they are implemented in Python using Jax, but then I asked myself whether I could perform better using Julia (spoiler: not yet).
In the following, you will find the two examples, in Jax and in Lux.jl with Zygote. Just to make the example minimal, I take a random sparse matrix representing the Hamiltonian, then I generate all the 2^N possible states, and then apply a very simple network to this set of states.
I start with the Julia case
using Lux
using Zygote
using Random
using LinearAlgebra
using SparseArrays
using BenchmarkTools
# Create an RNG object
rng = MersenneTwister(1234)
function generate_combinations(N)
M = 2^N
col = 1:M
row = 1:N
combinations = 1 .- 2 .* mod.(cld.(col', 2 .^ (N .- row)), 2)
return Float32.(combinations)
end
# Example usage:
N = 15 # equivalent to 15 spins
all_configurations = generate_combinations(N)
size(all_configurations)
# 15×32768 Matrix{Float32}
Matrix generation
dim = 2^N
density = 2 / dim
H = sprandn(rng, Float32, dim, dim, density)
H = (H + H') / 2
# 32768×32768 SparseMatrixCSC{Float32, Int64} with 131678 stored entries
Neural Network definition
# I see that λ becomes Float64 although I declared it as Float32
model = @compact(λ=rand(rng, Float32, 1)*0.01) do x
y = λ .* x
return sum(logsigmoid(y), dims=1) / 2
end
function to_array(model, ps, st, all_configurations)
ψ, _ = model(all_configurations, ps, st)
return vec(exp.(ψ))
end
function compute_energy(model, ps, st, H, all_configurations)
ψ = to_array(model, ps, st, all_configurations)
return dot(ψ, H, ψ)
end
function compute_energy_and_gradient(model, ps, st, H, all_configurations)
return Zygote.withgradient(ps -> compute_energy(model, ps, st, H, all_configurations), ps)
end
ps, st = Lux.setup(rng, model)
# I make the parameter Float32
ps = (λ = Float32.(ps.λ), )
# (λ = Float32[0.005596993],)
Try to compute the energy
compute_energy(model, ps, st, H, all_configurations)
# 0.000719939269632178
And try to benchmark it
@benchmark compute_energy(model, ps, st, H, all_configurations)
BenchmarkTools.Trial: 957 samples with 1 evaluation.
Range (min … max): 4.577 ms … 15.013 ms ┊ GC (min … max): 0.00% … 6.86%
Time (median): 5.038 ms ┊ GC (median): 0.00%
Time (mean ± σ): 5.217 ms ± 715.179 μs ┊ GC (mean ± σ): 1.89% ± 3.71%
▂▆█▄ ▂▂▃
▂▃▃▃▄▆████▇▅▅▄▄▄▄▄▄▆███▇▅▃▃▃▃▂▃▂▂▂▂▁▂▁▂▂▂▃▃▃▄▃▃▂▂▂▂▁▁▁▁▁▁▁▂ ▃
4.58 ms Histogram: frequency by time 6.76 ms <
Memory estimate: 8.25 MiB, allocs estimate: 13.
Now benchmark the gradient
@benchmark compute_energy_and_gradient(model, ps, st, H, all_configurations)
BenchmarkTools.Trial: 5 samples with 1 evaluation.
Range (min … max): 897.608 ms … 1.083 s ┊ GC (min … max): 0.04% … 15.73%
Time (median): 1.020 s ┊ GC (median): 16.70%
Time (mean ± σ): 1.013 s ± 74.032 ms ┊ GC (mean ± σ): 13.33% ± 7.26%
█ █ █ █ █
█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁█ ▁
898 ms Histogram: frequency by time 1.08 s <
Memory estimate: 4.01 GiB, allocs estimate: 83.
And look at the 4GB of allocations!
Now let’s try the same with python, JAX and flax
%env JAX_PLATFORM_NAME=cpu
import numpy as np
from scipy.sparse import random as sparse_random
import jax
import jax.random as jrandom
import jax.numpy as jnp
from jax.experimental.sparse import BCOO
from functools import partial
# PRNG key for JAX
key = jrandom.PRNGKey(0)
import flax.linen as nn
jax.devices()
# [CpuDevice(id=0)]
def generate_combinations(N):
M = 2**N
col = jnp.arange(1, M + 1)
row = jnp.arange(1, N + 1)
jax.jit
def compute_combinations(col, row):
return 1 - 2 * jnp.mod(jnp.ceil(col[:, None] / (2 ** (N - row))), 2)
combinations = compute_combinations(col, row)
return combinations
# Example usage:
N = 15
all_configurations = generate_combinations(N)
all_configurations.shape
# (32768, 15)
Generation of the random matrix
# Parameters for the sparse matrix
dim = 2**N # Number of dim
density = 2 / dim # 2% sparsity
# Calculate the number of non-zero elements
num_nonzeros = int(dim * dim * density)
# Generate random indices for the sparse matrix
row_indices = jrandom.randint(key, (num_nonzeros,), 0, dim)
col_indices = jrandom.randint(key, (num_nonzeros,), 0, dim)
# Generate random values for the sparse matrix
values = jrandom.normal(key, (num_nonzeros,))
# Create the JAX sparse matrix (BCOO)
indices = jnp.vstack((row_indices, col_indices))
H = BCOO((values, indices.T), shape=(dim, dim))
H = (H + H.T) / 2
H
# BCOO(float32[32768, 32768], nse=131072)
Neural network definition
class MF(nn.Module):
@nn.compact
def __call__(self, x):
lam = self.param(
"lambda", nn.initializers.normal(), (1,), x.dtype
)
p = nn.log_sigmoid(lam*x)
return 0.5 * jnp.sum(p, axis=-1)
def to_array(model, parameters, all_configurations):
# now evaluate the model, and convert to a normalised wavefunction.
logpsi = model.apply(parameters, all_configurations)
psi = jnp.exp(logpsi)
psi = psi / jnp.linalg.norm(psi)
return psi
# we use partial to directly jit this function. Jitting the top-most will jit everything inside it as well.
@partial(jax.jit, static_argnames='model')
def compute_energy(model, parameters, H, all_configurations):
psi_gs = to_array(model, parameters, all_configurations)
return psi_gs.conj().T @ H @ psi_gs
@partial(jax.jit, static_argnames='model')
def compute_energy_and_gradient(model, parameters, H, all_configurations):
grad_fun = jax.value_and_grad(compute_energy, argnums=1)
return grad_fun(model, parameters, H, all_configurations)
# create an instance of the model
model = MF()
# initialise the weights
parameters = model.init(key, np.random.rand(N))
parameters
# {'params': {'lambda': Array([-0.01280743], dtype=float32)}}
Benchmarks energy computation
%timeit compute_energy(model, parameters, H, all_configurations)
# 822 µs ± 71.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Benchmark gradient
%timeit compute_energy_and_gradient(model, parameters, H, all_configurations)
# 2.29 ms ± 288 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
So we have almost a factor 500 of difference. I don’t know if I am doing something wrong. For sure I didn’t expected those 4GB of allocations using Zygote.