Thoughts on Jax vs CuArrays and Zygote

#1

Google’s JAX looks like a combination of CuArrays and Zygote for Python. I wonder if anyone can compare and contrast Jax with CuArrays and Zygote?

I can’t gather how Jax is different apart from the fact that it compiles to XLA.

4 Likes

#2

I’m also interested in a comparison. My understanding is that Jax supports AD on a subset of Python+Numpy. I suspect it would be a lot of effort for Jax to keep and expand support, but I’m not sure if Julia and the various AD packages (Zygote in particular) are competitive yet, e.g. on par with functionality in the Jax cookbook: https://colab.research.google.com/github/google/jax/blob/master/notebooks/autodiff_cookbook.ipynb

3 Likes

#3

Just saw this and I’m curious as well (I don’t know much about Jax). @MikeInnes could you share your thoughts?

1 Like

#4

Jax is really nice; at this point TensorFlow and PyTorch have converged to being hybrids of each other, with a bunch of different ways of doing the same thing (“eager mode”, “static graphs”). Jax really refreshes and cleans up this design, with well-thought-out semantics and interfaces, along with having a lot of nice ideas around composable code transforms that are very much on our wavelength (c.f. Cassette).

Python still has the fundamental limitations we discussed a while back, though. The major development since then is that explicit graph building has been replaced with eager semantics + tracing; this is much more intuitive but technically largely equivalent. So in TF 2.0, JAX, or PyTorch JIT, if you want control flow and performance, you’ll still need to replace your loops with framework-compatible versions, and this has limitations around recursion, mutation, custom data structures etc., and in particular it can’t differentiate through any other python library code.

The paper on TF 2.0, which shares many ideas with Jax, discusses this a bit as well:

In TensorFlow Eager, users must manually stage computations, which might require refactoring code. An ideal framework for differentiable programming would automatically stage computations, without programmer intervention. One way to accomplish this is to embed the framework in a compiled procedural language and implement graph extraction and automatic differentiation as compiler rewrites; this is what, e.g., DLVM, Swift for TensorFlow, and Zygote do. Python’s flexibility makes it difficult for DSLs embedded in it to use such an approach.

8 Likes