As promised, here are my thoughts. Excuse any mistakes as I did this late last night after JuliaCon.
Consider a generic broadcasted expression like:
y = f.(g.(h.(x)))
where x
is some collection. Broadcasting allows us to conveniently express the following loop:
y = prepare_dest((f, g, h), x) # some function that allocates y
for I in axes(x)
y[I] = f(g(h(x[I])))
end
But there is another type of looping that is implicit within this code. If I were to be more explicit about it, I might write
out_f = prepare_dest((f,), x)
out_g = prepare_dest((f, g), x)
out_h = prepare_dest((f, g, h), x)
y = prepare_dest((f, g, h), x)
for I in axes(x)
val = x[I]
for (dest, op) in zip((out_h, out_g, out_f), (h, g, f))
dest[I] = op(val)
val = dest[I]
end
y[I] = out_h[I]
end
Importantly, I allocated out_f
, out_g
, and out_h
here even though I didn’t need to. I’m trying to belabor the point that operator fusion is an optimization that contracts the inner-most loop above.
What happens when we switch the order of the loops?
val = x
for op in (h, g, f)
dest = prepare_dest((op,), x)
for I in axes(val)
dest[I] = op(val[I])
end
val = dest
end
y = val
What would contracting the inner-most loop mean now? First, you lose the memory benefits of operator fusion (dest
is allocated for every op
). But would the contraction be faster:
val = x
for op in (h, g, f)
# some faster "denser" op works on all of val
val = contract_op(op, indices(val))(val)
end
y = val
Can we do this with broadcasting?
At this point, it is more helpful to deal with a concrete example. Namely, the expression at the start of this post.
# b is a vector, A is a matrix
y = dot.(Ref(b), eachcol(A))
If we write out the explicit form, we get
# ignore my abusive notation
y = prepare_dest((dot,), Ref(b), eachcol(A))
for I in axes(Ref(b)), J in axes(eachcol(A))
K = get_out_index(I, J)
for op in (dot,)
y[K] = op(Ref(b)[I], eachcol(A)[J])
end
end
# now we swap the loops
val = (Ref(b), eachcol(A))
for op in (dot,)
dest = prepare_dest((op,), val...)
for I in axes(val[1]), J in axes(val[2])
K = get_out_index(I, J)
dest[K] = op(val[1][I], val[2][J])
end
val = dest
end
y = val
# and we contract
val = (Ref(b), eachcol(A))
for op in (dot,)
val = contract_op(op, axes(val))(val)
# this contract_op(op, indices((Ref(b), eachcol(A)))
# should return (vec, mat) -> mat * vec
end
y = val
How would we facilitate the steps above with broadcasting? With something like this PR, we could overload
broadcasted(::typeof(dot), vec::AbstractVector, mat::EachCol{<:AbstractMatrix}) = parent(mat) * vec
To be clear, this means that mapping broadcast operations to more efficient implementations requires the author of the function to define what to do. But if you look at Jax’s lowered form for dot
, it maps calls to dot
to dot_general
which consumes axis information to call the most efficient operation. We’re basically doing the same here.
Can we contract both loops?
The nice part about the approach above is that it requires no changes to the broadcast machinery but that also poses some limitations. Specifically, we can’t fuse multiple operators into a more performant call. To talk about this, I want to introduce two more concepts: the width and depth of a broadcasted optimization. When we think about a generic expression like
y = f.(g.(h.(x)))
the width refers to how many operations the optimization works on. So, the standard operator fusion is a full-width optimization, because it contracts over every function in the expression. The depth refers to optimizations applied within the outer-most (surface) functions. For example, if h
contains a call to dot
, and I was able to apply optimizations to that inner call, then I have some depth to my optimization.
The approach above is single width, because it doesn’t allow something like
out = dot.(Ref(y), eachcol(X)) .+ b
to map to the BLAS operation for expressions like A*x + b
. To facilitate this, we can modify the Broadcasted
type so that the axes
field (or maybe an additional field) contains the EachCol
style information above. Then when we materialize the nested Broadcasted
s, we provide an overload-able contract(nested_broadcast)
that allows someone to specify that the nested broadcast should be replaced by a single broadcast where Broadcasted.f == (dot ∘ +)
. Then again, we do
broadcasted(typeof(dot ∘ +), x::AbstractVector, A::EachCol{<:AbstractMatrix}, b::AbstractVector) = parent(A) * x + b
This approach allows for full-width optimizations that also contract the data indexing, but only single depth.
Why does Jax use the compiler?
If we can pass all the necessary info around, then why would Jax bother with IR passes? In order to get depth > 1, we need to inspect the body of functions in the outer broadcast expression which requires some kind of compiler pass. We’ve tested this behavior in a Zulip discussion, and Jax does seem to descend really far into expressions. There are a couple issues/PRs about custom compiler passes floating around, so I won’t really get into this.
Really, broadcasting?
The changes above are complex and will probably break stuff. Off the top of my head, it clearly breaks operator fusion, so a real implementation would require some kind of flag in Broadcasted
to control the contraction behavior. Or maybe BroadcastStyle
could be used for this. There are also other ways of implementing what I’m talking about like this recent issue on KernelAbstractions.jl. This approach would do everything above so long as the kernel function passed to vmap!
behaves the same as fused BLAS calls. So the user would need to provide the correct kernel function to get the performance whereas the broadcasting approach gives that control primarily to the developer (making it appear “automatic” to the user).
More than BLAS?
This is a point that Chris has brought up often, and I’ve thought about it a lot and believe it is a really strong point. Beyond fused BLAS operations, I can’t think of really compelling examples to make these changes. All I can say is that broadcasting appears to be an extremely succinct way to describe patterned loop blocks. Wherever there are loops, there is fertile ground for optimization. I could even see a more future looking version of this allowing broadcasting over collections complex data structures to use pmap
and other parallel optimizations. The obvious downside is that the more expressibility you add to broadcasting, the worse it might become doing each thing. Though on this front I’d echo Samuel’s words on vmap
, broadcasting to me does one thing only — express loops. Ultimately, I’m not arguing that it should do anything else, only that it can express all loops and that developers/users can overload its default behavior.