It could be very useful to call JIT compiled JAX functions from Python libraries from Julia, and in turn, to be able for Julia libraries to call user-provided JAX functions.
This is possible today via PythonCall, e.g.:
    using PythonCall
    jax = pyimport("jax")
    numpy = pyimport("numpy")
    
    # Define a simple test function
    func_str = """
    def f(x):
        return jax.numpy.sin(x) + jax.numpy.cos(x)
    """
    
    # Create namespace and define function
    namespace = pydict()
    namespace["jax"] = jax
    pyexec(func_str, namespace)
    py_func = namespace["f"]
    
    # Create test input
    x = numpy.array([1.0], dtype=numpy.float32)
    
    # Get lowered representation
    lowered = jax.jit(py_func).lower(x)
    println("HLO text:")
    println(lowered.as_text())
    
    # Also test compilation and execution through Python
    compiled = lowered.compile()
    @time result = compiled(x)
    println("\nTest execution result:", result)
    println("\nTest execution result:", sum(result))
The problem of course it that we have overhead from python, about 7us worth per call, because starting the JAX function still goes through Python (compiled(x)...).
It seems like it should be possible to call the JIT compiled function directly. At worse case, it should be possible to export the compiled JAX function and then run it without involving Python.
Here is a relevant GitHub discussion: Calling pre-compiled JAX code from C++ ¡ jax-ml/jax ¡ Discussion #22184 ¡ GitHub
Has anyone tried this, or have suggestions on how to proceed? I guess we would need to get the PJRT C library built and into Yggdrasil?