Julia (AcceleratedKernels) vs JAX time comparison

Hi everyone!
Following this topic, I’d like to understand how to write an efficient CPU/GPU agnostic code using AcceleratedKernels.jl.

My problem is writing the joint PDF of two random variables, one of which is conditioned on the other. I need to implement differente PDF model for the individual random variables and I’d like to use Julia dispatch to handle such different models.
Here’s the example Julia code I produced:

using CUDA
import AcceleratedKernels as AK

abstract type Params end

struct Data
  x1::Float64
  x2::Float64
end
const DataMatVec{Data} = Union{AbstractVector{Data},AbstractMatrix{Data}}
const BLOCK_SIZE = 256

# Interface for primary and secondary pdf: these have to be implemented case by case
function pdf_x1 end
function cdf_x1 end
function pdf_x2_given_x1 end
function cdf_x2_given_x1 end

# Joint distribution: this holds in generale
function pdf_x1x2(data::DataMatVec, params::Params)
  res = similar(data, Float64)
  norm_p1 = cdf_x1(params)
  AK.map!(x -> pdf_x1(x, params) * inv(norm_p1) * pdf_x2_given_x1(x, params) * inv(cdf_x2_given_x1(x, params)),
    res,
    data;
    block_size=BLOCK_SIZE)
  return res
end

# Core functions 
@inline @fastmath function tpl(x::Float64, alpha::Float64, x_min::Float64, x_max::Float64)::Float64
  if x_min < x < x_max
    return x^alpha
  else
    return 0.0
  end
end

@inline @fastmath function tpl_cdf(x::Float64, alpha::Float64, x_min::Float64)::Float64
  if alpha == -1
    return log(x_min) - log(x)
  else
    return (x^(1 + alpha) - x_min^(1 + alpha)) / (1 + alpha)
  end
end

# Specific Implementation 

struct TPLParams <: Params
  x_low::Float64
  x_high::Float64
  alpha::Float64
  beta::Float64
end

pdf_x1(data::Data, params::TPLParams) = tpl(data.x1, -params.alpha, params.x_low, params.x_high)
cdf_x1(params::TPLParams) = tpl_cdf(params.x_high, -params.alpha, params.x_low)
pdf_x2_given_x1(data::Data, params::TPLParams) = tpl(data.x2, params.beta, params.x_low, data.x1)
cdf_x2_given_x1(data::Data, params::TPLParams) = tpl_cdf(data.x1, params.beta, params.x_low)

### for the other models it is sufficient to implement the last four functions and pdf_x1x2 will work

The benchmarks on CPU/GPU for vectors of data are:

using BenchmarkTools

params = TPLParams(0.1,5.,2.,1.5)

x1_vec = range(2,4,5000) 
x2_vec = range(1,3,5000) 
x1_vec_gpu = CuArray(x1_vec)
x2_vec_gpu = CuArray(x2_vec)

data_vec = Data.(x1_vec, x2_vec)
data_vec_gpu = Data.(x1_vec_gpu, x2_vec_gpu)

#warmup
pdf_x1x2(data_vec, params)
pdf_x1x2(data_vec_gpu, params)
#benchmark
@btime pdf_x1x2(res_v, $data_vec, $params) # 143.316 μs (89 allocations: 49.07 KiB)
@btime @CUDA.sync pdf_x1x2($data_vec_gpu, $params) # 127.649 μs (82 allocations: 3.00 KiB)

while for matrices I get:

