Thoughts on JAX vs Julia

Hi everyone.

I just posted to Forem about my thoughts on using JAX and Julia:

I didn’t write it in the post, but I’d highly recommend everyone who hasn’t tried JAX yet to give it a go (and I would say the same to JAX users who haven’t tried Julia).


Just a heads-up: This sentence made me double back twice:

JAX offers a DSL with many restrictions on function purity and control flow and writing idiomatic Julia code is necessary to achieve its maximum performance.

I thought I had to write idiomatic Julia code to get max performance from JAX.


True, I rewrote the second half without touching the first and you’re right it doesn’t work anymore.

While i agree that jax is more restrictive in terms of types vs. Julia (unless you do a lot of pytree stuff), i would say jax is less restrictive in terms of its automatic differentiation capabilities.

There’s things jax can do with a single method call that julia cant even dream of. Perfect example is obysics informed neural networks where you need to reverse diff over a loss function defined with forward diffs. There’some work arounds in NeuralPDE.jl using finite differences, but that really kills the hopes of complex geometry without a lot of work.

Maybe this is just a symptom of julia having a wild west frontier of AD packages. But really, jax should be the traget for an interface design due to how easy it is to use and to extend if we you happen to hit an edge case where a derivative isnt defined well analytically (think repeat eigenvalue in SVD).

Just my two cents.


Interesting. I’m not very familiar with Jax so please forgive any naivety on my behalf here. I’m rather surprised by your comment as I was under the impression that this is exactly the kind of thing Julia would be extremely good for. Edge case models that do not have a super optimized python workflow due to being less popular are usually beneficial to implement in Julia. More importantly, why is it currently hard for Julia to solve this particular case in your point of view?

what about Jax can’t do if-else

what about Jax can’t even handle variable-lenth vector (i.e. a normal vector)

what about Jax can’t handle while loop (any non-unrolled control florw)


Well for PINNs it’s not just hard, I dont think its doable currently (unless theres a new AD package I am unaware of). I’ve tried with zygote and forward and reverse diff and kept running into either errors trying to AD some AD code or zygote just hung all night trying to compile.

With jax, you dont need to necessarily worry about optimized python since their jit method takes care kf robustness as long as you don’t use native python loops.

1 Like

Jax can handle conditional branching, maybe not native python if else but they have a functional version.

Sure they can’t handle varibale array sizes in some situations but I dont think that’s a game changer. You could likely overcome this with some pytree games.

Jax can handle a while loop you just cant forward mode differentiate it since while loops are never defined for forward mode since you don’t know the number of iterations it will run.


Thanks for the reply. I’ll definitely look some more into Jax.


Well Jax’s approach has asymptotically terrible performance though, it’s O(n^d) instead of the O(n) NeuralPDE takes on the number of neural network parameters. So there’s a trade-off there that’s important to note, a giant *** with it.

That’s a big overstatement. Take a look at:

Yes, Jax is fine if you use the special lax stuff, but you have to look at the other part of the docs to see the big caveat.


Oh, in order to do the thing that actually finishes (because the “normal form” doesn’t even finish), you have to do something that disables reverse mode AD compatibility.

There are some nice things about Jax’s AD and some nice things about Julia’s AD’s, but there’s no reason to oversell one or the other. Basically:

Enzyme is amazing with the huge caveat that right now you need to avoid the Julia runtime. But if you write a fully non-allocating code, it’s pretty much always the winner (by a good margin). Really high performance ceiling, you can get a ton of performance out of it, but a really high skill floor (I wouldn’t expect most users of Julia to know what I mean by “avoid the Julia runtime”).

Jax is nice because it’s always decent. Its loop handling is meh, compiler passes meh, reverse mode optimizations meh (some optimizations not possible due to when it traces and such), higher order meh (missing optimizations). Its not very automated (you have to rewrite your code to nice functional programming pure styles, rewrite your loops to lax stuff, etc.), and it sometimes defaults to fastmath types of kernels which would scare a numerical analyst (we only found this by digging into a few cases). It’s missing high end features to control memory and do manual SIMD. So it has a lower performance ceiling, but a higher performance floor. It’s the League of Legends to the DOTA. Decent everywhere at least makes it easy to use in some sense though: there’s no major caveat once you’ve rewritten your code.

ForwardDiff really hits a lot of nice points: it allows manual preallocation (GitHub - SciML/PreallocationTools.jl: Speed at all costs) mixed with chunking, so the only thing that really beats it in performance is Enzyme. But it’s a lot more automatic (“most” well-written generic codes work with ForwardDiff without even changing it). It’s pretty pareto optimal (which is why people love to use it!), but it’s only forward-mode and its design does not extend to reverse mode.

ReverseDiff non-compiled mode is pretty automatic but doesn’t support GPUs. Its compiled mode is extremely fast (almost as fast as Enzyme) but doesn’t support any runtime-valued control flow (i.e. only handles quasi-static like Jax). The main issue is that its compiled mode isn’t documented very well, so very few use it.

