There are a few related things here to pick apart.
Firstly, jax.jit
seems to have a much lower fixed compilation overhead than most Julia ADs. However, I’ve heard enough reports of people favouring Julia ADs despite the long TTFG because it was faster than JAX for their large models. Either way, TTFG is a big problem and we need better solutions for it.
Secondly, there’s a bit of a technical mixed with a philosophical problem here. If you asked someone to write the original loss function in pure Julia, they probably would’ve written something like what @yolhan_mannes did: minimal allocation and very loopy. Not surprisingly, this does well with certain ADs. If you asked someone to write the same function like they were a Python programmer, they probably would’ve written something like @mcabbott’s best-performing examples: fully vectorized operations, minimal looping and exploiting vectorized operator fusion where possible. That does well with other ADs. What does not work well is doing something in between. For example, looping over slices of an input array and apply vectorizing (i.e. expensive and allocating) operations to each.
So why do we see this kind of code again and again in the real world? One factor is that most people who are familiar with AD in Python and Julia are generally more comfortable with the former than the latter. This means that code examples are more likely to be in the “in between” style which is a worst case for Julia ADs. This IMO is an education and documentation problem: we need to direct people towards either writing more Pythonic vectorized code or more Julian scalarized code depending on their use case.
But the other side is that the Python AD libraries provide a better “pit of success” for users trying to write idiomatic code. For example, the JAX example in the OP uses vmap
, while @mcabbott’s examples had to do some/all of that vectorization by hand. Granted, I think there’s still a cultural/familiarity aspect in that a generator comprehension over array slices would be immediately flagged as a performance problem by any proficient JAX/PyTorch/Numpy user, but this challenge of making the fast path the obvious one has been an evergreen once ever since I started following the Julia AD ecosystem.