Julia's Broadcast vs Jax's vmap

I ran the benchmarks on the GPU:

using BenchmarkTools, CuArrays
using LinearAlgebra: dot

D = 10^3
BS = 10^2

x = randn(D)
X = randn(D, BS)
y = randn(D)
cX = cu(X)
cy = cu(y)
Xt = permutedims(X)
cXt = cu(Xt)

dot(x, y)
dot(cu(x), cy)

broadcast_dot(X, y) = [dot(x, y) for x in eachslice(X; dims = 2)]
matmul_dot(Xt, y) = Xt * y

Now running on the CPU:

@btime broadcast_dot($X, $y)
16.652 μs (108 allocations: 6.56 KiB)

@btime matmul_dot($Xt, $y)
13.867 μs (1 allocation: 896 bytes)

And on the GPU:

@btime CuArrays.@sync broadcast_dot($cX, $cy)
321.091 ms (208 allocations: 8.45 KiB)

@btime CuArrays.@sync matmul_dot($cXt, $cy)
238.609 μs (8 allocations: 208 bytes)

Of note is that the following definition did not work:

broadcast_dot(X, y) = dot.(eachslice(X; dims = 2), Ref(y))
3 Likes