What are the future plans for Scientific ML?

I wonder if someone can tell us about future plans? I did not see Reactant.jl mentioned here. It seems to me that this would be the Julia version of JAX, openXLA. Just like openXLA, it uses the MLIR from LLVM.

I also wondered if MLIR would be a better target for Julia proper instead of the old LLVM API. Does this even make sense? I believe Mojo uses MLIR directly.

You can view it like Jax for Julia, because we both use XLA for compilation and use a “operator overloading”-like tracing strategy for constructing the computational DAG, but there are a couple of differences:

  • We can trace mutation.
  • No need for vmap: Julia’s broadcasting automatically lowers to enzyme.batch MLIR op which will lower to batched MLIR/StableHLO ops in the future (it works, we just need to add a optimization pass for that).
  • Differentiation:
    • In Jax, differentiation is performed on the computational DAG.
    • In Reactant, we directly create the MLIR code (no intermediate DAG) and use Enzyme (actually a repo called Enzyme-JAX which also adds that functionality to Jax) for differentiation, which gives us more differentiation superpowers like differentiation over mutation, control flow and ccalls.
      • Actually, you don’t need to set the chain rule for a Julia function (although you can for performance). It will automatically be diff by Enzyme’s MLIR backend if you just teach Reactant how to lower it to MLIR.
  • We’re currently adding tracing capabilities for control flow, something that Jax cannot add due to Python:
    • In principle, “operator overloading” tracing is like some kind of partial evaluation which makes that (1) you cannot use a jax.array as part of the condition and (2) it will compile only the branch taken during tracing.
    • In Julia, thanks to macros, you can just do @trace if ... and you will have both branches compiled. Support is coming for for and while loops.
  • In the future, we want to add the capability of writing your own GPU kernels: You write your kernel with CUDA.jl / AMDGPU.jl, and we take its LLVM IR and inject it into the MLIR. No need from you to do anything fancy.
  • Enzyme-JAX adds some optimizations like op fusion which should simplify and speedup the code.
  • …and more features in the future that we’re still starting to work on. Stay tuned :wink:

That’s exactly Brutus. @Pangoraw reimplemented it on top of MLIR.jl and @merckxiaan has this nice PR to add better support for it on MLIR.jl based on AbstractInterpreter.

It’s not there yet because we need better MLIR C-API, specially for the IRDL dialect, to write a Julia MLIR dialect which doesn’t depend so much on the LLVM version or build. Some effort is being done but we’re currently more focused on Reactant.jl

7 Likes