Will Reactant.jl become a machine learning framework?

I made an observation that MLIR + Enzyme could form quite a capable ML framework, and then, well… like the last time I theorized about how a good language would work, and Julians were one step ahead of me, creating Julia, this time, they’ll do it again?

Whatever brilliant idea I came up with, I usually end up finding out that someone else is already doing that, and they actually can do it, but I’m glad someone figured it out and did it.

So, anyway, back to the main question, will Reactant.jl become a machine learning framework?

I think ML framework is a bit of a misnomer, but seems to be fairly common for how people also refer to things like JAX. But I believe it will definitely be very useful for ML.

1 Like

So, anyway, back to the main question, will Reactant.jl become a machine learning framework?

Reactant is to Julia what Jax is to Python (see arXiv: "The State of Julia for Scientific Machine Learning" by Berman & Ginesin - #41 by mofeing for a comparison on this).

Reactant already converts NNlib functions to the corresponding StableHLO calls. If you are using Lux, most of the Lux tutorials (Tutorials | Lux.jl Docs) currently use Reactant. You can think of Lux being a nicer frontend for Reactant for ML tasks with high-level layer implementations (similar to how Equinox/Flax makes it nicer to deal with Jax).

One of the final bits that remains to be done in Reactant, is to extend its support for the SciML packages (we need some features like custom adjoints and automatic tracing of loops without @trace macro). Till then my general recommendation is to use Lux + Reactant for ML tasks, Lux + Zygote (or Enzyme) for SciML tasks.

(there is also some work to integrate reactant into flux Support for Reactant.jl by mcabbott · Pull Request #28 · FluxML/Fluxperimental.jl · GitHub)

13 Likes

Doesn’t Reactant invalidate Julia’s goal to “solve the two-language problem”?

AFAIK, I can use the entire XLA machinery from Python with JAX. Plain Python is too slow, doesn’t run on GPUs and doesn’t even have native support for multidimensional arrays, let alone autodiff, so I have to resort to JAX+XLA written in C. There are two two-language problems here: Python vs C, as well as Python vs JAX, because JAX is like a domain-specific language that can be called from Python.

Reactant seems to be roughly the same as JAX, but for Julia. Why? Isn’t Julia fast enough? Doesn’t it support the necessary features like multidimensional arrays and dynamic dispatch for implementing autodiff? Why do I need to @compile a Reactant function if all Julia functions are compiled anyway? I understand this compiles to XLA while Julia compiles to native code, but this literally calls into another language from Julia, so why use Julia if I can do the same from Python? (Without having to specify when I want my function to be compiled, by the way. I just call it and JAX automatically compiles it when needed) Also, does it mean that Julia’s autodiff (Zygote, ForwardDiff, Mooncake) isn’t good enough, so we go back to XLA, thus seemingly not gaining any advantage over Python?

Seems like Reactant introduces the 2-language problem into Julia…

3 Likes

I think you are right about the two language problem, But practically famous frameworks and packages use many languages (not just two), So why reinvent the wheel, I think integration between language make great things.
By the way mooncake written in Julia.

2 Likes

Reactant is emphatically not an autodiff tool. Of course you can use Enzyme from inside it just like you can use it within Julia presently (now it will use EnzymeMLIR instead of EnzymeLLVM).

Reactant is a tool to compile Julia code in a way that preserves high level structure and semantics (e.g. understanding that Base.mul! is a matmul, and not just a bunch of for loops), getting rid of type instabilities, and generally compiling away performance and usability pitfalls in Julia. It does so by extending the existing Abstract Interpretation work in Julia’s compiler, and compiling to MLIR as a way to preserve and optimize this information.

Unlike JaX, working in the Julia compiler itself lets it generally work well with existing code. In particular, multiple dispatch means that you can keep using your existing Base.mul!(c, a, b) and not need to rewrite things as jax.numpy.mul(a, b). Mutation is supported, existing CUDA kernels are supported, control flow works nicely (with ongoing work here), etc.

That said, Reactant is not “JaX for Julia”. JaX is a fantastic project that brings JIT compilation, autodiff, and partial evaluation to Python. Reactant aims to answer a completely separate of “can we make it easier/better for people to write effective Julia code” (which we do by making a significantly more powerful compiler). It incorporates hundreds of tensor-level optimizations (e.g. transpose(matmul(A, B)) → matmul(B, A)) [for a complete list see Enzyme-JAX/src/enzyme_ad/jax/TransformOps/TransformOps.td at main · EnzymeAD/Enzyme-JAX · GitHub, which have been shown to provide double digit speedups for JaX code]. In contrast, Reactant aims to make it easy to write existing Julia code and have it automatically run on your favorite (CPU, GPU, TPU, etc) backend – including a distributed cluster thereof, without needing to rewrite your code. This includes automatically rewriting CUDA.jl kernels to effecively run faster on GPUS and even CPUs (see Reactant.jl/test/integration/cuda.jl at b506bfcf77ee68e60ccbbb649089f65f69bd709c · EnzymeAD/Reactant.jl · GitHub and https://dl.acm.org/doi/pdf/10.1145/3572848.3577475) for more information. Reactant gets rid of unnecessary allocations for temporaries and fuses loops together. Reactant enables you to write Julia code inside of other languages (see Exporting Lux Models to Jax (via EnzymeJAX & Reactant) | Lux.jl Docs) and make it easy for you to use other languages from Julia (Reactant.jl/test/integration/python.jl at main · EnzymeAD/Reactant.jl · GitHub), all while preserving nice things like autodiff, optimizations, and multi-device. These additional optimizations and usability make autodiff like EnzymeMLIR much faster/more effective (which is one reason we started the project).

I strongly disagree with your take on the two language problem here. If anything Reactant is trying to help Julia with the two language problem – by making it such that you can write as much code in Julia as possible, and still be able to run it effectively. Many of the problems that Reactant solves (type instabilties, burdensome allocations/GC time, etc) are quite hard and presently the best solution we have for them are tools to try to find where they happen (GitHub - MilesCranmer/DispatchDoctor.jl: The dispatch doctor prescribes type stability for type instabilities; GitHub - JuliaLang/AllocCheck.jl: AllocCheck for allocations, etc) and force the user to rewrite their code to a contorted subset of Julia code that might not have these issues after rewriting the code significantly.

Instead, Reactant solves this problem by building a compiler. This compiler aims to let you keep writing your favorite Julia code.

Of course the inner Reactant compiler isn’t pure Julia (though if you look at GitHub - EnzymeAD/Reactant.jl: Optimize Julia Functions With MLIR and XLA for High-Performance Execution on CPU, GPU, TPU and more., Github rounds to say most of our code is actually pure Julia), going through C++ because core compiler libraries like MLIR/XLA/LLVM are written in C++. The Julia compiler is written in C++ for the very same reason: julia/src/cgutils.cpp at 99fd5d9a92190e826bc462d5739e7be948a3bf44 · JuliaLang/julia · GitHub). Julia matrix libraries call fortran code (libopenblas). The CUDA compiler is written in C++ and the CUDA.jl binds you access to be able to use them in Julia. I could go on…

The point of the two language problem is that users have to write code in multiple languages to be effective. You are welcome to try using Julia without ever touching the core utilities that enable the user to avoid the two language problem (like BLAS, CUDA.jl, etc), but it would likely mean you stop using Julia (especially since the Julia compiler itself wouldn’t be usable).

In short, Reactant tries to make it such that you can write your favorite pure Julia code like perhaps

function foo(model, x, y)
   neural_net(model, mul(x,ocean_predict(y)))
end

and have it automatically run fast/effectively on your laptop and computing cluster (thanks to compiler-based linear algebra optimizations, parallel optimizations, device optimizations, autodiff optimizations, etc).

39 Likes

Could there be a future where the Reactant toolchain becomes the standard or will users need to opt into it for the foreseeable future?

1 Like

Maybe, and we do hope that many optimizations we’re developing can be integrated “natively” into Julia.

For example part of the infrastructure within the “@compile” macro is there because the base Julia compiler in 1.10/1.11 doesn’t support injecting custom typed code. This will partially be improved in 1.12, and potentially longer term we will be able to inject custom optimizations into Base julia.

That said, the project is currently in a research/design/exploration phase. We are figuring out what we can do and the design/implementation implications thereof. Reactant currently relies on partial evaluation for some parts of itself which may be difficult to enable without a construct to explicitly opt in. That said we are currently working on being able to automatically support advanced properties of control flow/function calls that may remove this requirement.

6 Likes

I think there’s a high level thing to note here about “what is a compiler” and “what a compiler is allowed to do” that has to be discussed here. When you write a code A*B + c, is it supposed to do A*B with BLAS and then tmp + c as a broadcast? If Julia is an imperative programming language where you define and call functions, the answer is… yes, duh, that’s how those functions are defined. But, people who know BLAS interfaces knows that there is a call to gemm! that does A*B + c together, and it’s faster. So, is the compiler allowed to reinterpret A*B + c as gemm!?

You might immediately think, duh, yes, that’s faster so do it. But now we’re contradicting our first intuition, which is that a programmer in the language should know it calls the functions they’ve actually written. And in fact this demonstration is nice because the two are not actually equivalent: you will get subtle floating point differences between a fused gemm! and an unfused A*B + c.

So then it comes down to philosophy: is a compiler allowed to change the floating point semantics in this case? But it’s a really case by case thing, because for example it’s already done with SIMD and muladd operations: both of those can change the floating point result, but a compiler can do this automatically. But you can see there’s different levels here:

  1. Julia and LLVM in general will try to SIMD automatically in loops when it can on the default -o2 setting.
  2. Julia will not change a*b + c to muladd(a,b,c) by default, though using @fastmath or MuladdMacro.jl are ways to make this more automatic, with one of course being built into the compiler / LLVM and just enabled in the pass stack.
  3. Julia will not change A*B + C to a single gemm!, but Reactant.jl will do that.

Should the compiler be allowed to change floating point of an output? If you say no, then it cannot auto-SIMD loops either: reductions like sum would get a different value. So I think at some level we have already loosed the idea of an imperative programming a bit and we’re using Julia language semantics somewhat declaratively: I meant sum(x), but I didn’t mean what the Julia function is precisely written as, you can modify that as needed to be faster.

But, if the user writes a*b+c, did they actually mean muladd(a,b,c)? That really is treating the code as a declarative spec: you said one very precise thing (multiply and then add), but I’m going to interpret it as “multiply and add”, and I’m going to replace it with something mathematically equivalent because it’s what you really meant, right? Reactant.jl is then taking that to a completely different level, you didn’t mean “multiply the matrix A by B and then add C”, which is what the imperative programming language semantics mean, you meant "A*B + C`, and I’m going to replace that with an equivalent mathematical expression that tends to be a better way to evaluate that.

When should a compiler be allowed to do? That it’s a very good question. If you have well-defined higher level functions, then maybe the compiler should be allowed to assume not just how the implementation is, but also the “intent” of the user, and be able to swap implementation based on higher level “intent”. Though that breaks many principles of “you know how it’s computed”, which is generally not done in many ways in an imperative programming language, because normally if you write a function * and + then it will use it. With Reactant :person_shrugging: it might just look at the LLVM and think “I know how to do that better”.

Because of this, I think there’s a lot of space to explore here. Not just in the ability to write such compilation passes, but also in the interface to the user. Should it be done by default or opt-in? If you want some passes but not others, can this be fine grained? It’s hard to tell how this evolves. But at least the safest way to built it out is to fully make it opt-in, which is Reactant.jl of today. But there have been prototypes of Julia to MLIR, such as a Brutus.jl and other things, so one major difference from Python is that it’s actually possible for something like Reactant to become just a standard part of the compiler some day. I think we still need to work out some interface ideas in order to really say that it should.

In the meantime, it’s really easy to opt-in and it does well. So that’s the current state, but likely not the last statement.

12 Likes

Ah, so Reactant is like an alternative compiler for Julia that can do all the nice optimizations and supports autodiff and neural nets etc. Then I get it: write Julia code (hopefully without explicit @compile calls and having to think about tracing tensors like in JAX), it’s translated to a common representation, optimized and then compiled for various hardware. This sounds great. Unfortunately, I couldn’t quite grasp what exactly Reactant is by reading the docs

1 Like

In fairness we are currently in the process of building and designing Reactant, so we haven’t really added docs. Contributions welcome!

I will say, however, that presently a lot of Reactant’s infrastructure is based off tracing/partial evaluation. So you don’t escape all of those concepts from JaX. That said it isn’t a requirement for us (for example, kernels don’t have any tracing, and we have the ability to support real control flow [available via a macro, automatically deduced in progress]).

I also imagine for a non-negligible time in the future you will need to call Reactant via a @compile or similar macro. We also have a simple @jit macro which just calls compile and run it, but in practice you’ll almost always want to explicitly compile once then use that compiled function. We’ll work on reducing compilation time and/or moving that into Julia compilation itself – but there’s plenty of other features that are high priority to be built.

7 Likes

I think this is one of the problems with the Julia language. There is no consensus on what the language or anything is allowed or not allowed to do. It was designed to be a general-purpose language, but that means dealing with several types of users with different interests. I think it gonna get even messier when deciding like what exactly is a number in Julia and so on. Is Quarternion a number? Different users with different interests will likely want different definitions and packages to adhere to their definitions, and it will be quite messy indeed.

So Reactant.jl is our APL?

just want to clarify that Reactant.jl uses the XLA compiler and the StableHLO dialect, which are designed with ML and array processing in mind. so if your code is not array-processing heavy, Reactant.jl won’t accelerate you much but Julia can still be great at it.

also, we allow to insert your own Julia function implementations if you want (currently only available for CUDA kernels) and compile together.

Reactant.jl is not replacing Julia. in any case, it’s making Julia better on some fields thanks to aggresive superoptimization. also, Julia features like multiple-dispatch and abstract interpretation allow us to do sooo much stuff.

6 Likes

It would be call to see a really cool demo/snippet to highlight the unique interplay between Julia and Reactant. Coming from a Jax-mindset, it is kind of hard for me to visualize what I’m missing (for sure there are many!). Though as I’m writing this reply, I’m currently writing some Jax code and wish I could just design it in a way that uses some of Julia’s features.

3 Likes