I wonder if someone can tell us about future plans? I did not see Reactant.jl mentioned here. It seems to me that this would be the Julia version of JAX, openXLA. Just like openXLA, it uses the MLIR from LLVM.
I also wondered if MLIR would be a better target for Julia proper instead of the old LLVM API. Does this even make sense? I believe Mojo uses MLIR directly.
You can view it like Jax for Julia, because we both use XLA for compilation and use a “operator overloading”-like tracing strategy for constructing the computational DAG, but there are a couple of differences:
We can trace mutation.
No need for vmap: Julia’s broadcasting automatically lowers to enzyme.batch MLIR op which will lower to batched MLIR/StableHLO ops in the future (it works, we just need to add a optimization pass for that).
Differentiation:
In Jax, differentiation is performed on the computational DAG.
In Reactant, we directly create the MLIR code (no intermediate DAG) and use Enzyme (actually a repo called Enzyme-JAX which also adds that functionality to Jax) for differentiation, which gives us more differentiation superpowers like differentiation over mutation, control flow and ccalls.
Actually, you don’t need to set the chain rule for a Julia function (although you can for performance). It will automatically be diff by Enzyme’s MLIR backend if you just teach Reactant how to lower it to MLIR.
We’re currently adding tracing capabilities for control flow, something that Jax cannot add due to Python:
In principle, “operator overloading” tracing is like some kind of partial evaluation which makes that (1) you cannot use a jax.array as part of the condition and (2) it will compile only the branch taken during tracing.
In Julia, thanks to macros, you can just do @trace if ... and you will have both branches compiled. Support is coming for for and while loops.
In the future, we want to add the capability of writing your own GPU kernels: You write your kernel with CUDA.jl / AMDGPU.jl, and we take its LLVM IR and inject it into the MLIR. No need from you to do anything fancy.
Enzyme-JAX adds some optimizations like op fusion which should simplify and speedup the code.
…and more features in the future that we’re still starting to work on. Stay tuned
That’s exactly Brutus. @Pangoraw reimplemented it on top of MLIR.jl and @merckxiaan has this nice PR to add better support for it on MLIR.jl based on AbstractInterpreter.
It’s not there yet because we need better MLIR C-API, specially for the IRDL dialect, to write a Julia MLIR dialect which doesn’t depend so much on the LLVM version or build. Some effort is being done but we’re currently more focused on Reactant.jl