Julia's Broadcast vs Jax's vmap

My (high-level) understanding of jax.vmap is that it automatically vectorizes vectorizes a function along a specified axis of its input by introducing an “abstract” axis and compiling the code as though the inputs were shaped accordingly.

To see the difference let’s consider a very simple example where Julia’s broadcasting is much less performant than jax.vmap.

Let’s consider how jax internally represents vector-vector dot products:


import jax.numpy as np
from jax.api import jit, vmap
from jax import make_jaxpr
import numpy.random as npr

D = 10**3 # Data Dim
BS = 10**2 # Broadcast/Batch Dim


# Vector-Vector Dot

x = npr.randn(D)
y = npr.randn(D)

np.dot(x,y)

# lowers to intermediate representation
make_jaxpr(np.dot)(x,y)


#{ lambda  ; a b.
#  let c = dot_general[ dimension_numbers=(((0,), (0,)), ((), ()))
#                       precision=None ] a b
#  in (c,) }

Compare this to how it represents matrix-vector product:


# Matrix-vector product

X = npr.randn(BS,D)
y = npr.rand(D)

np.matmul(X,y)

# lowers to IR
make_jaxpr(np.matmul)(X,y)

#{ lambda  ; a b.
#  let c = reshape[ dimensions=None
#                   new_sizes=(1000, 1) ] b
#      d = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))
#                       precision=None ] a c
#      e = reshape[ dimensions=None
#                   new_sizes=(100,) ] d
#  in (e,) }

In particular notice that both dot and matul lower to an internal representation function called dot_general and that dot has dimension_numbers=(((0,), (0,)), ((), ())) where matmul has dimension_numbers=(((1,), (0,)), ((), ())).

With this in mind, let’s consider what happens when we use jax.vmap. It’s clear from x = npr.randn(D) vs X = npr.randn(BS,D) that the latter can be thought of as a collection of D-dimensional x’s collected along the 0th axis (sorry, Python is 0-based)…

So, the output of matmul can be achieved by broadcasting over the 0th axis with dot:

Xy = np.matmul(X,y)

# broadcast dot over 0th axis of first argument
# and do not broadcast over second argument (e.g. Ref(y))
broadcast_dot = vmap(np.dot, in_axes=(0,None))

# These are equivalent
np.allclose(Xy,broadcast_dot(X,y)) # True

# The IR for the broadcasted operation 
# lowers to the equivalent matmul operation
make_jaxpr(broadcast_dot)(X,y)

#{ lambda  ; a b.
#  let c = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))
#                       precision=None ] a b
#  in (c,) }

Critically, note that vmap “broadcasted” dot lowers to the same underlying representation as matmul. When these are compiled to hardware, (BLAS, CUDA, XLA…) they will call the same operations.

This is reflected in benchmarking. After jax.jit both functions evaluate in approximately the same time:

jbd = jit(broadcast_dot)
%timeit jbd(X,y) 
# 232 µs ± 2.63 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

jmm = jit(np.matmul)
%timeit jmm(X,y)
# 235 µs ± 4.81 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

This is true for other functions, even those with “built-in” broadcasting


# np.mean has optional axis argument
# lambda only for jitting
j_np_mean = jit(lambda X: np.mean(X, axis=0))

# can be achieved by vmapping over second axis
j_vmap_mean = jit(vmap(np.mean, in_axes = (1,)))

np.allclose(j_np_mean(X),j_vmap_mean(X)) # True

%timeit j_np_mean(X)
# 191 µs ± 10.3 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%timeit j_vmap_mean(X)
# 180 µs ± 3.78 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Further, and I won’t demonstrate this herenow, but this is all composable e.g. within other jit’d functions and with AD.

Julia Broadcasting

Let’s see how Julia’s broadcasting compares. (note: that I am not suggesting to compare the timing results between Jax and Julia here, only to compare the hand-vectorized vs vmap’d/broadcasted timing within each language. Thanks @Mason for clarifying)


using BenchmarkTools
using LinearAlgebra: dot

D = 10^3
BS = 10^2

x = randn(D)
X = randn(BS,D)
y = randn(D)

dot(x,y);

By broadcasting (with Julia . syntax sugar) we can compute the matrix-vector multiply by broadcasting over the 1st slices of the first argument.

Since we don’t want to broadcast at all over the second argument, we use
Ref… yuck.


broadcast_dot(X,y) = dot.(eachslice(X,dims=1), Ref(y))

isapprox( broadcast_dot(X,y), X*y) #true


# Performance is worse
@btime X*y;
# 18.488 μs (1 allocation: 896 bytes)

@btime broadcast_dot(X,y);
# 70.838 μs (108 allocations: 6.56 KiB)

