Calling Python function JIT compiled with JAX from Julia without overhead

It could be very useful to call JIT compiled JAX functions from Python libraries from Julia, and in turn, to be able for Julia libraries to call user-provided JAX functions.

This is possible today via PythonCall, e.g.:

    using PythonCall
    jax = pyimport("jax")
    numpy = pyimport("numpy")
    
    # Define a simple test function
    func_str = """
    def f(x):
        return jax.numpy.sin(x) + jax.numpy.cos(x)
    """
    
    # Create namespace and define function
    namespace = pydict()
    namespace["jax"] = jax
    pyexec(func_str, namespace)
    py_func = namespace["f"]
    
    # Create test input
    x = numpy.array([1.0], dtype=numpy.float32)
    
    # Get lowered representation
    lowered = jax.jit(py_func).lower(x)
    println("HLO text:")
    println(lowered.as_text())
    
    # Also test compilation and execution through Python
    compiled = lowered.compile()
    @time result = compiled(x)
    println("\nTest execution result:", result)
    println("\nTest execution result:", sum(result))

The problem of course it that we have overhead from python, about 7us worth per call, because starting the JAX function still goes through Python (compiled(x)...).

It seems like it should be possible to call the JIT compiled function directly. At worse case, it should be possible to export the compiled JAX function and then run it without involving Python.

Here is a relevant GitHub discussion: Calling pre-compiled JAX code from C++ ¡ jax-ml/jax ¡ Discussion #22184 ¡ GitHub

Has anyone tried this, or have suggestions on how to proceed? I guess we would need to get the PJRT C library built and into Yggdrasil?

Might I recommend GitHub - EnzymeAD/Reactant.jl

cc @avikpal and @mofeing

there’s docs in lux (which we should move to reactant proper) for exporting julia into jax, we can make docs for doing the same in reverse. Exporting Lux Models to Jax (via EnzymeJAX & Reactant) | Lux.jl Docs

3 Likes

Hi @wsmoses , thanks for this!

I started digging into the Reactant code.

Through PythonCall I can jit compile a function and get its “executable”.

After that, it looks like the reactant function codegen_xla_call is the next thing to use, right?

kind of, honestly I think the answer here is that we should just add an Ops.hlo_call that takes a string of a stablehlo module like we have in Enzyme-Jax (Enzyme-JAX/test/testffi.py at main ¡ EnzymeAD/Enzyme-JAX ¡ GitHub).

Then with that you could even make a jax call like this or something

function jax_call(py_func, args...)
    lowered = jax.jit(py_func).lower(args...)
    Ops.hlo_call(lowered.as_text(), args...)
end

x = Reactant.ConcreteRArray(ones(10))
jlfn = @compile jax_call(py_func, x)
jlfn(x)

If you’re interested in helping add this let me know and I can help you with the setup!

4 Likes

I’ve spent some time reading up more on Reactant. Very cool!

Once a traced function is compiled with Reactant, is it cheap to call into that function? (Like a regular Julia function, without the cost imposed by starting a JAX kernel from Python). Ie the entire Julia code doesn’t have to be traced by Reactant to get good performance, right?
It looks like the answer is yes, which is great!

Similarly, if I use Reactant for one part of a computation, can one use enzyme to take gradients through a larger Julia program that happens to contain code compiled through reactant?

Thanks!

1 Like

It should be super cheap to call into the function yeah. And correct, you compile/trace once and can re use it as much as you want (aka you won’t retrace it).

So currently the setup assumes that Reactant will be on the outside of an Enzyme autodiff (if the autodiff uses reactant types). We may be able to support the other direction, but you’re going to get significantly better perf by doing reactant on the outside anyways, so I’m curious on your use case here.

2 Likes

Thanks, just trying to make get a clear picture of how it’s intended to work!

My main usecase is providing a Julia MCMC library to Python users. Ideally:

  • we expose the library using JuliaCall eg like PySR
  • they provide a JAX JIT traced Python function
  • we run and autodiff their function using Reactant and Enzyme, ideally on multiple threads

I’m also interested to use some Python Gaussian process libraries written in JAX within a larger Julia program, which might not yet be amenable to fully tracing with Reactant (but maybe I’m wrong).

Sounds like almost everything is already in place for this, which is great! If I want to help creating a reactant op for general XLA code, where should I start?

