I’ll offer a perspective from someone who (as a conscious choice) primarily uses Python over Julia. I work with, and maintain libraries for, all of PyTorch, JAX, and Julia.
For context my answers will draw some parallels between:
- JAX, with Equinox for neural networks;
- Julia, with Flux for neural networks;
as these actually feel remarkably similar. JAX and Julia are both based around jit-compilers; both ubiquitously perform program transforms via homoiconicity. Equinox and Flux both build models using the same Flux.@functor
-style way of thinking about things. etc.
Question 1: where does ML-in-Julia shine?
(A) Runtime speed.
Standard Julia story there, really. This is most noticable compared to PyTorch, at least when doing operations that aren’t just BLAS/cuDNN/etc.-dominated. (JAX is generally faster in my experience.)
(B) Compilation speed.
No, really! Julia is substantially faster than JAX on this front. (It really doesn’t help that JAX is essentially a compiler written in Python. JAX is a lovely framework, but IMO it would have been better to handle its program transformations in another language.)
It’s been great watching the recent progress here in Julia.
(C) Introspection.
Julia offers tools like @code_warntype
, @code_native
etc. Meanwhile JAX offers almost nothing. (Once you hit the XLA backend, it becomes inscrutable.) For example I’ve recently had to navigate some serious performance bugs in the XLA compiler, essentially by trial-and-error.
(D) Julia is a programming language, not a DSL.
JAX/XLA have limitations like not being able to backpropagate while loops, or being able to specify when to modify a buffer in-place. As a “full” programming language, Julia doesn’t share these limitations.
Julia offers native syntax, over e.g. jax.lax.fori_loop(...)
.
(PyTorch does just fine on this front, though.)
Question 2.
(A) Poor documentation.
If I want to do the equivalent of PyTorch’s detach
or JAX’s stop_gradient
, how should I do that in Flux?
First of all, it’s not in the Flux documentation. Instead it’s in the separate Zygote documentation. So you have to check both.
Once you’ve determined which set of documentation you need to look in, there are the entirely separate Zygote.dropgrad
and Zygote.ignore
.
What’s the difference? Unclear. Will they sometimes throw mysterious errors? Yes. Do I actually know which to use at this point? Nope.
(B) Inscrutable errors.
Whenever the developer misuses a library, the compilation errors messages returned are typically more akin to “C++ -template-verbiage” than “helpful-Rust-compiler”. (That is to say, less than helpful.) Especially when coupled with point (A), it can feel near-impossible to figure out what one actually did wrong.
Moreover at least a few times I’ve had cases where what I did was theoretically correct, and the error was actually reflective of a bug in the library. (Incidentally, Julia provides very few tools to library authors to verify the correctness of their work.)
Put simply, the trial-and-error development process is slow.
(C) Unreliable gradients
I remember all too un-fondly a time in which one of my models was failing to train. I spent multiple months on-and-off trying to get it working, trying every trick I could think of.
Eventually (eventually) I found the error: Zygote was returning incorrect gradients. After having spent so much energy wrestling with points (A) and (B) above, this was the point where I simply gave up. Two hours of development work later, I had the model successfully training… in PyTorch.
(D) Low code quality
(D.1)
It’s pretty common to see posts on this forum saying “XYZ doesn’t work”, followed by a reply from one of the library maintainers stating something like “This is an upstream bug in the new version a.b.c of the ABC library, which XYZ depends upon. We’ll get a fix pushed ASAP.”
Getting fixes pushed ASAP is great, of course. What’s bad is that the error happened in the first place. In contrast I essentially never get this experience as an end user of PyTorch or JAX.
(D.2)
Code quality is generally low in Julia packages. (Perhaps because there’s an above-average number of people from academia etc., without any formal training in software development?)
Even in the major well-known well-respected Julia packages, I see obvious cases of unused local variables, dead code branches that can never be reached, etc.
In Python these are things that a linter (or code review!) would catch. And the use of such linters is ubiquitous. (Moreover in something like Rust, the compiler would catch these errors as well.) Meanwhile Julia simply hasn’t reached the same level of professionalism. (I’m not taking shots at anyone in particular here.)
[Additionally there’s the whole #4600 - FromFile.jl
- include
problem that IMO hinders readability. But I’ve spoken about that extensively before, and it seems to be controversial here, so I’ll skip any further mention of that.]
(D.3) Math variable names
APIs like Optimiser(learning_rate=...)
are simply more readable than those like Optimiser(η=...)
. (I suspect some here will disagree with me on this. After all, APL exists.)
(E) Painful array syntax
- Julia makes a distinction between
A[1]
and A[1, :]
;
- The need to put
@view
everywhere is annoying;
- The need for
selectdim
over something (NumPy-style) like A[..., 1, :]
reduces readability.
- The lack of a built-in
stack
function is annoying (to the extent that Flux provides one!)
Array manipulation is such an important part of ML, and these really hinder usability/readability. One gets there eventually, of course, but my PyTorch/JAX code is simply prettier to read, and to understand.
(F) No built-in/ubiquitous way to catch passing arrays of the wrong shape; a very common error. (At least that I know of.)
Meanwhile:
- JAX probably has the best usable offering for this. Incorrect shapes can be caught during jit compilation using an
assert
statement. The only downside is that actually do is very unusual (not even close to culturally ubiquitous), probably because of the need for extra code.
-
Hasktorch and Dex encode the entirety of an array’s shape into its type. Huge amounts of safety, ubiquitously. Only downside here is that both are experimental research projects.
- PyTorch has torchtyping, which provides runtime or test-time shape checking, essentially as part of the type system.
My “dream come true” in this regard would be something with the safety of the Rust compiler and the array types of TorchTyping.
Question 3
My experience has been that all of PyTorch/JAX/Julia are fast enough. I don’t really find myself caring about speed differences between them, and will pick a tool based on other considerations. (Primarily those listed above.)
Question 4
Maybe not an “experiment” in the sense you mean, but – more Q&As like this one, in particular at other venues where there’s likely to be more folks that have (either just for a project or more broadly) decided against Julia.
Question 5
Best case argument:
Imagine working in an environment that has both the elegance of JAX (Julia has arrays, a jit compiler, and vmap all built-in to the language!) and the usability of PyTorch (Julia is a language, not a DSL!) Julia still has issues to fix, but come and help pitch in if this is a dream you want to see become reality.
Impactful contributions:
- Static compilation. Julia’s deployment story is simply nonexistent, and IMO this sharply limits its commercial applicability.
- Better autodifferentiation. I know there’s ongoing work in this space (i.e. Diffractor.jl, which I haven’t tried yet) but so far IMO Julia hasn’t yet caught up to PyTorch/JAX on this front.
- Fixing all of the negatives I raised above. Right now none of those are issues suffered by the major Python alternatives.
Question 6
As Q5.
Question 7
What packages? Right now, I’m tending to reach for JAX, Equinox, Optax (all in Python).
Why those packages? They provide the best trade-off between speed/usability for me right now.
What do I wish existed? Solutions to the above problems with Julia ML. I find that I really like the Julia language, but its ML ecosystem problems hold me back.