It could be very useful to call JIT compiled JAX functions from Python libraries from Julia, and in turn, to be able for Julia libraries to call user-provided JAX functions.
This is possible today via PythonCall, e.g.:
using PythonCall
jax = pyimport("jax")
numpy = pyimport("numpy")
# Define a simple test function
func_str = """
def f(x):
return jax.numpy.sin(x) + jax.numpy.cos(x)
"""
# Create namespace and define function
namespace = pydict()
namespace["jax"] = jax
pyexec(func_str, namespace)
py_func = namespace["f"]
# Create test input
x = numpy.array([1.0], dtype=numpy.float32)
# Get lowered representation
lowered = jax.jit(py_func).lower(x)
println("HLO text:")
println(lowered.as_text())
# Also test compilation and execution through Python
compiled = lowered.compile()
@time result = compiled(x)
println("\nTest execution result:", result)
println("\nTest execution result:", sum(result))
The problem of course it that we have overhead from python, about 7us worth per call, because starting the JAX function still goes through Python (compiled(x)...)
.
It seems like it should be possible to call the JIT compiled function directly. At worse case, it should be possible to export the compiled JAX function and then run it without involving Python.
Here is a relevant GitHub discussion: Calling pre-compiled JAX code from C++ ¡ jax-ml/jax ¡ Discussion #22184 ¡ GitHub
Has anyone tried this, or have suggestions on how to proceed? I guess we would need to get the PJRT C library built and into Yggdrasil?