x1_mat = repeat(x1_vec', 300, 1) 
x2_mat = repeat(x2_vec', 300, 1) 
x1_mat_gpu = CuArray(x1_mat)
x2_mat_gpu = CuArray(x2_mat)

data_mat = Data.(x1_mat, x2_mat)
data_mat_gpu = Data.(x1_mat_gpu, x2_mat_gpu)

#warmup
pdf_x1x2(data_mat, params)
pdf_x1x2(data_mat_gpu, params)
#benchmark
@btime pdf_x1x2(data_mat, params) # 31.754 ms (89 allocations: 11.45 MiB)
@btime @CUDA.sync pdf_x1x2(data_mat_gpu, params) # 30.811 ms (81 allocations: 3.02 KiB)

The equivalent JAX code that uses flax to write jax-comptaible struct and plum-dispatch to simulate dispatching is

import jax
jax.config.update("jax_enable_x64", True)
jax.config.update('jax_platform_name', 'cpu') # or 'gpu'
import jax.numpy as jnp
from flax import struct  
from plum import dispatch

# structs
@struct.dataclass
class TPLParams:
  x_low: float
  x_high: float
  alpha: float
  beta: float

@struct.dataclass
class Data:
  x1: jnp.ndarray
  x2: jnp.ndarray

#Core functions
@jax.jit
def tpl(x, alpha, x_low, x_high):
  return jnp.where((x_low < x) & (x < x_high),
    x**alpha,
    0.)

@jax.jit
def tpl_cdf(x, alpha, x_low):
  return jnp.where(params.alpha==-1,
    jnp.log(params.x_low) - jnp.log(x),
    (x**(1 + params.alpha) - params.x_low**(1 + params.alpha)) / (1 + params.alpha)
  )

# Function implementation
@dispatch
@jax.jit
def pdf_x1(data, params:TPLParams):
  return tpl(data.x1, -params.alpha, params.x_low, params.x_high)

@dispatch
@jax.jit
def cdf_x1(params:TPLParams):
  return tpl_cdf(params.x_high, -params.alpha, params.x_low)

@dispatch
@jax.jit
def pdf_x2_given_x1(data, params:TPLParams):
  return  tpl(data.x2, params.beta, params.x_low, data.x1)

@dispatch
@jax.jit
def cdf_x2_given_x1(data, params:TPLParams):
  return tpl_cdf(data.x1, params.beta, params.x_low)

@dispatch
@jax.jit
def pdf_x1x2(data, params):
  norm_p1 = cdf_x1(params)
  p1 = pdf_x1(data, params)/norm_p1
  p2 = pdf_x2_given_x1(data, params)/cdf_x2_given_x1(data, params)
  return p1*p2

and the timings are better than in the AK case, especially on GPU:

#test
params = TPLParams(x_low=0.1, x_high=5., alpha=2., beta=1.)

x1_vec = jnp.linspace(2,4,5000)
x2_vec = jnp.linspace(1,3,5000)
x1_mat = jnp.array([x1_vec for _ in range(300)])
x2_mat = jnp.array([x2_vec for _ in range(300)])

data_vec = Data(x1_vec, x2_vec)
data_mat = Data(x1_mat, x2_mat)

# compile
pdf_x1x2(data_vec, params)
pdf_x1x2(data_mat, params)

#timing 
%timeit pdf_x1x2(data_vec, params).block_until_ready()
# cpu: 115 μs ± 4.66 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
# gpu: 112 μs ± 2.02 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
%timeit pdf_x1x2(data_mat, params).block_until_ready()
# cpu 19.1 ms ± 4.01 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
# gpu 8.14 ms ± 38.8 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

I’m fairly certain that it’s possible to achieve the same or better timing in Julia as in JAX. Do you know where the problems in my Julia code are?
Thanks!

2 Likes

Are you using the same precision? Jax might default to 32bit while julia is doing 64 bit

Yes, this line jax.config.update("jax_enable_x64", True) forces JAX to use 64bit.

1 Like

Oh, thanks! That option couldn’t be less clear.

1 Like

And it’s just one of the reasons to switch to Julia! I just have to figure out how to make Julia equally fast and my whole code will look much cleaner and clearer.

1 Like

Out of curiosity, could you share what GPU and CPU you’re using and other versioninfo?

I find it kinda suspicious that the CPU and GPU timings are so close for you, especially given that you’re doing Float64 math. I guess maybe all the time even in the GPU case is being taken up by the CPU side of things for you?

For what it’s worth, here’s what I see on a Ryzen 5600x and a RTX 3070:

Bench CPU Vector:    63.711 μs (39 allocations: 42.98 KiB)
Bench GPU Vector:    42.251 μs (82 allocations: 3.00 KiB)
Bench CPU Matrix:    17.415 ms (39 allocations: 11.45 MiB)
Bench GPU Matrix:    3.653 ms (81 allocations: 3.02 KiB)

I don’t really want to go through the pain-in-the-butt of installing Jax, Flax, and Plum on my machine so I can’t give you comparison numbers for that software.


Here’s my versioninfo:

julia> versioninfo()
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 12 × AMD Ryzen 5 5600X 6-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 6 default, 0 interactive, 3 GC (on 12 virtual cores)
Environment:
  JULIA_NUM_THREADS = 6

and my CUDA versioninfo:

julia> CUDA.versioninfo()
CUDA runtime 12.9, artifact installation
CUDA driver 12.8
NVIDIA driver 570.153.2

CUDA libraries: 
- CUBLAS: 12.9.0
- CURAND: 10.3.10
- CUFFT: 11.4.0
- CUSOLVER: 11.7.4
- CUSPARSE: 12.5.9
- CUPTI: 2025.2.0 (API 27.0.0)
- NVML: 12.0.0+570.153.2

Julia packages: 
- CUDA: 5.8.2
- CUDA_Driver_jll: 0.13.0+0
- CUDA_Runtime_jll: 0.17.0+0

Toolchain:
- Julia: 1.11.5
- LLVM: 16.0.6

1 device:
  0: NVIDIA GeForce RTX 3070 (sm_86, 3.251 GiB / 8.000 GiB available)
1 Like

Thanks for the reply and for having tried my code! The fact that on your GPU the code is effectively faster is a good sign (at the end I will run my code on a cluster of A100 GPUs).

Here’s my versioninfo:

Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 16 × 12th Gen Intel(R) Core(TM) i5-1240P
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, alderlake)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
  JULIA_NUM_THREADS = 16

and my CUDA.versioninfo()

CUDA runtime 12.9, artifact installation
CUDA driver 12.9
NVIDIA driver 535.230.2

CUDA libraries: 
- CUBLAS: 12.9.0
- CURAND: 10.3.10
- CUFFT: 11.4.0
- CUSOLVER: 11.7.4
- CUSPARSE: 12.5.9
- CUPTI: 2025.2.0 (API 27.0.0)
- NVML: 12.0.0+535.230.2

Julia packages: 
- CUDA: 5.8.2
- CUDA_Driver_jll: 0.13.0+0
- CUDA_Runtime_jll: 0.17.0+0

Toolchain:
- Julia: 1.11.5
- LLVM: 16.0.6

1 device:
  0: NVIDIA T550 Laptop GPU (sm_75, 3.370 GiB / 4.000 GiB available)

Maybe my GPU is not so performant. Or do you see something else that is not correct?

1 Like

This is so weird, here are my times :
julia :

  34.700 μs (109 allocations: 53.38 KiB)
vec gpu
  37.400 μs (82 allocations: 3.00 KiB)
mat cpu
  7.323 ms (109 allocations: 11.46 MiB)
mat gpu
  5.451 ms (81 allocations: 3.02 KiB)

jax (only cpu for me) :

cpuvec  0.1568954000249505 ms
cpumat  9.355242500081658 ms

btw, how can I test on gpu with jax i get :

Traceback (most recent call last):
  File "C:\Users\yolha\Desktop\juju_tests\Tests\test.py", line 67, in <module>
    x1_vec = jnp.linspace(2,4,5000)
             ^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\yolha\AppData\Local\Programs\Python\Python312\Lib\site-packages\jax\_src\numpy\array_creation.py", line 504, in linspace
    return _linspace(start, stop, num, endpoint, retstep, dtype, axis, device=device)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Unknown backend: 'gpu' requested, but no platforms that are instances of gpu are present. Platforms present are: cpu      
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Thanks for this benchmark!

I tried to use Google Colab CPU (2 threads) and Tesla T4 GPU. Here’s the summary of the benchmarks:

Julia:

CPU vec  385.728 μs (14 allocations: 40.44 KiB)
GPU vec  60.261 μs (130 allocations: 5.69 KiB)

CPU mat 118.224 ms (16 allocations: 11.45 MiB)
GPU mat  10.197 ms (87 allocations: 3.33 KiB)

JAX

CPU vec  259 µs ± 8.71 µs
GPU vec  544 μs ± 145 μs

CPU mat  42.6 ms ± 2.32 ms
GPU mat  3.27 ms ± 21.6 µs

While in the vector case the timings are similar, in the matrix case JAX is 3-4 times faster than Julia. I tried to flat the matrix and reshape the result but the time does not change.

Tried colab :
Julia :
cpuvec 0.379528 ms
gpuvec 0.058308 ms
cpumat 118.753 ms
gpumat 5.136 ms
Jax :
cpuvec 0.6700139019999938 ms
gpuvec 0.5178199249999977 ms
cpumat 67.233181347 ms
gpumat 3.4971002119999923 ms

Can we fix the block size on jax side ? it seems like this is the only thing that would explain so mush worst for vec and better for mat. Also Jax may generate a 2D kernel for the matrix and 1D one for the vector while we only do 1D kernels in julia (in AK at least)

CPU mat case is explainable by simd opt in jax while not being done in julia.

1 Like

Can we fix the block size on jax side ?

Nope, as far as I know. It’s like black-box magic for me JAX on GPU.
I tried different block size in Julia but the results are stable.

Also Jax may generate a 2D kernel for the matrix and 1D one for the vector while we only do 1D kernels in julia (in AK at least)

To prove this I did this simple test with JAX on GPU:

params = TPLParams(x_low=0.1, x_high=5., alpha=2., beta=1.)

x1_vec = jnp.linspace(2,4,5000)
x2_vec = jnp.linspace(1,3,5000)
x1_mat = jnp.array([x1_vec for _ in range(300)])
x2_mat = jnp.array([x2_vec for _ in range(300)])
data_vec = Data(x1_vec, x2_vec)
data_mat = Data(x1_mat, x2_mat)
data_mat_flatten = Data(x1_mat.flatten(), x2_mat.flatten())

# compile
pdf_x1x2(data_vec, params)
pdf_x1x2(data_mat, params)
pdf_x1x2(data_mat_flatten, params)

%timeit pdf_x1x2(data_mat_flatten, params).block_until_ready()
# 3.49 ms ± 201 µs 

%timeit pdf_x1x2(data_mat, params).block_until_ready()
# 3.28 ms ± 41.6 µs --> faster than Julia!

The timings in the 2d matrix and 1d flattened matrix cases are pretty similar so I don’t think the problem is this one.

Any idea other idea?

function specase_cuda(res,data,inorm_p1,params)
  bid = blockIdx()
  bdim = blockDim()
  tid = threadIdx()
  I = (bid.x-1i32) * bdim.x + tid.x
  J = (bid.y-1i32) * bdim.y + tid.y
  s1,s2 = size(data)
  if 1<=I<=s1 && 1<=J<=s2
    @inbounds x = data[I,J]
    @inbounds res[I,J] = pdf_x1(x,params)*inorm_p1*pdf_x2_given_x1(x,params)*inv(cdf_x2_given_x1(x,params))
  end
  return
end
function pdf_x1x2(data::AbstractMatrix, params::Params)
  res = similar(data, Float64)
  norm_p1 = cdf_x1(params)
  inorm_p1 = 1 / norm_p1
  bck = get_backend(data)
  if get_backend(data) isa GPU
    #specase!(bck,(BLOCK_SIZE,))(res,data,inorm_p1,params;ndrange=length(data))
    nb = cld.(size(data),(16,16))
    @cuda blocks=nb threads=(16,16) specase_cuda(res,data,inorm_p1,params)
  else
    AK.map!(x -> pdf_x1(x, params) * inorm_p1 * pdf_x2_given_x1(x, params) * inv(cdf_x2_given_x1(x, params)),
      res,
      data
      )
  end
  return res
end

This doesn’t help a lot actually 5ms at best but there is a lot block size to test since your matrix is (300,5000) so I might have missed some. Also, In jax you have a structure of 2 cuda array while in julia you have a matrix of types Data(::Float64,Float64), not sure it matters but it might disable some optim.

OK NEVERMIND,
you got something wrong on jax side :

@jax.jit
def tpl_cdf(x, alpha, x_low):
  return jnp.where(params.alpha==-1,
    jnp.log(params.x_low) - jnp.log(x),
    (x**(1 + params.alpha) - params.x_low**(1 + params.alpha)) / (1 + params.alpha)
  )

should be

@jax.jit
def tpl_cdf(x, alpha, x_low):
  return jnp.where(alpha==-1,
    jnp.log(x_low) - jnp.log(x),
    (x**(1 + alpha) - x_low**(1 + alpha)) / (1 + alpha)
  )

instead of using the global params.alpha or x_low ect ect struct
now its 5.213037197000176 ms for the matrix case as in julia.
Basicly the gpu didn’t need to look at global memory for params it was fully inlined by jax (or wrong I dont know jax enough) because outside the function and this lowered a lot the time, we could mimic this in julia using Val everywhere or setting it as const and don’t pass it to fct.

edit:
results had to be wrong because, cdf_x1 and cdf_x2_given_x1 actually called tpl_cdf(params.x_high, params.alpha, params.x_low) because it didn’t care of what went in.

So :

  • AK already had the perf even when replaced with manual cuda kernel (5ms)
  • Jax only was faster because of a global variable being inlined and not looked by the gpu (or simply wrong didnt test values)
  • Julia beats jax on the vector case on gpu (probebly because we set block size to 256 instead of 1024)
7 Likes

Out of curiosity, what are the new timings comparing jax and julia once the mistake is corrected?

1 Like

julia (AK code) on colab

cpuvec :   0.393902 ms 
gpuvec :   0.062267 ms
cpumat :   119.943 ms 
gpumat :   5.160 ms 

jax with fix on colab :

cpuvec  0.6152610739999886 ms
gpuvec  0.7260259349999956 ms
cpumat  95.42822554800003 ms
gpumat  5.124378841000009 ms

some hints :

  • cpuvec : julia faster probably because of overhead when calling jax
  • gpuvec : julia far far faster probably because of the choice of 256 for block dim and again the overhead of calling jax
  • cpumat : jax faster probably because of better simd ( not happening when test on local machine)
  • gpumat : timing similar probably because of unbaresly parallelisable code hiting memory bandwidth as some block dim <=256 ( which is why I digged so deap to understand why )
12 Likes

Thanks! This is amazing, I think AK is definetely the best way to go in my case!

1 Like

Thanks for this!

How can you tell?

no idea actually but its a guess, also weird :
on local machine julia is faster for cpu mat but only on f64 not on f32.

Jax is designed to modify code to muladds and simd so its what I thought
In local I have 20 threads which means running in paralel is more important than simd => julia wins
On colab only 2 threads available on the T4 kernel so simd is highly important => jax wins

Don’t ask why f32 changes everything

1 Like

can you add the missing functions so that it runs on its own in Segfault on `#main` branch with the PoCL backend · Issue #613 · JuliaGPU/KernelAbstractions.jl · GitHub and try to lower it a bit.

maybe try what your functions does in a very simple way :

you need a function foo(x,param) that do something to a float using a struct param and that contains an if statement. Then you have a kernel broadcasting it thats pretty much it.

I don’t know why this works while your code doesn’t but we may be able to adapt it before issueing :

using KernelAbstractions

struct Param
  a::Float64
  b::Float64
end

function foo(x,param::Param)
  if x>param.a
    return x^param.b
  else
    return 0.0
  end
end

@kernel function kernel!(y,x,param::Param)
  i = @index(Global, Linear)
  y[i] = foo(x[i],param) + x[i]*inv(param.a)
end

function main(x,param)
  y = similar(x)
  bck = get_backend(x)
  ker = kernel!(bck)
  ker(y,x,param;ndrange=length(x), workgroupsize=256)
  return y
end

x = rand(300,5000)
param = Param(0.5,1.0)
y = main(x,param)

Also, your code seems to work on KA main can you check you didn’t do weird stuff :

using KernelAbstractions

abstract type Params end

struct Data
  x1::Float64
  x2::Float64
end
const DataMatVec{Data} = Union{AbstractVector{Data},AbstractMatrix{Data}}
const BLOCK_SIZE = 256

# Interface for primary and secondary pdf: these have to be implemented case by case
function pdf_x1 end
function cdf_x1 end
function pdf_x2_given_x1 end
function cdf_x2_given_x1 end

# Joint distribution: this holds in generale

@kernel function ker(res,data,params,norm_p1)
  i = @index(Global)
  x = data[i]
  res[i] = pdf_x1(x, params) * inv(norm_p1) * pdf_x2_given_x1(x, params) * inv(cdf_x2_given_x1(x, params))
end

function pdf_x1x2(data::DataMatVec, params::Params)
  res = similar(data, Float64)
  norm_p1 = cdf_x1(params)
  bck = get_backend(res)
  ker(bck)(res, data, params, norm_p1;ndrange=length(data),workgroupsize=BLOCK_SIZE)
  res
end

# Core functions 
@inline @fastmath function tpl(x::Float64, alpha::Float64, x_min::Float64, x_max::Float64)::Float64
  if x_min < x < x_max
    return x^alpha
  else
    return 0.0
  end
end

@inline @fastmath function tpl_cdf(x::Float64, alpha::Float64, x_min::Float64)::Float64
  if alpha == -1
    return log(x_min) - log(x)
  else
    return (x^(1 + alpha) - x_min^(1 + alpha)) / (1 + alpha)
  end
end

struct TPLParams <: Params
  x_low::Float64
  x_high::Float64
  alpha::Float64
  beta::Float64
end

pdf_x1(data::Data, params::TPLParams) = tpl(data.x1, -params.alpha, params.x_low, params.x_high)
cdf_x1(params::TPLParams) = tpl_cdf(params.x_high, -params.alpha, params.x_low)
pdf_x2_given_x1(data::Data, params::TPLParams) = tpl(data.x2, params.beta, params.x_low, data.x1)
cdf_x2_given_x1(data::Data, params::TPLParams) = tpl_cdf(data.x1, params.beta, params.x_low)


using BenchmarkTools

param = TPLParams(0.1,5.,2.,1.5)

x1_vec = range(2,4,5000) 
x2_vec = range(1,3,5000) 
data_vec = Data.(x1_vec, x2_vec)
pdf_x1x2(data_vec, param)
x1_mat = repeat(x1_vec', 300, 1) 
x2_mat = repeat(x2_vec', 300, 1) 

data_mat = Data.(x1_mat, x2_mat)
pdf_x1x2(data_mat, param)

Status `C:\Users\yolha\Desktop\juju_tests\Tests\Project.toml`
  [63c18a36] KernelAbstractions v0.10.0-dev `https://github.com/JuliaGPU/KernelAbstractions.jl.git#main`

Also specifying block size on cpu is weird.
Last note :
AK actually uses OhMyThreads for its cpu parralel cases not KA and I’m not sure KA will ever beat that by design

Hi, I did not understand well if I have to write something on Github.
BTW, the last code you sent that uses KA instead of AK works fine for me.

It’s just for next time try to lower the problem as much as possible even though random segfaults are hard to reproduce. Oh are you not the one who posted the issue?