My (high-level) understanding of `jax.vmap`

is that it automatically vectorizes vectorizes a function along a specified axis of its input by introducing an “abstract” axis and compiling the code as though the inputs were shaped accordingly.

To see the difference let’s consider a very simple example where Julia’s broadcasting is much less performant than `jax.vmap`

.

Let’s consider how `jax`

internally represents vector-vector dot products:

```
import jax.numpy as np
from jax.api import jit, vmap
from jax import make_jaxpr
import numpy.random as npr
D = 10**3 # Data Dim
BS = 10**2 # Broadcast/Batch Dim
# Vector-Vector Dot
x = npr.randn(D)
y = npr.randn(D)
np.dot(x,y)
# lowers to intermediate representation
make_jaxpr(np.dot)(x,y)
#{ lambda ; a b.
# let c = dot_general[ dimension_numbers=(((0,), (0,)), ((), ()))
# precision=None ] a b
# in (c,) }
```

Compare this to how it represents `matrix-vector`

product:

```
# Matrix-vector product
X = npr.randn(BS,D)
y = npr.rand(D)
np.matmul(X,y)
# lowers to IR
make_jaxpr(np.matmul)(X,y)
#{ lambda ; a b.
# let c = reshape[ dimensions=None
# new_sizes=(1000, 1) ] b
# d = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))
# precision=None ] a c
# e = reshape[ dimensions=None
# new_sizes=(100,) ] d
# in (e,) }
```

In particular notice that both `dot`

and `matul`

lower to an internal representation function called `dot_general`

and that `dot`

has `dimension_numbers=(((0,), (0,)), ((), ()))`

where `matmul`

has `dimension_numbers=(((1,), (0,)), ((), ()))`

.

With this in mind, let’s consider what happens when we use `jax.vmap`

. It’s clear from `x = npr.randn(D)`

vs `X = npr.randn(BS,D)`

that the latter can be thought of as a collection of D-dimensional `x`

's collected along the `0`

th axis (sorry, Python is 0-based)…

So, the output of `matmul`

can be achieved by broadcasting over the 0th axis with `dot`

:

```
Xy = np.matmul(X,y)
# broadcast dot over 0th axis of first argument
# and do not broadcast over second argument (e.g. Ref(y))
broadcast_dot = vmap(np.dot, in_axes=(0,None))
# These are equivalent
np.allclose(Xy,broadcast_dot(X,y)) # True
# The IR for the broadcasted operation
# lowers to the equivalent matmul operation
make_jaxpr(broadcast_dot)(X,y)
#{ lambda ; a b.
# let c = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))
# precision=None ] a b
# in (c,) }
```

Critically, note that vmap “broadcasted” dot lowers to the same underlying representation as matmul. When these are compiled to hardware, (BLAS, CUDA, XLA…) they will call the same operations.

This is reflected in benchmarking. After `jax.jit`

both functions evaluate in approximately the same time:

```
jbd = jit(broadcast_dot)
%timeit jbd(X,y)
# 232 µs ± 2.63 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
jmm = jit(np.matmul)
%timeit jmm(X,y)
# 235 µs ± 4.81 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
```

This is true for other functions, even those with “built-in” broadcasting

```
# np.mean has optional axis argument
# lambda only for jitting
j_np_mean = jit(lambda X: np.mean(X, axis=0))
# can be achieved by vmapping over second axis
j_vmap_mean = jit(vmap(np.mean, in_axes = (1,)))
np.allclose(j_np_mean(X),j_vmap_mean(X)) # True
%timeit j_np_mean(X)
# 191 µs ± 10.3 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit j_vmap_mean(X)
# 180 µs ± 3.78 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
```

Further, and I won’t demonstrate this herenow, but this is all composable e.g. within other jit’d functions and with AD.

## Julia Broadcasting

