Hello everyone. I wanted to benchmark Flux vs Jax so I created this basic example. The tldr is that the Flux version (posted below) takes about 40-50 seconds on an nvidia A6000 (compilation excluded). The exact same architecture/example in Jax takes about 6 seconds. How can I find the bottleneck? Am i doing something wrong?
using Zygote,Flux,Plots, CUDA
using Statistics: mean
using Flux: @functor
x = collect(LinRange(-0,1,1024))
x = collect(x')
y = cos.(x*2*pi).^10
# dump to gpu
x = cu(x)
y = cu(y)
acti = relu
latent= 128
epochs = 30000
model = Chain(
Dense(1,latent,acti),
Dense(latent,latent,acti),
LayerNorm(latent),
Dense(latent,1)
)
# %% testing inference
model = gpu(model) # gpu() and |> gpu are the same
display(model)
dump = model(x)
loss(x, y) = Flux.Losses.mse(model(x), y) |> gpu
optimizer = ADAM(0.001)
params = Flux.params(model)
@time for ii in 1:epochs
grads = Flux.gradient(()->loss(x,y),params)
Flux.Optimise.update!(optimizer,params,grads)
if ii % 5000 === 0
println("epoch: ",ii," |loss: ",loss(x,y))
end
#Flux.train!(loss,params,data,optimizer)
end
I wont be posting the jax code here, since I doubt someone would be interested , but If anyone asks, I will provide it. It is an exact replica of the example above ; same architecture, number of points , floating point representation etc.
PS i am not an expert Julia user so i might have done something stupid. Also, this is my first discourse post, so if I did something wrong forgive me.
PS2 This is the aforementioned jax code,
import jax.numpy as jnp
from jax import jit,vmap,grad
import jax
import flax
import optax
from flax.training.train_state import TrainState
import flax.linen as nn
from typing import Sequence
x = jnp.linspace(-1,1,1024).reshape(-1,1)
y = jnp.sin(x)
lista = [128,128,1]
class MLP(nn.Module):
features: Sequence[int]
@nn.compact
def __call__(self, x):
for feat in self.features[:-1]:
x = nn.Dense(feat,kernel_init=nn.initializers.glorot_uniform())(x)
x = nn.relu(x)
x = nn.LayerNorm()(x)
x = nn.Dense(self.features[-1],kernel_init=nn.initializers.glorot_uniform())(x)
return x
model = MLP(lista)
params = model.init(jax.random.PRNGKey(0),x)
tx = optax.adam(1e-3)
state = TrainState.create(apply_fn=model.apply,params=params,tx=tx)
def loss_mse(params,X,Y):
Y2 = model.apply(params,X)
loss = jnp.mean((Y-Y2)**2)
return loss
@jax.jit
def apply_model(state,X,Y):
grad_fn = jax.value_and_grad(loss_mse, argnums=(0))
loss, grads = grad_fn(state.params,X,Y)
state = state.apply_gradients(grads=grads)
return loss,state
#begin Training
epochs = 30000
from time import time
# Creating X,Y for training of thdf and w
tt = time()
for ii in range(epochs):
loss,state = apply_model(state,x,y)
if ii%5000==0:
print("epoch:",ii,'|| Loss:',loss)
print('Training took ',time()-tt,'seconds')