Julia's Broadcast vs Jax's vmap

How do you tell vmap take random initial conditions, repeat on observations that fail, batch by 50 and perform a low memory expansion and reduce via Welford’s algorithm?

1 Like

If you write a explicit loop, then by construction the loops are already fused. And the compiler can do SIMD.

6 Likes

Again, the point here is that vmap is completely agnostic to the function that it’s batching…

That’s great! I wasn’t aware that Julia’s compiler went to those lengths. But it still seems as though there are significant performance bottlenecks elsewhere, as in @jessebett’s examples. The real value-prop for jax.vmap for me is that it frees me from worrying about all those bottlenecks and instead lets me write intuitive, straight-from-the-equation code. I would be stoked if such a thing could be possible in Julia though!

3 Likes

But can it do what I mentioned? You said it’s “completely general” so you don’t need library functions like the one I pointed to, but from what I could see is that vmap of 10000 ODEs is going to need the memory for 10000 ODE solutions, while the library I linked won’t need more memory than that for 50 solutions at a time.

4 Likes

I think this line of questioning is getting a bit off topic, as the OP issues brought up around vmap/julia broadcasting don’t relate to memory pressure.

But in any case, what you’re looking for can be done in JAX. Something like this should do the trick:

import jax.numpy as jnp

really_big_dataset = ...
aggregation = 0
for chunk in jnp.split(really_big_dataset, 50):
  really_big_result = vmap(your_func)(chunk)
  aggregation += some_kind_of_aggregation(really_big_result)

And you can swap that python for loop out for a jax.lax.fori_loop for some extra speed if you’d like! This pattern is actually used in some of the JAX example code for training loops, etc. Or you could use jax.pmap to run vmap-batched computations in parallel across multiple devices and then aggregate them together. The possibilities are endless!

2 Likes

I mean, you kind of just proved the exact opposite of your point. Either vmap is useful here because it’s a high level interface that does a bunch of common things to make Monte Carlo etc. ensemble simulations easy to do, or it’s just a for loop and you have to write everything yourself. But since it doesn’t do what the library does and instead is the latter, why not just use a language which has fast loops and auto-vectorization of said loops, i.e. Julia?

(that’s not to say there aren’t compelling examples for vmap, but this isn’t one)

3 Likes

On the contrary, the value of jax.vmap is that it does precisely one thing and does it very well. Having a grab-bag of every possible feature does not align with the philosophy of JAX as I understand it. Rather, the goal is to be able to build complexity out of composable, simple primitives. Fast loops are great, but the point here is that they don’t save you: Julia's Broadcast vs Jax's vmap - #9 by darsnack.

I’m worried this is devolving into some sort of Julia vs JAX flame war. My understanding of OP’s intent was to start a discussion around comparisons between jax.vmap and Julia’s broadcasting.

11 Likes

Yes, that’s a compelling example of vmap, but that doesn’t mean the other example is. The problem with vmap is that its extra utility over a for loop is really limited to the cases where you can really exploit extra SIMD, like fusing a loop of BLAS2 kernels to be a BLAS3 kernel. That is compelling, but in other cases it’s essentially just a loop, which isn’t compelling. But since the compelling cases have to do with really tight loops like in linear algebra, that seems like a missing abstraction in handling linear algebra, not something that is more generally useful.

If that’s the case, does XLA.jl not cover all of these cases where you’d get meaningful amounts of extra SIMD? I still haven’t an example that contradicts that, which tells me that we should be slapping xla(f) on a few more functions to get the same as vmap and call it a day.

That’s one source of low-level utility, and another is that it seems easier to make AD understand something explicit like this, instead of an arbitrary loop. Likewise translation to GPUs.

Besides those, people evidently find value in having a uniform, high-level way to specify such things. Even if it gives up some control. The same could be said of ordinary “scalar” broadcasting.

Thanks for the examples, @Samuel_Ainsworth.

5 Likes

If it’s a loop over small/scalar operations, then yes this matters a lot. If it’s a loop over matmuls, then no you’re not gaining anything by clumping this in the AD.

Once again, that works for the simple linear algebra case but not the more difficult cases like an ODE solver. Throwing a whole function to asynchronously while loop over many things of varying length is a good way to get a lot of divergent warps: you need quite fancy handling to do this well. And that’s even without considering things like event handling callbacks (which aren’t implemented in the Jax one to begin with). So yes, that works on linear algebra, but more?