Zygote is Zygote. Its main feature is that it’s really easy to add ChainRules and define your own derivatives: you can teach a class of undergrads and get them extending the ecosystem in a single class period. But it’s only about as automatic as Jax because you have to avoid mutation, and its compiler design has some performance issues.

Enzyme is growing quickly, and it’s the right design. It’ll be “the” thing for Julia in about 2 years, but until its GC support is robust it’s hard to recommend it beyond experts (since most people won’t know what operations cause the GC to hit :sweat_smile:). But SciML already defaults to it most of the time internally for many things, we just cannot when we mix Flux in (because of Flux intricacies).


will Diffractor ever be useful before useless?


I wasn’t trying to oversell jax or julia. Personally i wish julia’s ecosystem was where it needed to be for my purposes but ill sacrifice O(N^D) run time for O(1) developer time vs. O(p) developer time where p is the number of julia AD packages.


Is it so much O(p) vs O(N/A) because no libraries check enough boxes for you? I think part of the disconnect in these discussions is that JAX is purposely designed to support “mainstream” ML (really DL) use cases first and foremost. It is really, really difficult to overemphasize how much limitations in Python ML libraries (including their ADs) influences ML research. All the tradeoffs that have been discussed in @jacobusmmsmit’s article and on this forum stem from that.

That said, you may well ask why we can’t just have a slice of Julia ecosystem which works like JAX. After all, we have ADs which do tracing, libraries to do similar optimizations and not one but two attempts to make Julia talk to XLA. I’m not sure what the answer is, but throwing out some ideas:

  • Libraries like JAX can only come from a team of full-time contributors backed by big tech resources. Thus trying to replicate them with only part-time volunteer effort is a fool’s errand.
  • Most people who want something like JAX…just use JAX :stuck_out_tongue: . Nothing wrong with that, it’s the rare individual who feels motivated to write an AD system in order to get their other work done.
  • Most AD authors (in any language) are scratching itches for their specific problem domain. In Julia land, this has only sometimes historically been “standard” ML, whereas in Python land it is almost always that.
  • Many Julia AD users do not find the benefit/tradeoff ratio of JAX large enough to roll up their sleeves working for something better. For those working on e.g. highly scalar numeric code with loops, mutation and low overhead requirements, the numerator there is close to zero. For others where JAX is starting to gain a foothold (e.g. astronomy/physics, going off JuliaCon talks), the ratio may be ~1 but still not over the activation energy required to shake up the status quo in the Julia ecosystem.
  • Relatedly, for a while any clamouring for a shake-up was assuaged by talk and demos of next gen AD tech which would address many of the existing gaps in the ecosystem. I believe PINNs were even a poster child example for these. However, those schedules slipped something fierce and we’re now in a bit of a collective limbo when it comes to what to do in the short + medium term (Chris already described part of the long term plan).

These are all the same thing. Diffractor’s category theory background is designed to make higher order differentiation have asymptotically better scaling. It’s great for PINNs. I’m not sure other use cases have much of a use for it though, so I don’t tend to think about it as a general purpose AD as much. It has the same major limitation of Zygote (mutation support), so :person_shrugging:.

Yeah, I don’t fault you for that. There’s a lot in flux right now.


Just a tiny comment: very well written @jacobusmmsmit! It’s a nice read and I also mostly agree with your points.

Btw. I really had to laugh out loud while reading this :laughing:

The best efforts in the Python ecosystem to replicate Cargo 's successes are Poetry (v1.0 in December 2019) and Hatch (v1.0 in April 2022). However, their adoption has been less-than-instant in the scientific community, a people well known for unenthusiastically learning badly-designed software and developing Stockholm syndrome once something new comes along.


@jacobusmmsmit I’d love to see a couple examples and/or understand the general types of situations where you’re forced to use a for loop in Julia to achieve similar functionality/performance to Jax.

When array programming in Julia I’ve often had a similar feeling of wishing there was a more functional primitive to replace a for loop, but I haven’t been able to pin down what sort of cases this happens (Julia has got the common cases down pretty well with broadcasting etc.), and not having used Jax much I don’t have a good sense of the alternatives, so would love to hear your thoughts on this :slightly_smiling_face:

1 Like

Could there be an Enzyme Linter?

that would be a bit hard because the runtime support isn’t nothing, it’s just spotty. Right now if you throw a code down and it allocates, it could work, or it could hit part of the GC which isn’t currently covered. And the coverage is changing every month. Such a linter would just take time of just improving the coverage :sweat_smile:


In the case of GC specifically (aka this would not link Type unstable code), you could take the current Enzyme GC PR and have it always emit an error when the requested cache allocation contains a Julia object, rather than the current in progress “have it function within Julia’s GC” and that would probably catch most GC (not other) issues.

Of course from the dev side on our perspective we’re chugging away at making that work generically rather than stopping halfway, but someone who has cycles is more than free to add that!


It seems like this is likely the best path. If at some point it becomes unreasonable to add further features in reasonably timeframe, that would be the time to start linting/error messaging.