Thoughts on JAX vs Julia

Hey Chris, where are you getting that @pytime macro? That seems like a huge nice-to-have and I can’t find it published anywhere.

1 Like

Apologies for the delay in replying. My frustration with Julia in this regard essentially boils down to how imperative programming is slower than declarative programming because computers are imperative i.e. it’s not a frustration with Julia, but with computers themselves. JAX gets around this by being very limited in scope, and so it can make better optimisations from the same code due to being able to make more strict assumptions about what code will do.

While JAX isn’t fully declarative (whatever that means), I think vmap is a fantastic abstraction and I wish something with similar syntax existed in Julia. That said, JAX’s philosophy of function transformations + immutability works well with a lot of my code, but the times where it doesn’t (anything that doesn’t vmap nicely) I wish I could instead use Julia interleaved with blocks of JAX/XLA’s limitations in order to use its compiler.

On vmap and broadcasting. I much prefer the vmap(f, (0, None))(x, y) syntax over f.(x, Ref(y)) as it clearly separates the verb from the object. Everyone loves abstraction.

The other alternatives in JAX are essentially just there to make compilation of for loops faster by providing a progressively less limiting function to allow you to do progressively more stuff that JAX wasn’t designed for. lax.fori_loop is implemented by lax.scan or lax.while_loop, and if lax.fori_loop won’t work then you have to use a regular for loop and suffer the compilation time of completely having it unrolled and each lime compiled separately. See this implementation for an example when none of the loop primitives were possible, due to JAX not being able to vmap over an array of functions (also I need to remove that TODO).

7 Likes