Hello everyone. I wanted to benchmark Flux vs Jax so I created this basic example. The tldr is that the Flux version (posted below) takes about 40-50 seconds on an nvidia A6000 (compilation excluded). The exact same architecture/example in Jax takes about 6 seconds. How can I find the bottleneck? Am i doing something wrong?
using Zygote,Flux,Plots, CUDA using Statistics: mean using Flux: @functor x = collect(LinRange(-0,1,1024)) x = collect(x') y = cos.(x*2*pi).^10 # dump to gpu x = cu(x) y = cu(y) acti = relu latent= 128 epochs = 30000 model = Chain( Dense(1,latent,acti), Dense(latent,latent,acti), LayerNorm(latent), Dense(latent,1) ) # %% testing inference model = gpu(model) # gpu() and |> gpu are the same display(model) dump = model(x) loss(x, y) = Flux.Losses.mse(model(x), y) |> gpu optimizer = ADAM(0.001) params = Flux.params(model) @time for ii in 1:epochs grads = Flux.gradient(()->loss(x,y),params) Flux.Optimise.update!(optimizer,params,grads) if ii % 5000 === 0 println("epoch: ",ii," |loss: ",loss(x,y)) end #Flux.train!(loss,params,data,optimizer) end
I wont be posting the jax code here, since I doubt someone would be interested , but If anyone asks, I will provide it. It is an exact replica of the example above ; same architecture, number of points , floating point representation etc.
PS i am not an expert Julia user so i might have done something stupid. Also, this is my first discourse post, so if I did something wrong forgive me.
PS2 This is the aforementioned jax code,
import jax.numpy as jnp from jax import jit,vmap,grad import jax import flax import optax from flax.training.train_state import TrainState import flax.linen as nn from typing import Sequence x = jnp.linspace(-1,1,1024).reshape(-1,1) y = jnp.sin(x) lista = [128,128,1] class MLP(nn.Module): features: Sequence[int] @nn.compact def __call__(self, x): for feat in self.features[:-1]: x = nn.Dense(feat,kernel_init=nn.initializers.glorot_uniform())(x) x = nn.relu(x) x = nn.LayerNorm()(x) x = nn.Dense(self.features[-1],kernel_init=nn.initializers.glorot_uniform())(x) return x model = MLP(lista) params = model.init(jax.random.PRNGKey(0),x) tx = optax.adam(1e-3) state = TrainState.create(apply_fn=model.apply,params=params,tx=tx) def loss_mse(params,X,Y): Y2 = model.apply(params,X) loss = jnp.mean((Y-Y2)**2) return loss @jax.jit def apply_model(state,X,Y): grad_fn = jax.value_and_grad(loss_mse, argnums=(0)) loss, grads = grad_fn(state.params,X,Y) state = state.apply_gradients(grads=grads) return loss,state #begin Training epochs = 30000 from time import time # Creating X,Y for training of thdf and w tt = time() for ii in range(epochs): loss,state = apply_model(state,x,y) if ii%5000==0: print("epoch:",ii,'|| Loss:',loss) print('Training took ',time()-tt,'seconds')