Calling Python function JIT compiled with JAX from Julia without overhead

Okay this is pretty darn cool!

Simple proof of concept for tracing and JIT compiling a python function with JAX, and then running it in Julia and autodiffing it with Enzyme:

f, g = jax_from_julia_with_grad(
    """
    def f(x):
        return jax.numpy.sum(jax.numpy.sin(x) + jax.numpy.cos(x))
    """,
    ones(Float32, 100)
)
r = Reactant.to_rarray(ones(Float32, 100))

julia> @time f(r)
  0.000034 seconds (4 allocations: 80 bytes)
(ConcreteRNumber{Float32}(138.17734f0),)

julia> @time g(r)
  0.000043 seconds (4 allocations: 96 bytes)
(ConcreteRArray{Float32, 1}(Float32[-0.30116868, -0.30116868, -0.30116868, -0.30116868, -0.30116868, -0.30116868, -0.30116868, -0.30116868, -0.30116868, -0.30116868  …  -0.30116868, -0.30116868, -0.30116868, -0.30116868, -0.30116868, -0.30116868, -0.30116868, -0.30116868, -0.30116868, -0.30116868]),)

Code below:

using Reactant
using Reactant: Ops
using Enzyme
using PythonCall
function jax_from_julia_with_grad(func_str, example_inputs)

    # Create namespace and define function
    namespace = pydict()
    namespace["jax"] = jax
    pyexec(func_str, namespace)
    py_func = namespace["f"]
    
    # Create test input
    x = numpy.array(example_inputs, dtype=numpy.float32)
    
    # Get lowered representation
    lowered = jax.jit(py_func).lower(x)

    _primal = @compile Ops.hlo_call(
        pyconvert(String, lowered.as_text()),
        Reactant.to_rarray(example_inputs)    
    )

    function _grad(i)
        function inner(inp)
            Ops.hlo_call(
                pyconvert(String, lowered.as_text()),
                inp
            )
        end
        return Enzyme.gradient(Reverse, inner, i)
    end
    __grad = @compile _grad(
        Reactant.to_rarray(example_inputs)    
    )


    return (args...)->_primal(nothing,args...), __grad
end

2 Likes