Let’s see how Julia’s broadcasting compares. (note: that I am not suggesting to compare the timing results between Jax and Julia here, only to compare the hand-vectorized vs vmap’d/broadcasted timing within each language. Thanks @Mason for clarifying)

```
using BenchmarkTools
using LinearAlgebra: dot
D = 10^3
BS = 10^2
x = randn(D)
X = randn(BS,D)
y = randn(D)
dot(x,y);
```

By broadcasting (with Julia `.`

syntax sugar) we can compute the matrix-vector multiply by broadcasting over the `1`

st slices of the first argument.

Since we don’t want to broadcast at all over the second argument, we use

`Ref`

… yuck.

```
broadcast_dot(X,y) = dot.(eachslice(X,dims=1), Ref(y))
isapprox( broadcast_dot(X,y), X*y) #true
# Performance is worse
@btime X*y;
# 18.488 μs (1 allocation: 896 bytes)
@btime broadcast_dot(X,y);
# 70.838 μs (108 allocations: 6.56 KiB)
```

And it’s clear from the IR that broadcasted dot does not get lowered to a single call to a more efficient matrix-vector multiplication:

```
@code_lowered X*y
# CodeInfo(
# 1 ─ Core.NewvarNode(:(y))
# │ TS = LinearAlgebra.promote_op(LinearAlgebra.matprod, $(Expr(:static_parameter, 1)), $(Expr(:static_parameter, 2)))
# │ %3 = LinearAlgebra.isconcretetype(TS)
# └── goto #3 if not %3
# 2 ─ %5 = Core.apply_type(LinearAlgebra.AbstractVector, TS)
# │ @_6 = LinearAlgebra.convert(%5, x)
# └── goto #4
# 3 ─ @_6 = x
# 4 ┄ y = @_6
# │ %10 = TS
# │ %11 = LinearAlgebra.size(A, 1)
# │ %12 = LinearAlgebra.similar(x, %10, %11)
# │ %13 = LinearAlgebra.mul!(%12, A, y)
# └── return %13
#)
@code_lowered broadcast_dot(X,y)
# CodeInfo(
# 1 ─ %1 = (:dims,)
# │ %2 = Core.apply_type(Core.NamedTuple, %1)
# │ %3 = Core.tuple(1)
# │ %4 = (%2)(%3)
# │ %5 = Core.kwfunc(Main.eachslice)
# │ %6 = (%5)(%4, Main.eachslice, X)
# │ %7 = Main.Ref(y)
# │ %8 = Base.broadcasted(Main.dot, %6, %7)
# │ %9 = Base.materialize(%8)
# └── return %9
# )
```

Also true in the `mean`

example:

```
using Statistics: mean
@btime mean($X, dims=1);
# 19.842 μs (1 allocation: 7.94 KiB)
@btime mean.($(eachslice(X,dims=2)));
# 29.634 μs (1002 allocations: 62.75 KiB)
```

Critically, unlike `jax.vmap`

Julia’s broadcast will lower to `Base.broadcasted(Main.dot,...)`

and not to a call of `LinearAlgebra.mul!`

. This has significant ramifications for downstream tasks. For instance, AD with Zygote of broadcasted functions is considerably more complex and less performant because it is not able to leverage a more performant rule like the adjoint of `matmul`

, and the emitted code of broadcast is messy. This especially is a noticeable issue at scales where hardware optimizations such as efficient matrix multiplication on GPU will dominate the broadcasted dot. (I may rerun these on GPU later).

The solution in Julia for me, unfortunately, has been to always write code that is batch-aware. That is, write all functions from the beginning as though it will accept a pre-specified batch dimension, so I can use hardware-backed kernels for fast matmul, as well for the gradients. Writing in Jax, with `vmap`

, comparatively, is much more freeing. I can write all my functions as though they are applied elementwise, and only consider batches of data when I am ready to apply them to batched data. (Especially nice that I don’t have to rely on conventions like batch-dim being `0`

in Python and `end`

in Julia).