And it’s clear from the IR that broadcasted dot does not get lowered to a single call to a more efficient matrix-vector multiplication:


@code_lowered X*y
# CodeInfo(
# 1 ─       Core.NewvarNode(:(y))
# │         TS = LinearAlgebra.promote_op(LinearAlgebra.matprod, $(Expr(:static_parameter, 1)), $(Expr(:static_parameter, 2)))
# │   %3  = LinearAlgebra.isconcretetype(TS)
# └──       goto #3 if not %3
# 2 ─ %5  = Core.apply_type(LinearAlgebra.AbstractVector, TS)
# │         @_6 = LinearAlgebra.convert(%5, x)
# └──       goto #4
# 3 ─       @_6 = x
# 4 ┄       y = @_6
# │   %10 = TS
# │   %11 = LinearAlgebra.size(A, 1)
# │   %12 = LinearAlgebra.similar(x, %10, %11)
# │   %13 = LinearAlgebra.mul!(%12, A, y)
# └──       return %13
#)


@code_lowered broadcast_dot(X,y)
# CodeInfo(
# 1 ─ %1 = (:dims,)
# │   %2 = Core.apply_type(Core.NamedTuple, %1)
# │   %3 = Core.tuple(1)
# │   %4 = (%2)(%3)
# │   %5 = Core.kwfunc(Main.eachslice)
# │   %6 = (%5)(%4, Main.eachslice, X)
# │   %7 = Main.Ref(y)
# │   %8 = Base.broadcasted(Main.dot, %6, %7)
# │   %9 = Base.materialize(%8)
# └──      return %9
# )

Also true in the mean example:


using Statistics: mean

@btime mean($X, dims=1);
# 19.842 μs (1 allocation: 7.94 KiB)


@btime mean.($(eachslice(X,dims=2)));
# 29.634 μs (1002 allocations: 62.75 KiB)

Critically, unlike jax.vmap Julia’s broadcast will lower to Base.broadcasted(Main.dot,...) and not to a call of LinearAlgebra.mul!. This has significant ramifications for downstream tasks. For instance, AD with Zygote of broadcasted functions is considerably more complex and less performant because it is not able to leverage a more performant rule like the adjoint of matmul, and the emitted code of broadcast is messy. This especially is a noticeable issue at scales where hardware optimizations such as efficient matrix multiplication on GPU will dominate the broadcasted dot. (I may rerun these on GPU later).

The solution in Julia for me, unfortunately, has been to always write code that is batch-aware. That is, write all functions from the beginning as though it will accept a pre-specified batch dimension, so I can use hardware-backed kernels for fast matmul, as well for the gradients. Writing in Jax, with vmap, comparatively, is much more freeing. I can write all my functions as though they are applied elementwise, and only consider batches of data when I am ready to apply them to batched data. (Especially nice that I don’t have to rely on conventions like batch-dim being 0 in Python and end in Julia).

17 Likes

@Mason @mcabbott @darsnack @mbauman

Link to the relevant Zulip thread that started this: Zulip

3 Likes

I don’t think a compiler pass is necessary. Using the EachSlice type and overloading broadcasting could work for 0-level deep. Beyond that, you could make operator fusion a part of the Broadcasted struct and a macro turned that on or off. When it is off, the behavior could be to run the broadcasted expression over the slice iterators in lock-step like Hydra. This could handle the depth issue, since broadcasting can contain nested Broadcasted args anyways.

You could extend the macro to invoke a compiler pass if it turns out deciding when to turn on/off operator fusion is a more complicated cost function.

I’m still a little uneasy about this point. Why is the timing for a matmul call so different between Jax and Julia here? My earlier concern was that the Jax benchmarks being run weren’t actually measuring the same thing, or even a computation at all.

I remember seeing some erroneous twitter benchmarks a while ago where it turned out the benchmark was just measuring some lazy process and not the actual computation.

3 Likes

Julia arrays are column-major, so iterating over row slices like this is very inefficient for the cache. In contrast, Jax and NumPy arrays are row-major by default, so row slices are contiguous. You can see this if you transpose the X matrix and do the multiplication by columns (which is equivalent to using row-major storage ala NumPy):

julia> @btime broadcast_dot($X,$y);
  80.013 μs (108 allocations: 6.56 KiB)

julia> broadcast_dotT(y, Xt) = [dot(y, x) for x in eachslice(Xt, dims=2)]
broadcast_dotT (generic function with 1 method)

