Julia's Broadcast vs Jax's vmap

As promised, here are my thoughts. Excuse any mistakes as I did this late last night after JuliaCon.

Consider a generic broadcasted expression like:

y = f.(g.(h.(x)))

where x is some collection. Broadcasting allows us to conveniently express the following loop:

y = prepare_dest((f, g, h), x) # some function that allocates y
for I in axes(x)
	y[I] = f(g(h(x[I])))
end

But there is another type of looping that is implicit within this code. If I were to be more explicit about it, I might write

out_f = prepare_dest((f,), x)
out_g = prepare_dest((f, g), x)
out_h = prepare_dest((f, g, h), x)
y = prepare_dest((f, g, h), x)
for I in axes(x)
	val = x[I]
	for (dest, op) in zip((out_h, out_g, out_f), (h, g, f))
		dest[I] = op(val)
		val = dest[I]
	end
	y[I] = out_h[I]
end

Importantly, I allocated out_f, out_g, and out_h here even though I didn’t need to. I’m trying to belabor the point that operator fusion is an optimization that contracts the inner-most loop above.

What happens when we switch the order of the loops?

val = x
for op in (h, g, f)
	dest = prepare_dest((op,), x)
	for I in axes(val)
		dest[I] = op(val[I])
	end
	val = dest
end
y = val

What would contracting the inner-most loop mean now? First, you lose the memory benefits of operator fusion (dest is allocated for every op). But would the contraction be faster:

val = x
for op in (h, g, f)
	# some faster "denser" op works on all of val
	val = contract_op(op, indices(val))(val)
end
y = val

Can we do this with broadcasting?

At this point, it is more helpful to deal with a concrete example. Namely, the expression at the start of this post.

# b is a vector, A is a matrix
y = dot.(Ref(b), eachcol(A))

If we write out the explicit form, we get

# ignore my abusive notation
y = prepare_dest((dot,), Ref(b), eachcol(A))
for I in axes(Ref(b)), J in axes(eachcol(A))
	K = get_out_index(I, J)
	for op in (dot,)
		y[K] = op(Ref(b)[I], eachcol(A)[J])
	end
end

# now we swap the loops
val = (Ref(b), eachcol(A))
for op in (dot,)
	dest = prepare_dest((op,), val...)
	for I in axes(val[1]), J in axes(val[2])
		K = get_out_index(I, J)
		dest[K] = op(val[1][I], val[2][J])
	end
	val = dest
end
y = val

# and we contract
val = (Ref(b), eachcol(A))
for op in (dot,)
	val = contract_op(op, axes(val))(val)
	# this contract_op(op, indices((Ref(b), eachcol(A)))
	# should return (vec, mat) -> mat * vec
end
y = val

How would we facilitate the steps above with broadcasting? With something like this PR, we could overload

broadcasted(::typeof(dot), vec::AbstractVector, mat::EachCol{<:AbstractMatrix}) = parent(mat) * vec

To be clear, this means that mapping broadcast operations to more efficient implementations requires the author of the function to define what to do. But if you look at Jax’s lowered form for dot, it maps calls to dot to dot_general which consumes axis information to call the most efficient operation. We’re basically doing the same here.

Can we contract both loops?

The nice part about the approach above is that it requires no changes to the broadcast machinery but that also poses some limitations. Specifically, we can’t fuse multiple operators into a more performant call. To talk about this, I want to introduce two more concepts: the width and depth of a broadcasted optimization. When we think about a generic expression like

y = f.(g.(h.(x)))

the width refers to how many operations the optimization works on. So, the standard operator fusion is a full-width optimization, because it contracts over every function in the expression. The depth refers to optimizations applied within the outer-most (surface) functions. For example, if h contains a call to dot, and I was able to apply optimizations to that inner call, then I have some depth to my optimization.

The approach above is single width, because it doesn’t allow something like

out = dot.(Ref(y), eachcol(X)) .+ b

to map to the BLAS operation for expressions like A*x + b. To facilitate this, we can modify the Broadcasted type so that the axes field (or maybe an additional field) contains the EachCol style information above. Then when we materialize the nested Broadcasteds, we provide an overload-able contract(nested_broadcast) that allows someone to specify that the nested broadcast should be replaced by a single broadcast where Broadcasted.f == (dot ∘ +). Then again, we do

broadcasted(typeof(dot ∘ +), x::AbstractVector, A::EachCol{<:AbstractMatrix}, b::AbstractVector) = parent(A) * x + b

This approach allows for full-width optimizations that also contract the data indexing, but only single depth.

Why does Jax use the compiler?

If we can pass all the necessary info around, then why would Jax bother with IR passes? In order to get depth > 1, we need to inspect the body of functions in the outer broadcast expression which requires some kind of compiler pass. We’ve tested this behavior in a Zulip discussion, and Jax does seem to descend really far into expressions. There are a couple issues/PRs about custom compiler passes floating around, so I won’t really get into this.

Really, broadcasting?

The changes above are complex and will probably break stuff. Off the top of my head, it clearly breaks operator fusion, so a real implementation would require some kind of flag in Broadcasted to control the contraction behavior. Or maybe BroadcastStyle could be used for this. There are also other ways of implementing what I’m talking about like this recent issue on KernelAbstractions.jl. This approach would do everything above so long as the kernel function passed to vmap! behaves the same as fused BLAS calls. So the user would need to provide the correct kernel function to get the performance whereas the broadcasting approach gives that control primarily to the developer (making it appear “automatic” to the user).

More than BLAS?

This is a point that Chris has brought up often, and I’ve thought about it a lot and believe it is a really strong point. Beyond fused BLAS operations, I can’t think of really compelling examples to make these changes. All I can say is that broadcasting appears to be an extremely succinct way to describe patterned loop blocks. Wherever there are loops, there is fertile ground for optimization. I could even see a more future looking version of this allowing broadcasting over collections complex data structures to use pmap and other parallel optimizations. The obvious downside is that the more expressibility you add to broadcasting, the worse it might become doing each thing. Though on this front I’d echo Samuel’s words on vmap, broadcasting to me does one thing only — express loops. Ultimately, I’m not arguing that it should do anything else, only that it can express all loops and that developers/users can overload its default behavior.

7 Likes