Hi everyone!
Following this topic, I’d like to understand how to write an efficient CPU/GPU agnostic code using AcceleratedKernels.jl.
My problem is writing the joint PDF of two random variables, one of which is conditioned on the other. I need to implement differente PDF model for the individual random variables and I’d like to use Julia dispatch to handle such different models.
Here’s the example Julia code I produced:
using CUDA
import AcceleratedKernels as AK
abstract type Params end
struct Data
x1::Float64
x2::Float64
end
const DataMatVec{Data} = Union{AbstractVector{Data},AbstractMatrix{Data}}
const BLOCK_SIZE = 256
# Interface for primary and secondary pdf: these have to be implemented case by case
function pdf_x1 end
function cdf_x1 end
function pdf_x2_given_x1 end
function cdf_x2_given_x1 end
# Joint distribution: this holds in generale
function pdf_x1x2(data::DataMatVec, params::Params)
res = similar(data, Float64)
norm_p1 = cdf_x1(params)
AK.map!(x -> pdf_x1(x, params) * inv(norm_p1) * pdf_x2_given_x1(x, params) * inv(cdf_x2_given_x1(x, params)),
res,
data;
block_size=BLOCK_SIZE)
return res
end
# Core functions
@inline @fastmath function tpl(x::Float64, alpha::Float64, x_min::Float64, x_max::Float64)::Float64
if x_min < x < x_max
return x^alpha
else
return 0.0
end
end
@inline @fastmath function tpl_cdf(x::Float64, alpha::Float64, x_min::Float64)::Float64
if alpha == -1
return log(x_min) - log(x)
else
return (x^(1 + alpha) - x_min^(1 + alpha)) / (1 + alpha)
end
end
# Specific Implementation
struct TPLParams <: Params
x_low::Float64
x_high::Float64
alpha::Float64
beta::Float64
end
pdf_x1(data::Data, params::TPLParams) = tpl(data.x1, -params.alpha, params.x_low, params.x_high)
cdf_x1(params::TPLParams) = tpl_cdf(params.x_high, -params.alpha, params.x_low)
pdf_x2_given_x1(data::Data, params::TPLParams) = tpl(data.x2, params.beta, params.x_low, data.x1)
cdf_x2_given_x1(data::Data, params::TPLParams) = tpl_cdf(data.x1, params.beta, params.x_low)
### for the other models it is sufficient to implement the last four functions and pdf_x1x2 will work
The benchmarks on CPU/GPU for vectors of data are:
using BenchmarkTools
params = TPLParams(0.1,5.,2.,1.5)
x1_vec = range(2,4,5000)
x2_vec = range(1,3,5000)
x1_vec_gpu = CuArray(x1_vec)
x2_vec_gpu = CuArray(x2_vec)
data_vec = Data.(x1_vec, x2_vec)
data_vec_gpu = Data.(x1_vec_gpu, x2_vec_gpu)
#warmup
pdf_x1x2(data_vec, params)
pdf_x1x2(data_vec_gpu, params)
#benchmark
@btime pdf_x1x2(res_v, $data_vec, $params) # 143.316 μs (89 allocations: 49.07 KiB)
@btime @CUDA.sync pdf_x1x2($data_vec_gpu, $params) # 127.649 μs (82 allocations: 3.00 KiB)
while for matrices I get:
x1_mat = repeat(x1_vec', 300, 1)
x2_mat = repeat(x2_vec', 300, 1)
x1_mat_gpu = CuArray(x1_mat)
x2_mat_gpu = CuArray(x2_mat)
data_mat = Data.(x1_mat, x2_mat)
data_mat_gpu = Data.(x1_mat_gpu, x2_mat_gpu)
#warmup
pdf_x1x2(data_mat, params)
pdf_x1x2(data_mat_gpu, params)
#benchmark
@btime pdf_x1x2(data_mat, params) # 31.754 ms (89 allocations: 11.45 MiB)
@btime @CUDA.sync pdf_x1x2(data_mat_gpu, params) # 30.811 ms (81 allocations: 3.02 KiB)
The equivalent JAX code that uses flax to write jax-comptaible struct and plum-dispatch to simulate dispatching is
import jax
jax.config.update("jax_enable_x64", True)
jax.config.update('jax_platform_name', 'cpu') # or 'gpu'
import jax.numpy as jnp
from flax import struct
from plum import dispatch
# structs
@struct.dataclass
class TPLParams:
x_low: float
x_high: float
alpha: float
beta: float
@struct.dataclass
class Data:
x1: jnp.ndarray
x2: jnp.ndarray
#Core functions
@jax.jit
def tpl(x, alpha, x_low, x_high):
return jnp.where((x_low < x) & (x < x_high),
x**alpha,
0.)
@jax.jit
def tpl_cdf(x, alpha, x_low):
return jnp.where(params.alpha==-1,
jnp.log(params.x_low) - jnp.log(x),
(x**(1 + params.alpha) - params.x_low**(1 + params.alpha)) / (1 + params.alpha)
)
# Function implementation
@dispatch
@jax.jit
def pdf_x1(data, params:TPLParams):
return tpl(data.x1, -params.alpha, params.x_low, params.x_high)
@dispatch
@jax.jit
def cdf_x1(params:TPLParams):
return tpl_cdf(params.x_high, -params.alpha, params.x_low)
@dispatch
@jax.jit
def pdf_x2_given_x1(data, params:TPLParams):
return tpl(data.x2, params.beta, params.x_low, data.x1)
@dispatch
@jax.jit
def cdf_x2_given_x1(data, params:TPLParams):
return tpl_cdf(data.x1, params.beta, params.x_low)
@dispatch
@jax.jit
def pdf_x1x2(data, params):
norm_p1 = cdf_x1(params)
p1 = pdf_x1(data, params)/norm_p1
p2 = pdf_x2_given_x1(data, params)/cdf_x2_given_x1(data, params)
return p1*p2
and the timings are better than in the AK case, especially on GPU:
#test
params = TPLParams(x_low=0.1, x_high=5., alpha=2., beta=1.)
x1_vec = jnp.linspace(2,4,5000)
x2_vec = jnp.linspace(1,3,5000)
x1_mat = jnp.array([x1_vec for _ in range(300)])
x2_mat = jnp.array([x2_vec for _ in range(300)])
data_vec = Data(x1_vec, x2_vec)
data_mat = Data(x1_mat, x2_mat)
# compile
pdf_x1x2(data_vec, params)
pdf_x1x2(data_mat, params)
#timing
%timeit pdf_x1x2(data_vec, params).block_until_ready()
# cpu: 115 μs ± 4.66 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
# gpu: 112 μs ± 2.02 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
%timeit pdf_x1x2(data_mat, params).block_until_ready()
# cpu 19.1 ms ± 4.01 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
# gpu 8.14 ms ± 38.8 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
I’m fairly certain that it’s possible to achieve the same or better timing in Julia as in JAX. Do you know where the problems in my Julia code are?
Thanks!