julia> Xt = Matrix(X');

julia> @btime broadcast_dotT($y, $Xt);
  19.694 μs (106 allocations: 5.67 KiB)

There’s your missing factor of 4.

(The BLAS matrix-vector multiplication takes the column-major storage into account and changes its algorithm to load several elements from each column at a time.)

16 Likes

Thanks for that. I suspected my timing could have be made more fair by respecting column vs row major storage.

To @Mason’s point I do think that I’m timing the jax code correctly but it’s possible I’m hitting that same lazy thing. I doubt it though, seeing as it favours Julia.

I’m less interested in optimizing these timing concerns though until I set up these experiments on a GPU where matmul will really be a considerable advantage.

But @stevengj point is significant because with the correct transposing the broadcasted dot is now comparable to the matmul on cpu…

I think the trick from that thread was like this, and seems to have little effect here. But these times are shorter than whatever lazy time was first measured there, so perhaps this is measuring largely some launch cost:

jbd = jit(broadcast_dot) 

def test_jbd(X,y): 
    return jbd(X,y).block_until_ready()

%timeit test_jbd(X,y)
284 µs ± 20.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

jmm = jit(np.matmul)
def test_jmm(X,y): 
    return jmm(X,y).block_until_ready() 

%timeit test_jmm(X,y)
276 µs ± 19.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

But regardless of the times, I guess we can believe make_jaxpr is not lying, and does indeed lower to what it prints out.

1 Like

I ran the benchmarks on the GPU:

using BenchmarkTools, CuArrays
using LinearAlgebra: dot

D = 10^3
BS = 10^2

x = randn(D)
X = randn(D, BS)
y = randn(D)
cX = cu(X)
cy = cu(y)
Xt = permutedims(X)
cXt = cu(Xt)

dot(x, y)
dot(cu(x), cy)

broadcast_dot(X, y) = [dot(x, y) for x in eachslice(X; dims = 2)]
matmul_dot(Xt, y) = Xt * y

Now running on the CPU:

@btime broadcast_dot($X, $y)
16.652 μs (108 allocations: 6.56 KiB)

@btime matmul_dot($Xt, $y)
13.867 μs (1 allocation: 896 bytes)

And on the GPU:

@btime CuArrays.@sync broadcast_dot($cX, $cy)
321.091 ms (208 allocations: 8.45 KiB)

@btime CuArrays.@sync matmul_dot($cXt, $cy)
238.609 μs (8 allocations: 208 bytes)

Of note is that the following definition did not work:

broadcast_dot(X, y) = dot.(eachslice(X; dims = 2), Ref(y))
3 Likes

NB the ms vs μs!

It’s my rough understanding that these generators get iterated through on the CPU, and each CuArray slice is handled in turn. I don’t know whether this can be done in a better way, but obviously fusing it to matmul_dot is one solution.

julia> gen = eachcol(cu(rand(2,3))); # slice a CuArray

julia> map(dot, gen, gen)
3-element Array{Float32,1}:
 0.11035925
 0.35250273
 0.23367213

julia> dot.(gen, gen)
3-element Array{Float32,1}:
 0.11035925
 0.35250273
 0.23367213

So there’s the difficulties @mcabbott and @darsnack are having with getting broadcast to just work on GPU, let alone dispatch to performant kernels. But another critical feature of jax.vmap is its composition with other transformations, namely AD.

In this case, being able to leverage broadcasted dot as a matmul would then allow JAX to emit adjoint code that uses the adjoint of matmul, rather than broadcasting the adjoint of dot.

Using @stevengj’s corrections for column-major (thanks) we can see this corresponds to significant performance reduction in the reverse-mode AD, even if the primal evaluations are comparably performant.

using BenchmarkTools
using LinearAlgebra: dot
using Zygote

D = 10^3
BS = 10^2

X = randn(BS,D)
Xt = Matrix(X')
y = randn(D)

broadcast_dotT(y,Xt) = dot.(Ref(y),eachslice(Xt,dims=2))

isapprox( broadcast_dotT(y,Xt), X*y) #true

# wrapped in sum for scalar needed by gradient
mul(y) = sum(X*y)
bdot(y) = sum(broadcast_dotT(y,Xt))


# similar primal pass
@btime mul(y);
#   18.398 μs (2 allocations: 912 bytes)

@btime bdot(y);
# 18.195 μs (109 allocations: 6.58 KiB)

# worse pullback performance
# only measuring difference in reverse pass
Xy, pb_mul = Zygote.pullback(mul,y);
@btime pb_mul(1.);
#  278.239 μs (14 allocations: 2.30 MiB)

Xy, pb_bdot = Zygote.pullback(bdot,y);
@btime pb_bdot(1.);
# 75.052 ms (2648 allocations: 155.01 MiB)

As @mcabbott mentioned, note that this is μs vs ms!

Again, all timing is done on CPU here. I expect this to be much worse on GPU.

2 Likes

I definitely really like being able to explicitly annotate the broadcasted dimensions — there’s a great usefulness to Julia’s broadcast permissiveness, but it’d be really cool to have an explicit mode like Jax.

We do need a dedicated struct for each slice. Once we have that, we could totally add a broadcast “rule” that transformed broadcasted dot over eachslice and a 0-dim container like ref to be a matvec. It’s quite the tiny peephole, but I think it could be reasonable if that’s causing major issues.

We also need to implement better broadcasting over generators like the sort that each slice currently returns in any case.

The mean performance issue looks to be a exasperated by a performance issue with mean itself, I think.

Thanks for putting this all together!

10 Likes

One thing to note is that (from a Julia perspective) the way vmap is implemented in JAX really is just dispatch on a Batched wrapper type; there’s no additional compiler magic or pattern matching. (Much of the rest of JAX is also implemented with dispatch: forward-mode autodiff is essentially dispatch on a tagged Dual wrapper type, and when we “trace” a user-defined function into a “jaxpr” data structure for XLA compilation or reverse-mode autodiff, that’s dispatch on what could be called a Staged wrapper type.) It’s only at the API level that these capabilities are exposed as function transformations (matching the API of other transformations that can’t be implemented purely with dispatch).

So what are we doing differently? For one thing, JAX maintains a strict equivalence between the information available to dispatch and the intermediate representation created by tracing; we’re free to round-trip between the two at will. (The whole JAX system would have the same semantics if it always traced to IR and transformed/analyzed that, but the dispatch mechanism exists as an optimization for performance and debuggability). For vmap this means it behaves the same way if applied before or after staging.

The fact that we trace to IR using dispatch (and so much of Python isn’t available for overloading) significantly restricts the information we can get out of the source language. We can’t see Python control flow, list/dict operations, exceptions, etc. (except to the extent that we often get to control the errors Python throws when it doesn’t know what to do with our dispatch wrappers). But, on the flip side, it means our IR is both very simple and very static. On top of assumptions checked by construction like no data-dependent Python control flow, the staging mechanism makes assumptions about user code that it generally can’t check: no side effects that escape the traced function, no depending on aliasing relationships.

These assumptions point towards a different way of looking at the relationship between vmap and broadcasting: JAX vmap is like a LoopVectorization.jl or KernelAbstractions.jl-annotated parallel loop: the iteration region can be traversed in any order or in any number of parallel threads, because the loop body can’t contain side effects or rely on aliasing. Because the implementation uses dispatch rather than a macro, vmap is able to look inside functions called in the body and emit BLAS-like primitives in addition to broadcasted scalar primitives.

Downstream, the parts of a vmapped region that don’t lower to BLAS on CPU/GPU almost always end up as XLA fusion regions (a construct with the same goal as Julia broadcast fusion but JAX-like purity assumptions). With XLA’s TPU backend, heroic efforts from the compiler team mean the BLAS parts end up in the fusion regions, too.

22 Likes

I’d just like to add that vmap is perhaps my favorite feature of JAX. In fact I think it may be the main barrier for me switching to Julia from Python/JAX. At least that’s how I happened upon this thread…

As @jekbradbury mentioned, JAX IR is limited, esp in terms of control flow. As a result control flow in JAX can be quite painful, but if that’s the price I need to pay to have vmap then so be it.

6 Likes

Do you perhaps have a few examples of what you use it for? Less trivial than the broadcast_dot things above, presumably…

My usual example is something like implementing a particle filter. So for each timestep you need to have a set of particles, and each particle gets resampled into a bunch of other particles, then you need to sample a new set of particles out of those, yada yada. But now maybe you want to operate over more than one sequence at a time. With JAX you just kind of vmap the whole thing together, and minimal care is required in order to get optimal performance. Compare that to maintaining tensors of with shapes like sequences x timesteps x partices x next_particles x observations.

Here’s another use case I came across recently: It’s possible to vmap(odeint(...)). This batches operations running across all the different ODE solvers simultaneously. Pretty nuts. Matt Johnson wrote a more nuanced explanation here.

4 Likes

That’s just https://diffeq.sciml.ai/latest/features/ensemble/ but with less features, less parallelism, less solver choices, support for less differential equation types, etc. Why would it be interesting? Is it just the syntax? That’s somewhat odd to point out then because it’s a cute example, but the syntax choice is not expressive enough to actually cover the space of things that need to be done, so ultimately it needs a different interface anyways.

The point is that vmap is fully general and comes at no performance cost regardless of what operations you’re broadcasting. It obviates the need for frameworks to build custom pipelines like https://docs.sciml.ai/latest/features/ensemble/ in the first place.

You could just write a for loop…

4 Likes

You could just write a for loop…

True, although you’d be giving up SIMD loop fusion and all that good stuff as per @jessebett’s examples.