Agreed, but is there example where it’s useful beyond linear algebra BLAS1 or BLAS2 linear algebra commands? That’s a valid question because all of the examples keep being the same thing. If that’s the case, it’s a neat syntax but could be covered by just making linear algebra operations lazy.

1 Like

not trying to get into some hot debate here but, I think people do need to realize Jax is not Python and “fully general” means as long as everything one does is in Numpy, which in itself is a package (ecosystem) limitation.

5 Likes

Now, I don’t understand much of the technical issues discussed in this thread, but I have a feeling that is not the best way we can introduce people to our community, I feel we are sometimes a bit over-protectionist…

13 Likes

Aside the practical aspects of things like Python have a MUCH larger user-base etc etc. I wonder, purely on a techncial idealised world. Would implementing JAX on Julia result in a much more technically superior product. Just wondering out loud.

Yes and no. Well there’s kind of two questions there. If the question is whether implementing something just like Jax on Julia would be better, i.e. something that expands statements to a simplified IR without control flow on XLA, there are some advantages in terms of ecosystem but no technical advantages I can think of. It would allow us to easily merge BLAS2 → BLAS3 kernels and all of that, and get all of the same XLA optimizations, but it would also have the same limitations, like difficulties handling while loops and dynamic allocations, because of how it performs the expansion and generates the XLA IR. FWIW, you can try this today since this is essentially what FluxML/XLA.jl is doing:

(It even uses the Jax build of XLA). It’s a great way to accelerate standard run-of-the-mill machine learning tasks, and to use it you just do |> xla on a chain. I wouldn’t be surprised if you see XLA.jl (and thus a “Jax like style” under the hood) recommended more in the Flux documentation soon, since it covers the cases mentioned here.

(The ecosystem advantage is that this would work on existing Julia packages, where Jax doesn’t work directly on most Python packages without modification because it requires a very functional style and a specific numpy)

The second way to interpret your question though is whether something like Jax could be done “on Julia”, i.e. these kinds of transformation except on the Julia IR. This is somewhat what Zygote is doing, though there’s kind of a “refresh” for new compiler tools that target Julia’s SSA form. In this sense it would be a technically superior method because it could handle all of those edge cases by not having to expand any control flow. It’s more technically difficult but it would integrate with the language a lot better, and it’s something to watch out for. But until then, XLA.jl (and also just Zygote.jl) is a good solution.

I didn’t mean to be antagonistic, but come on, “completely general” is just bait for an example of something it can’t do, and I happened to be working on a real user issue (so something that shows up in practice) that vmap doesn’t handle. Let’s just stay grounded since there’s an interesting technical discussion to be had.

10 Likes

Chris and I had a bit of a “flame war” on this topic way back (only jesting — I thought it was a good discussion, and I’ve been lurking in the Julia community long enough to know there were no ill intentions). Actually, I was supposed to post my thoughts here, but I totally forgot, and it is been so long that I’m going to take some time to recollect them.

But as someone on the “pro-vmap” camp, I want to make clear that I don’t believe that vmap provides any functionality beyond for-loops. Maybe in other languages, but not in Julia AFAIK. On the other hand, Julia’s broadcasting mechanism is no doubt central to the expressibility of the language. My angle in this argument is that broadcasting handles a certain kind of for-loop, and I would like to extend its semantics to include other kinds. This will not allow us to do something we couldn’t before, but it might allow us to do something in a different more expressible way.

16 Likes

As promised, here are my thoughts. Excuse any mistakes as I did this late last night after JuliaCon.

Consider a generic broadcasted expression like:

y = f.(g.(h.(x)))

where x is some collection. Broadcasting allows us to conveniently express the following loop:

y = prepare_dest((f, g, h), x) # some function that allocates y
for I in axes(x)
	y[I] = f(g(h(x[I])))
end

But there is another type of looping that is implicit within this code. If I were to be more explicit about it, I might write

out_f = prepare_dest((f,), x)
out_g = prepare_dest((f, g), x)
out_h = prepare_dest((f, g, h), x)
y = prepare_dest((f, g, h), x)
for I in axes(x)
	val = x[I]
	for (dest, op) in zip((out_h, out_g, out_f), (h, g, f))
		dest[I] = op(val)
		val = dest[I]
	end
	y[I] = out_h[I]
