Julia's Broadcast vs Jax's vmap

In the code example that you quoted, I am not trying to do any functional composition. Excuse my terrible overloading of the “contract” term, but contract_op(f, axes_info) is a function that maps f.(x) => F(x) where f is some element-level operation and F operates on all of x at once. The idea is that some functions when broadcasted to their arguments in a specific iteration behavior (axes_info) can be mapped to equivalent operations that are faster and more dense. By “dense” I’m trying to say that F operates on a larger chunk of x (probably all of x at once) than f. The canonical example of that in this thread is that dot broadcasted to slices of a matrix and a vector can be done more efficiently by matrix multiplication.

The initial method that I proposed would only require adding some kind of “slice type” to indicated when a broadcast is being applied to eachslice, eachrow, etc. It does not do any functional composition and it doesn’t change the broadcasting machinery, but it allows for the mapping I described above. Let’s call this optimization data contraction, because instead of iterating the elements of x, it contracts or fuses some of the axes to apply F.

I really like that you described this as functional composition. I was not explicitly thinking of it that way (even though I use in my post), but I think this point of view actually makes things much clearer. I would argue that Julia’s broadcasting already does functional composition. Consider the code below:

y = f.(g.(h.(x)))

As described in the original blog post, a naive implementation of broadcasting would first apply h to each element of x, then apply g to each element of h.(x), then f to each element of g.(h.(x)). Operator fusion is an optimization that maps

y = f.(g.(h.(x))) ==> y = (f ∘ g ∘ h).(x)

Of course, broadcasting does not do this explicitly (i.e. there is not a preprocess step where the broadcast machinery composes all the functions in an expression). Instead, since broadcasting is lazy and nested, the machinery “lifts” the compose operation into the Broadcasted type. So, by having iterating over the data be outer-most loop, the composition happens at “materialize time.”

The downside to data contraction is that it makes iterating the data the inner-most loop. This breaks the “materialize time” operator fusion. But there are cases (e.g. dot.(Ref(y), X) .+ b) where the efficient code requires both data contraction and operator fusion. To facilitate this, we need to make explicit the functional composition in broadcasting as something that can be overloaded by users. Again, even though I use dot ∘ + in my example, at no point will I actually apply this composed function to anything. This was my way of allowing the user to explicitly apply operator fusion and data contraction together, since in the end what actually gets executed is the BLAS call for X * y + b.

EDIT: Also to be clear, the two “modes” of broadcasting I described above would actually be interoperable. So normal broadcasting and operator fusion will still happen as expected is no one overloads anything. If someone overloads dot to apply data contraction, then there would be this kind of barrier where dot is routed to matmul then normal broadcasting/operator fusion proceeds as normal. Additionally, someone could further overload the new explicit composition function so that dot followed by + is routed to BLAS’s A * x + b, then again normal broadcasting proceeds for the other expressions.

2 Likes