Will Reactant.jl become a machine learning framework?

I made an observation that MLIR + Enzyme could form quite a capable ML framework, and then, well… like the last time I theorized about how a good language would work, and Julians were one step ahead of me, creating Julia, this time, they’ll do it again?

Whatever brilliant idea I came up with, I usually end up finding out that someone else is already doing that, and they actually can do it, but I’m glad someone figured it out and did it.

So, anyway, back to the main question, will Reactant.jl become a machine learning framework?

I think ML framework is a bit of a misnomer, but seems to be fairly common for how people also refer to things like JAX. But I believe it will definitely be very useful for ML.

So, anyway, back to the main question, will Reactant.jl become a machine learning framework?

Reactant is to Julia what Jax is to Python (see arXiv: "The State of Julia for Scientific Machine Learning" by Berman & Ginesin - #41 by mofeing for a comparison on this).

Reactant already converts NNlib functions to the corresponding StableHLO calls. If you are using Lux, most of the Lux tutorials (Tutorials | Lux.jl Docs) currently use Reactant. You can think of Lux being a nicer frontend for Reactant for ML tasks with high-level layer implementations (similar to how Equinox/Flax makes it nicer to deal with Jax).

One of the final bits that remains to be done in Reactant, is to extend its support for the SciML packages (we need some features like custom adjoints and automatic tracing of loops without @trace macro). Till then my general recommendation is to use Lux + Reactant for ML tasks, Lux + Zygote (or Enzyme) for SciML tasks.

(there is also some work to integrate reactant into flux Support for Reactant.jl by mcabbott · Pull Request #28 · FluxML/Fluxperimental.jl · GitHub)

11 Likes

Doesn’t Reactant invalidate Julia’s goal to “solve the two-language problem”?

AFAIK, I can use the entire XLA machinery from Python with JAX. Plain Python is too slow, doesn’t run on GPUs and doesn’t even have native support for multidimensional arrays, let alone autodiff, so I have to resort to JAX+XLA written in C. There are two two-language problems here: Python vs C, as well as Python vs JAX, because JAX is like a domain-specific language that can be called from Python.

Reactant seems to be roughly the same as JAX, but for Julia. Why? Isn’t Julia fast enough? Doesn’t it support the necessary features like multidimensional arrays and dynamic dispatch for implementing autodiff? Why do I need to @compile a Reactant function if all Julia functions are compiled anyway? I understand this compiles to XLA while Julia compiles to native code, but this literally calls into another language from Julia, so why use Julia if I can do the same from Python? (Without having to specify when I want my function to be compiled, by the way. I just call it and JAX automatically compiles it when needed) Also, does it mean that Julia’s autodiff (Zygote, ForwardDiff, Mooncake) isn’t good enough, so we go back to XLA, thus seemingly not gaining any advantage over Python?

Seems like Reactant introduces the 2-language problem into Julia…

I think you are right about the two language problem, But practically famous frameworks and packages use many languages (not just two), So why reinvent the wheel, I think integration between language make great things.
By the way mooncake written in Julia.