State of machine learning in Julia

Thanks for the shout-out on the CTPG paper! I whole-heartedly agree that the paper would not have been possible without Julia and DifferentialEquations.jl. It’s cool to see what can be done when you break out of the usual static-graph mode. Thanks and kudos to @ChrisRackauckas!

I work in ML, largely in Python, but I have a soft spot for Julia as well. I think @patrick-kidger’s response summarizes things very well. I’ll just chip in a few of my own experiences/thoughts:

What does ML even mean?

There are so many different types of models/problems/architectures these days that it’s worth pointing out that there’s a big difference between “conventional deep learning” – transformers, convnets, large models – and “other” more obscure models – differentiable physics, neural ODEs, implicit models, etc. So far I think Julia is doing better in the “other” category.

It’s worth noting that the requirements for “conventional” vs “other” can be drastically different. Everything from model parallelism to compute architecture to float32 vs float64.

Speed

Compilation speed is entirely irrelevant (cf. jax). What matters at the end of the day is iterations/second. Right now, JAX/XLA seem to have Julia beat in the “conventional large model” space since they have all kinds of optimizations for linear algebra, specific kernels, and TPUs. At this point just about every last drop of performance has been squeezed out of pytorch/TF/jax in the “conventional” large models space.

That being said, I am extremely bullish on the MetaTheory.jl line of work with e-graph based optimization. Ultimately I think this is a superior design than anything in the competition. But the devil will be in the details of making it production-ready esp. on GPUs/TPUs.

Correctness

Like @patrick-kidger, I have been bit by incorrect gradient bugs in Zygote/ReverseDiff.jl. This cost me weeks of my life and has thoroughly shaken my confidence in the entire Julia AD landscape. As a result I now avoid using Julia AD frameworks if I can. At minimum, I cross-check all of their results against JAX… at which point I might as well just use the JAX implementation. (Excited to check out Diffractor.jl when it’s ready though!)

In all my years of working with PyTorch/TF/JAX I have not once encountered an incorrect gradient bug.

Ecosystem and library scope

I really wish there was just something like JAX in Julia. Flux.jl is too high-level for me most of the time. Zygote is often too low-level. I like the idea of source-to-source AD though. Maybe we just need new frameworks on top of Zygote/Diffractor to spring up? I don’t know. I expect that solutions here will emerge naturally as more investment is made in ML in Julia and people bump into the limitations of existing tooling…

I’m optimistic for the future of ML in Julia. I really am. For me personally, it’s not ready for what I need it to do just yet. But I’m optimistic that this may change over time.

28 Likes