Hey Chris, where are you getting that @pytime
macro? That seems like a huge nice-to-have and I can’t find it published anywhere.
Apologies for the delay in replying. My frustration with Julia in this regard essentially boils down to how imperative programming is slower than declarative programming because computers are imperative i.e. it’s not a frustration with Julia, but with computers themselves. JAX gets around this by being very limited in scope, and so it can make better optimisations from the same code due to being able to make more strict assumptions about what code will do.
While JAX isn’t fully declarative (whatever that means), I think vmap
is a fantastic abstraction and I wish something with similar syntax existed in Julia. That said, JAX’s philosophy of function transformations + immutability works well with a lot of my code, but the times where it doesn’t (anything that doesn’t vmap
nicely) I wish I could instead use Julia interleaved with blocks of JAX/XLA’s limitations in order to use its compiler.
On vmap
and broadcasting. I much prefer the vmap(f, (0, None))(x, y)
syntax over f.(x, Ref(y))
as it clearly separates the verb from the object. Everyone loves abstraction.
The other alternatives in JAX are essentially just there to make compilation of for
loops faster by providing a progressively less limiting function to allow you to do progressively more stuff that JAX wasn’t designed for. lax.fori_loop
is implemented by lax.scan
or lax.while_loop
, and if lax.fori_loop
won’t work then you have to use a regular for
loop and suffer the compilation time of completely having it unrolled and each lime compiled separately. See this implementation for an example when none of the loop primitives were possible, due to JAX not being able to vmap
over an array of functions (also I need to remove that TODO).