end

Importantly, I allocated out_f, out_g, and out_h here even though I didn’t need to. I’m trying to belabor the point that operator fusion is an optimization that contracts the inner-most loop above.

What happens when we switch the order of the loops?

val = x
for op in (h, g, f)
	dest = prepare_dest((op,), x)
	for I in axes(val)
		dest[I] = op(val[I])
	end
	val = dest
end
y = val

What would contracting the inner-most loop mean now? First, you lose the memory benefits of operator fusion (dest is allocated for every op). But would the contraction be faster:

val = x
for op in (h, g, f)
	# some faster "denser" op works on all of val
	val = contract_op(op, indices(val))(val)
end
y = val

Can we do this with broadcasting?

At this point, it is more helpful to deal with a concrete example. Namely, the expression at the start of this post.

# b is a vector, A is a matrix
y = dot.(Ref(b), eachcol(A))

If we write out the explicit form, we get

# ignore my abusive notation
y = prepare_dest((dot,), Ref(b), eachcol(A))
for I in axes(Ref(b)), J in axes(eachcol(A))
	K = get_out_index(I, J)
	for op in (dot,)
		y[K] = op(Ref(b)[I], eachcol(A)[J])
	end
end

# now we swap the loops
val = (Ref(b), eachcol(A))
for op in (dot,)
	dest = prepare_dest((op,), val...)
	for I in axes(val[1]), J in axes(val[2])
		K = get_out_index(I, J)
		dest[K] = op(val[1][I], val[2][J])
	end
	val = dest
end
y = val

# and we contract
val = (Ref(b), eachcol(A))
for op in (dot,)
	val = contract_op(op, axes(val))(val)
	# this contract_op(op, indices((Ref(b), eachcol(A)))
	# should return (vec, mat) -> mat * vec
end
y = val

How would we facilitate the steps above with broadcasting? With something like this PR, we could overload

broadcasted(::typeof(dot), vec::AbstractVector, mat::EachCol{<:AbstractMatrix}) = parent(mat) * vec

To be clear, this means that mapping broadcast operations to more efficient implementations requires the author of the function to define what to do. But if you look at Jax’s lowered form for dot, it maps calls to dot to dot_general which consumes axis information to call the most efficient operation. We’re basically doing the same here.

Can we contract both loops?

The nice part about the approach above is that it requires no changes to the broadcast machinery but that also poses some limitations. Specifically, we can’t fuse multiple operators into a more performant call. To talk about this, I want to introduce two more concepts: the width and depth of a broadcasted optimization. When we think about a generic expression like

y = f.(g.(h.(x)))

the width refers to how many operations the optimization works on. So, the standard operator fusion is a full-width optimization, because it contracts over every function in the expression. The depth refers to optimizations applied within the outer-most (surface) functions. For example, if h contains a call to dot, and I was able to apply optimizations to that inner call, then I have some depth to my optimization.

The approach above is single width, because it doesn’t allow something like

out = dot.(Ref(y), eachcol(X)) .+ b

to map to the BLAS operation for expressions like A*x + b. To facilitate this, we can modify the Broadcasted type so that the axes field (or maybe an additional field) contains the EachCol style information above. Then when we materialize the nested Broadcasteds, we provide an overload-able contract(nested_broadcast) that allows someone to specify that the nested broadcast should be replaced by a single broadcast where Broadcasted.f == (dot ∘ +). Then again, we do

broadcasted(typeof(dot ∘ +), x::AbstractVector, A::EachCol{<:AbstractMatrix}, b::AbstractVector) = parent(A) * x + b

This approach allows for full-width optimizations that also contract the data indexing, but only single depth.

Why does Jax use the compiler?

If we can pass all the necessary info around, then why would Jax bother with IR passes? In order to get depth > 1, we need to inspect the body of functions in the outer broadcast expression which requires some kind of compiler pass. We’ve tested this behavior in a Zulip discussion, and Jax does seem to descend really far into expressions. There are a couple issues/PRs about custom compiler passes floating around, so I won’t really get into this.

Really, broadcasting?