so @Pangoraw just added support for hlo_call in this PR (Add Ops.hlo_call(::String, args...) by Pangoraw ¡ Pull Request #358 ¡ EnzymeAD/Reactant.jl ¡ GitHub).

Probably the next thing to do in your case, is to add a Reactant extension for pythoncall of a TracedRArray and gets the stablehlo out of jax and instead does a corresponding Reactant.Ops.hlo_call

Want to give it a go? We’d be happy to help!

x/ref Trace over Python ¡ Issue #354 ¡ EnzymeAD/Reactant.jl ¡ GitHub

1 Like

Okay this is pretty darn cool!

Simple proof of concept for tracing and JIT compiling a python function with JAX, and then running it in Julia and autodiffing it with Enzyme:

f, g = jax_from_julia_with_grad(
    """
    def f(x):
        return jax.numpy.sum(jax.numpy.sin(x) + jax.numpy.cos(x))
    """,
    ones(Float32, 100)
)
r = Reactant.to_rarray(ones(Float32, 100))

julia> @time f(r)
  0.000034 seconds (4 allocations: 80 bytes)
(ConcreteRNumber{Float32}(138.17734f0),)

julia> @time g(r)
  0.000043 seconds (4 allocations: 96 bytes)
(ConcreteRArray{Float32, 1}(Float32[-0.30116868, -0.30116868, -0.30116868, -0.30116868, -0.30116868, -0.30116868, -0.30116868, -0.30116868, -0.30116868, -0.30116868  …  -0.30116868, -0.30116868, -0.30116868, -0.30116868, -0.30116868, -0.30116868, -0.30116868, -0.30116868, -0.30116868, -0.30116868]),)

Code below:

using Reactant
using Reactant: Ops
using Enzyme
using PythonCall
function jax_from_julia_with_grad(func_str, example_inputs)

    # Create namespace and define function
    namespace = pydict()
    namespace["jax"] = jax
    pyexec(func_str, namespace)
    py_func = namespace["f"]
    
    # Create test input
    x = numpy.array(example_inputs, dtype=numpy.float32)
    
    # Get lowered representation
    lowered = jax.jit(py_func).lower(x)

    _primal = @compile Ops.hlo_call(
        pyconvert(String, lowered.as_text()),
        Reactant.to_rarray(example_inputs)    
    )

    function _grad(i)
        function inner(inp)
            Ops.hlo_call(
                pyconvert(String, lowered.as_text()),
                inp
            )
        end
        return Enzyme.gradient(Reverse, inner, i)
    end
    __grad = @compile _grad(
        Reactant.to_rarray(example_inputs)    
    )


    return (args...)->_primal(nothing,args...), __grad
end

2 Likes

The above works (shockingly!) well as a way to call a given Python/JAX function without overhead.

Looking at how this could be integrated more smoothly into PythonCall + Reactant though, and it’s not clear that PythonCall has the necessary extension points.

The public interface to extend PythonCall is through adding new argument conversion rules. Here, we would need something that operates on an entire function call + any arguments that are TracedRArrays–something like:

using Reactant
using Reactant: Ops
using Enzyme
using PythonCall
jax = pyimport("jax")
numpy = pyimport("numpy")

function PythonCall.pycall(f::Py, args::Reactant.TracedRArray...; kwargs...)
    lowered = jax.jit(f).lower(args...)
    inputs = map(args) do arg
        numpy.array(size(arg), dtype=numpy.float32)
    end
    return Ops.hlo_call(
        pyconvert(String, lowered.as_text()),
        inputs...
    )
end

f = @compile jax.numpy.sum(
    Reactant.to_rarray(Float32[1, 2, 3]),
)

but this is reaching into internals (plus it stack overflows).

1 Like

I mean worst case, we could add a python_call to Reactant without a definition unless Pythoncall is loaded.

Also for ease, mind moving this discussion to the relevant Reactant issue for visibility?

3 Likes

Following up here this now natively works in Reactant with optimization and differentiation between julia/python :

using Reactant
jax = pyimport("jax")

result = Reactant.@jit jax.numpy.sum(Reactant.to_rarray(Float32[1, 2, 3]))
@test typeof(result) == ConcreteRNumber{Float32}
@test result ≈ 6

from Reactant.jl/test/integration/python.jl at main ¡ EnzymeAD/Reactant.jl ¡ GitHub

7 Likes

This is amazing, thanks @wsmoses and team!

I notice too that the overhead has decreased significantly since those experiments I posted above. They had a floor of about 30 us latency, and this example, when compiled, is more like 1.5 us. Amazing!

2 Likes