Hi everyone! I’m new to Julia and I’m trying to translate a JAX code in Julia using Tullio
macro that makes the code simple and clean.
However, I’m not able to get the same performances even with simple functions.
Here’s the minimal JAX WE and the relative timings on my laptop
import jax
jax.config.update("jax_enable_x64", True)
jax.config.update('jax_platform_name', 'cpu')
import jax.numpy as jnp
from flax import struct
@struct.dataclass
class TPLParams:
x_low: float
x_high: float
alpha: float
@jax.jit
def pdf(x, params):
# not normalized
return jnp.where((params.x_low < x) & (x < params.x_high),
x**params.alpha,
0.)
@jax.jit
def cdf(x, params):
# not normalized. if m = m_high the result is the normalization of the pdf
return jnp.where(params.alpha==-1,
jnp.log(params.x_low) - jnp.log(x),
(x**(1 + params.alpha) - params.x_low**(1 + params.alpha)) / (1 + params.alpha)
)
@jax.jit
def test(x, params):
norm = cdf(params.x_high, params)
return pdf(x, params) / norm
#test
p = TPLParams(x_low=1., x_high=10., alpha=2.)
x_scalar = 5.
x_vec = jnp.linspace(0,11,5000)
x_mat = jnp.array([x_vec for _ in range(300)])
# compile
test(x_scalar, p)
test(x_vec, p)
test(x_mat, p)
#timing
%timeit test(x_scalar, p).block_until_ready()
10.6 μs ± 544 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
%timeit test(x_vec, p).block_until_ready()
74 μs ± 15.3 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
%timeit test(x_mat, p).block_until_ready()
5.41 ms ± 1.06 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
To write the equivalent Julia code I decided to write the pdf
and cdf
functions for scalar input and broadcast them using Tullio
. This is a naive attempt to easily write agnostic CPU/GPU code similarly to JAX.
using LoopVectorization
using Tullio
using BenchmarkTools
abstract type Params end
struct TPLParams <: Params
x_low::Float64
x_high::Float64
alpha::Float64
end
# functions
function pdf(x::Real, params::TPLParams)::Float64
if params.x_low < x < params.x_high
return x^params.alpha
else
return 0.0
end
end
pdf(x::AbstractVector{<:Real}, params::Params) = @tullio result[i] := pdf(x[i], params)
pdf(x::AbstractArray{<:Real,2}, params::Params) = @tullio result[i, j] := pdf(x[i, j], params)
function cdf(x::Real, params::TPLParams)::Float64
if params.alpha == -1
return log(params.x_low) - log(x)
else
return (x^(1 + params.alpha) - params.x_low^(1 + params.alpha)) / (1 + params.alpha)
end
end
cdf(x::AbstractVector{<:Real}, params::Params) = @tullio result[i] := cdf(x[i], params)
cdf(x::AbstractArray{<:Real,2}, params::Params) = @tullio result[i, j] := cdf(x[i, j], params)
function test(x::Union{Real,AbstractVector{<:Real},AbstractArray{<:Real,2}}, params::Params)
norm = cdf(params.x_high, params)
return pdf(x, params) / norm
end
#test
p = TPLParams(1., 10., 2.)
x_scalar = 5.
x_vec = range(0,11,5000)
x_mat = repeat(x_vec', 300, 1)
# compile
test(x_scalar, p)
test(x_vec, p)
test(x_mat, p)
However, the computational times are ~2 times slower than JAX and the standard deviations are much larger in Julia than in JAX (especially in the x_mat
case which is the most used one):
During the Julia execution I also get this LoopVectorization
warning
┌ Warning: #= /home/mt/.julia/packages/Tullio/2zyFP/src/macro.jl:1093 =#:
│ `LoopVectorization.check_args` on your inputs failed; running fallback `@inbounds @fastmath` loop instead.
│ Use `warn_check_args=false`, e.g. `@turbo warn_check_args=false ...`, to disable this warning.
└ @ Main ~/.julia/packages/LoopVectorization/ImqiY/src/condense_loopset.jl:1166
Here’s my system information:
Hostname: mt-ThinkPad-P14s-Gen-3
CPU(s): 1 x 12th Gen Intel(R) Core(TM) i5-1240P
CPU target: alderlake
∘ CPU 1:
→ 12 cores (16 CPU-threads due to 2-way SMT)
→ 8 "efficiency cores", 4 "performance cores".
→ 1 NUMA domain
Detected GPUs: 1
nothing
Notebook launched with 16 threads
Can someone help me in understanding where the performance differences are and how to make the two codes equally fast? Thanks!