My (high-level) understanding of jax.vmap
is that it automatically vectorizes vectorizes a function along a specified axis of its input by introducing an “abstract” axis and compiling the code as though the inputs were shaped accordingly.
To see the difference let’s consider a very simple example where Julia’s broadcasting is much less performant than jax.vmap
.
Let’s consider how jax
internally represents vector-vector dot products:
import jax.numpy as np
from jax.api import jit, vmap
from jax import make_jaxpr
import numpy.random as npr
D = 10**3 # Data Dim
BS = 10**2 # Broadcast/Batch Dim
# Vector-Vector Dot
x = npr.randn(D)
y = npr.randn(D)
np.dot(x,y)
# lowers to intermediate representation
make_jaxpr(np.dot)(x,y)
#{ lambda ; a b.
# let c = dot_general[ dimension_numbers=(((0,), (0,)), ((), ()))
# precision=None ] a b
# in (c,) }
Compare this to how it represents matrix-vector
product:
# Matrix-vector product
X = npr.randn(BS,D)
y = npr.rand(D)
np.matmul(X,y)
# lowers to IR
make_jaxpr(np.matmul)(X,y)
#{ lambda ; a b.
# let c = reshape[ dimensions=None
# new_sizes=(1000, 1) ] b
# d = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))
# precision=None ] a c
# e = reshape[ dimensions=None
# new_sizes=(100,) ] d
# in (e,) }
In particular notice that both dot
and matul
lower to an internal representation function called dot_general
and that dot
has dimension_numbers=(((0,), (0,)), ((), ()))
where matmul
has dimension_numbers=(((1,), (0,)), ((), ()))
.
With this in mind, let’s consider what happens when we use jax.vmap
. It’s clear from x = npr.randn(D)
vs X = npr.randn(BS,D)
that the latter can be thought of as a collection of D-dimensional x
’s collected along the 0
th axis (sorry, Python is 0-based)…
So, the output of matmul
can be achieved by broadcasting over the 0th axis with dot
:
Xy = np.matmul(X,y)
# broadcast dot over 0th axis of first argument
# and do not broadcast over second argument (e.g. Ref(y))
broadcast_dot = vmap(np.dot, in_axes=(0,None))
# These are equivalent
np.allclose(Xy,broadcast_dot(X,y)) # True
# The IR for the broadcasted operation
# lowers to the equivalent matmul operation
make_jaxpr(broadcast_dot)(X,y)
#{ lambda ; a b.
# let c = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))
# precision=None ] a b
# in (c,) }
Critically, note that vmap “broadcasted” dot lowers to the same underlying representation as matmul. When these are compiled to hardware, (BLAS, CUDA, XLA…) they will call the same operations.
This is reflected in benchmarking. After jax.jit
both functions evaluate in approximately the same time:
jbd = jit(broadcast_dot)
%timeit jbd(X,y)
# 232 µs ± 2.63 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
jmm = jit(np.matmul)
%timeit jmm(X,y)
# 235 µs ± 4.81 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
This is true for other functions, even those with “built-in” broadcasting
# np.mean has optional axis argument
# lambda only for jitting
j_np_mean = jit(lambda X: np.mean(X, axis=0))
# can be achieved by vmapping over second axis
j_vmap_mean = jit(vmap(np.mean, in_axes = (1,)))
np.allclose(j_np_mean(X),j_vmap_mean(X)) # True
%timeit j_np_mean(X)
# 191 µs ± 10.3 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit j_vmap_mean(X)
# 180 µs ± 3.78 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Further, and I won’t demonstrate this herenow, but this is all composable e.g. within other jit’d functions and with AD.
Julia Broadcasting
Let’s see how Julia’s broadcasting compares. (note: that I am not suggesting to compare the timing results between Jax and Julia here, only to compare the hand-vectorized vs vmap’d/broadcasted timing within each language. Thanks @Mason for clarifying)
using BenchmarkTools
using LinearAlgebra: dot
D = 10^3
BS = 10^2
x = randn(D)
X = randn(BS,D)
y = randn(D)
dot(x,y);
By broadcasting (with Julia .
syntax sugar) we can compute the matrix-vector multiply by broadcasting over the 1
st slices of the first argument.
Since we don’t want to broadcast at all over the second argument, we use
Ref
… yuck.
broadcast_dot(X,y) = dot.(eachslice(X,dims=1), Ref(y))
isapprox( broadcast_dot(X,y), X*y) #true
# Performance is worse
@btime X*y;
# 18.488 μs (1 allocation: 896 bytes)
@btime broadcast_dot(X,y);
# 70.838 μs (108 allocations: 6.56 KiB)
And it’s clear from the IR that broadcasted dot does not get lowered to a single call to a more efficient matrix-vector multiplication:
@code_lowered X*y
# CodeInfo(
# 1 ─ Core.NewvarNode(:(y))
# │ TS = LinearAlgebra.promote_op(LinearAlgebra.matprod, $(Expr(:static_parameter, 1)), $(Expr(:static_parameter, 2)))
# │ %3 = LinearAlgebra.isconcretetype(TS)
# └── goto #3 if not %3
# 2 ─ %5 = Core.apply_type(LinearAlgebra.AbstractVector, TS)
# │ @_6 = LinearAlgebra.convert(%5, x)
# └── goto #4
# 3 ─ @_6 = x
# 4 ┄ y = @_6
# │ %10 = TS
# │ %11 = LinearAlgebra.size(A, 1)
# │ %12 = LinearAlgebra.similar(x, %10, %11)
# │ %13 = LinearAlgebra.mul!(%12, A, y)
# └── return %13
#)
@code_lowered broadcast_dot(X,y)
# CodeInfo(
# 1 ─ %1 = (:dims,)
# │ %2 = Core.apply_type(Core.NamedTuple, %1)
# │ %3 = Core.tuple(1)
# │ %4 = (%2)(%3)
# │ %5 = Core.kwfunc(Main.eachslice)
# │ %6 = (%5)(%4, Main.eachslice, X)
# │ %7 = Main.Ref(y)
# │ %8 = Base.broadcasted(Main.dot, %6, %7)
# │ %9 = Base.materialize(%8)
# └── return %9
# )
Also true in the mean
example:
using Statistics: mean
@btime mean($X, dims=1);
# 19.842 μs (1 allocation: 7.94 KiB)
@btime mean.($(eachslice(X,dims=2)));
# 29.634 μs (1002 allocations: 62.75 KiB)
Critically, unlike jax.vmap
Julia’s broadcast will lower to Base.broadcasted(Main.dot,...)
and not to a call of LinearAlgebra.mul!
. This has significant ramifications for downstream tasks. For instance, AD with Zygote of broadcasted functions is considerably more complex and less performant because it is not able to leverage a more performant rule like the adjoint of matmul
, and the emitted code of broadcast is messy. This especially is a noticeable issue at scales where hardware optimizations such as efficient matrix multiplication on GPU will dominate the broadcasted dot. (I may rerun these on GPU later).
The solution in Julia for me, unfortunately, has been to always write code that is batch-aware. That is, write all functions from the beginning as though it will accept a pre-specified batch dimension, so I can use hardware-backed kernels for fast matmul, as well for the gradients. Writing in Jax, with vmap
, comparatively, is much more freeing. I can write all my functions as though they are applied elementwise, and only consider batches of data when I am ready to apply them to batched data. (Especially nice that I don’t have to rely on conventions like batch-dim being 0
in Python and end
in Julia).