Thoughts on Jax vs CuArrays and Zygote

Google’s JAX looks like a combination of CuArrays and Zygote for Python. I wonder if anyone can compare and contrast Jax with CuArrays and Zygote?

I can’t gather how Jax is different apart from the fact that it compiles to XLA.


I’m also interested in a comparison. My understanding is that Jax supports AD on a subset of Python+Numpy. I suspect it would be a lot of effort for Jax to keep and expand support, but I’m not sure if Julia and the various AD packages (Zygote in particular) are competitive yet, e.g. on par with functionality in the Jax cookbook:


Just saw this and I’m curious as well (I don’t know much about Jax). @MikeInnes could you share your thoughts?

1 Like

Jax is really nice; at this point TensorFlow and PyTorch have converged to being hybrids of each other, with a bunch of different ways of doing the same thing (“eager mode”, “static graphs”). Jax really refreshes and cleans up this design, with well-thought-out semantics and interfaces, along with having a lot of nice ideas around composable code transforms that are very much on our wavelength (c.f. Cassette).

Python still has the fundamental limitations we discussed a while back, though. The major development since then is that explicit graph building has been replaced with eager semantics + tracing; this is much more intuitive but technically largely equivalent. So in TF 2.0, JAX, or PyTorch JIT, if you want control flow and performance, you’ll still need to replace your loops with framework-compatible versions, and this has limitations around recursion, mutation, custom data structures etc., and in particular it can’t differentiate through any other python library code.

The paper on TF 2.0, which shares many ideas with Jax, discusses this a bit as well:

In TensorFlow Eager, users must manually stage computations, which might require refactoring code. An ideal framework for differentiable programming would automatically stage computations, without programmer intervention. One way to accomplish this is to embed the framework in a compiled procedural language and implement graph extraction and automatic differentiation as compiler rewrites; this is what, e.g., DLVM, Swift for TensorFlow, and Zygote do. Python’s flexibility makes it difficult for DSLs embedded in it to use such an approach.


I think JAX is a bit more than “just” cuda + autodiff, the XLA compiler also produces highly optimized CPU code (5x faster than numpy on a real usecase I had) with the added bonus bonus that the exact same code also runs on GPU is available.

Maybe Julia can compile to XLA to gain speed on all hardware supported by XLA? I’ve seen XLA.jl but it seemed too focus on TPUs.

Logical conclusion. Give up on Python. Coalesce around Julia and make more non-data-science purpose libraries in Julia. That day will come.


I think the general direction we’re heading is to use MLIR (an LLVM project) as Julia’s backend compiler and Intermediate Representation. MLIR has “dialects” like Affine which allow representing operations on tensors natively within the IR, as well as optimization passes which can operate on “tensor IR”. So it should be able to get us to a similar level of performance as XLA, while still providing full access to LLVM via the LLVM dialect.


Are you saying the plan is to replace Julia’s IR with an MLIR dialect, or that it will be one of the IRs in the stack?

This is amazing to hear. Are you referring to the Tensor Compute Primitives Proposal as the “tensor IR” in question?

Also, have any of the JuliaGPU contributors looked into MLIR-based tensor/linalg runtimes such as IREE? I know XLATools.jl exists, but being able to wrap a native library instead of jaxlib seems like a plus :slight_smile:

It would become another IR in the stack, most likely. However, to make full use of it, we’ll need some way to “lower” Julia’s array operations into the statements that match the semantics we need, so Julia’s own IR(s) could potentially be influenced by this work to make said lowering easier to implement.

1 Like

You’re probably right. I’m not the MLIR expert or developer by any means, I’m just communicating what I know about some of the work that’s being done, so everything I say about this should be taken with a grain of salt :slightly_smiling_face:

I can’t speak for the JuliaGPU contributors, but I haven’t heard anything about anyone targeting IREE. I suspect that’s because we don’t yet have MLIR support in Julia, and that’s probably a blocker to targeting IREE. I suspect in a few months it’ll be worth putting that option on the table.