Thoughts on Jax vs CuArrays and Zygote

Jax is really nice; at this point TensorFlow and PyTorch have converged to being hybrids of each other, with a bunch of different ways of doing the same thing (“eager mode”, “static graphs”). Jax really refreshes and cleans up this design, with well-thought-out semantics and interfaces, along with having a lot of nice ideas around composable code transforms that are very much on our wavelength (c.f. Cassette).

Python still has the fundamental limitations we discussed a while back, though. The major development since then is that explicit graph building has been replaced with eager semantics + tracing; this is much more intuitive but technically largely equivalent. So in TF 2.0, JAX, or PyTorch JIT, if you want control flow and performance, you’ll still need to replace your loops with framework-compatible versions, and this has limitations around recursion, mutation, custom data structures etc., and in particular it can’t differentiate through any other python library code.

The paper on TF 2.0, which shares many ideas with Jax, discusses this a bit as well:

In TensorFlow Eager, users must manually stage computations, which might require refactoring code. An ideal framework for differentiable programming would automatically stage computations, without programmer intervention. One way to accomplish this is to embed the framework in a compiled procedural language and implement graph extraction and automatic differentiation as compiler rewrites; this is what, e.g., DLVM, Swift for TensorFlow, and Zygote do. Python’s flexibility makes it difficult for DSLs embedded in it to use such an approach.

21 Likes