The changes above are complex and will probably break stuff. Off the top of my head, it clearly breaks operator fusion, so a real implementation would require some kind of flag in Broadcasted to control the contraction behavior. Or maybe BroadcastStyle could be used for this. There are also other ways of implementing what I’m talking about like this recent issue on KernelAbstractions.jl. This approach would do everything above so long as the kernel function passed to vmap! behaves the same as fused BLAS calls. So the user would need to provide the correct kernel function to get the performance whereas the broadcasting approach gives that control primarily to the developer (making it appear “automatic” to the user).

More than BLAS?

This is a point that Chris has brought up often, and I’ve thought about it a lot and believe it is a really strong point. Beyond fused BLAS operations, I can’t think of really compelling examples to make these changes. All I can say is that broadcasting appears to be an extremely succinct way to describe patterned loop blocks. Wherever there are loops, there is fertile ground for optimization. I could even see a more future looking version of this allowing broadcasting over collections complex data structures to use pmap and other parallel optimizations. The obvious downside is that the more expressibility you add to broadcasting, the worse it might become doing each thing. Though on this front I’d echo Samuel’s words on vmap, broadcasting to me does one thing only — express loops. Ultimately, I’m not arguing that it should do anything else, only that it can express all loops and that developers/users can overload its default behavior.

7 Likes

For me, I see the future of broadcasting as being a super optimized iterator. On the syntax side, I don’t see too much more necessary besides terse lambdas.

Very likely not, as this may be type unstable for generic ops. YMMV.

Sorry I am missing something, but I think that function composition and broadcasting are entirely different operations and trying to unify them may not offer practical benefits. Composition can be understood as eg foldr (with emphasis on the order of associativity), while broadcasting can do things in any order under ideal conditions.

4 Likes

In the code example that you quoted, I am not trying to do any functional composition. Excuse my terrible overloading of the “contract” term, but contract_op(f, axes_info) is a function that maps f.(x) => F(x) where f is some element-level operation and F operates on all of x at once. The idea is that some functions when broadcasted to their arguments in a specific iteration behavior (axes_info) can be mapped to equivalent operations that are faster and more dense. By “dense” I’m trying to say that F operates on a larger chunk of x (probably all of x at once) than f. The canonical example of that in this thread is that dot broadcasted to slices of a matrix and a vector can be done more efficiently by matrix multiplication.

The initial method that I proposed would only require adding some kind of “slice type” to indicated when a broadcast is being applied to eachslice, eachrow, etc. It does not do any functional composition and it doesn’t change the broadcasting machinery, but it allows for the mapping I described above. Let’s call this optimization data contraction, because instead of iterating the elements of x, it contracts or fuses some of the axes to apply F.

I really like that you described this as functional composition. I was not explicitly thinking of it that way (even though I use in my post), but I think this point of view actually makes things much clearer. I would argue that Julia’s broadcasting already does functional composition. Consider the code below:

y = f.(g.(h.(x)))

As described in the original blog post, a naive implementation of broadcasting would first apply h to each element of x, then apply g to each element of h.(x), then f to each element of g.(h.(x)). Operator fusion is an optimization that maps

y = f.(g.(h.(x))) ==> y = (f ∘ g ∘ h).(x)

Of course, broadcasting does not do this explicitly (i.e. there is not a preprocess step where the broadcast machinery composes all the functions in an expression). Instead, since broadcasting is lazy and nested, the machinery “lifts” the compose operation into the Broadcasted type. So, by having iterating over the data be outer-most loop, the composition happens at “materialize time.”

The downside to data contraction is that it makes iterating the data the inner-most loop. This breaks the “materialize time” operator fusion. But there are cases (e.g. dot.(Ref(y), X) .+ b) where the efficient code requires both data contraction and operator fusion. To facilitate this, we need to make explicit the functional composition in broadcasting as something that can be overloaded by users. Again, even though I use dot ∘ + in my example, at no point will I actually apply this composed function to anything. This was my way of allowing the user to explicitly apply operator fusion and data contraction together, since in the end what actually gets executed is the BLAS call for X * y + b.

EDIT: Also to be clear, the two “modes” of broadcasting I described above would actually be interoperable. So normal broadcasting and operator fusion will still happen as expected is no one overloads anything. If someone overloads dot to apply data contraction, then there would be this kind of barrier where dot is routed to matmul then normal broadcasting/operator fusion proceeds as normal. Additionally, someone could further overload the new explicit composition function so that dot followed by + is routed to BLAS’s A * x + b, then again normal broadcasting proceeds for the other expressions.

2 Likes