Flux benchmark being too slow vs Jax

Hello everyone. I wanted to benchmark Flux vs Jax so I created this basic example. The tldr is that the Flux version (posted below) takes about 40-50 seconds on an nvidia A6000 (compilation excluded). The exact same architecture/example in Jax takes about 6 seconds. How can I find the bottleneck? Am i doing something wrong?

using Zygote,Flux,Plots, CUDA
using Statistics: mean
using Flux: @functor

x = collect(LinRange(-0,1,1024))

x = collect(x')
y = cos.(x*2*pi).^10

# dump to gpu
x = cu(x)
y = cu(y)

acti  = relu
latent= 128
epochs = 30000

model = Chain(
    Dense(1,latent,acti),
    Dense(latent,latent,acti),
    LayerNorm(latent),
    Dense(latent,1)
)

# %% testing inference
model = gpu(model) # gpu() and |> gpu are the same
display(model)
dump = model(x)

loss(x, y) = Flux.Losses.mse(model(x), y) |> gpu

optimizer = ADAM(0.001)
params = Flux.params(model)

@time for ii in 1:epochs
    grads = Flux.gradient(()->loss(x,y),params)
    Flux.Optimise.update!(optimizer,params,grads)
    if ii % 5000 === 0
        println("epoch: ",ii," |loss: ",loss(x,y))
    end
    #Flux.train!(loss,params,data,optimizer)
end

I wont be posting the jax code here, since I doubt someone would be interested , but If anyone asks, I will provide it. It is an exact replica of the example above ; same architecture, number of points , floating point representation etc.

PS i am not an expert Julia user so i might have done something stupid. Also, this is my first discourse post, so if I did something wrong forgive me.

PS2 This is the aforementioned jax code,

import jax.numpy as jnp
from jax import jit,vmap,grad
import jax
import flax
import optax
from flax.training.train_state import TrainState
import flax.linen as nn
from typing import Sequence

x = jnp.linspace(-1,1,1024).reshape(-1,1)
y = jnp.sin(x)
lista = [128,128,1]

class MLP(nn.Module):
  features: Sequence[int]
  @nn.compact
  def __call__(self, x):
    for feat in self.features[:-1]:
      x = nn.Dense(feat,kernel_init=nn.initializers.glorot_uniform())(x)
      x = nn.relu(x)
    x = nn.LayerNorm()(x)
    x = nn.Dense(self.features[-1],kernel_init=nn.initializers.glorot_uniform())(x)
    return x

model = MLP(lista)

params = model.init(jax.random.PRNGKey(0),x)
tx     = optax.adam(1e-3)
state  = TrainState.create(apply_fn=model.apply,params=params,tx=tx)

def loss_mse(params,X,Y):
  Y2 = model.apply(params,X)
  loss = jnp.mean((Y-Y2)**2)
  return loss

@jax.jit
def apply_model(state,X,Y):
  grad_fn = jax.value_and_grad(loss_mse, argnums=(0))
  loss, grads = grad_fn(state.params,X,Y)
  state = state.apply_gradients(grads=grads)
  return  loss,state

#begin Training
epochs = 30000
from time import time
# Creating X,Y for training of thdf and w
tt = time()
for ii in range(epochs):
  loss,state = apply_model(state,x,y)
  if ii%5000==0:
    print("epoch:",ii,'|| Loss:',loss)

print('Training took ',time()-tt,'seconds')
1 Like

If you put your code in three back-ticks ``` then it will display nicely & won’t have smart quotes, like epoch: “,ii,”.

My first thought is that the problem seems pretty small for a GPU. This is a bit faster on my laptop (once I make x,y Float32) than on a V100. About 1/3 of the time is LayerNorm. It’s a little faster without implicit mode:

opt_state = Flux.setup(Adam(0.001), model)
@time for ii in 1:epochs
    grads = Flux.gradient(m -> Flux.Losses.mse(m(x), y), model)
    Flux.update!(opt_state, model, grads[1])
end

Posting the Jax code may let others see something in the comparison.

1 Like

Thank you very much for the reply, i rly appreciate it. I know the problem is small for a A6000, but the problem is that I am testing the same code on the same hardware and getting completely different performance. Flux running at lets say 50% of jax would be understandable, but a factor of x9.5 is very weird and it implies some (stupid and probably on my end) bottleneck.

Yes. I don’t see anything obviously slow, but also may miss something.

It is possible that a small test is very sensitive to e.g. the GC time to clean up many small arrays, or some difference in whether it waits for some kernel to finish before launching the next. Seeing 10x in something like that would be less surprising than 10x in actual computation. It can’t be spending much time computing if my laptop beats it. A total of 6s for 30k is 200μs per gradient, which seems easy to swamp, a few ms of overhead.

1 Like

Excellent point! Also the code you mentioned does not work. Flux.setup is not defined for me.

I would say having this is just as important as having the Julia MWE above, because we need to know e.g. whether jax.jit is being used. The answer for where and what can be optimized on the Julia side will vary quite a bit depending on specifics of the JAX implementation.

3 Likes

updated the original post with the jax code.

1 Like

Thanks! Using the technique in ahead-of-time lowering and compilation for jit by froystig · Pull Request #7997 · google/jax · GitHub, I had a look at the compiled output of apply_model. It looks like JAX/XLA is fusing a number of the layernorm operations together. In practice, this should reduce the number of allocations required and thus allocation-related overhead.

The bad news is that, lacking something like XLA, we can’t easily do similar automatic optimizations for Flux models. The silver lining is that, as @mcabbott noted, this overhead should be amortized for larger input and model sizes. If you are working at small scale most of the time, GitHub - PumasAI/SimpleChains.jl: Simple chains may be worth a look.

6 Likes

@ToucheSir this is an old issue but do we have any plans in the future for such optimizations. I am facing similar issues with my Flux models vs PyTorch 2.0 (which somehow makes the models go 2x speed per training iteration). I wanted to check if it is because of similar optimizations Torch is doing?

1 Like

Nothing has changed in the past couple of months. We lack the ability to do the kinds of automatic array-level optimizations JAX and PyTorch 2.0 can. That said, I find most Flux models posted here have low-hanging fruit to optimize, so if you have a specific example feel free to start a new topic for it.

2 Likes

Thanks for the response. Another related questions if my forward pass is about the same order compared to JAX/Torch should I see similar performance on backward pass too. I know the question is a bit abstract but I am trying to get a feel for what other optimizations the Julia ecosystem is missing. Thanks!

1 Like

No, that’s not something you can rely on. Performance of the backwards pass is so library and context specific that I’m not sure it’s possible to make such a general statement.

2 Likes