Okay this is pretty darn cool!
Simple proof of concept for tracing and JIT compiling a python function with JAX, and then running it in Julia and autodiffing it with Enzyme:
f, g = jax_from_julia_with_grad(
"""
def f(x):
return jax.numpy.sum(jax.numpy.sin(x) + jax.numpy.cos(x))
""",
ones(Float32, 100)
)
r = Reactant.to_rarray(ones(Float32, 100))
julia> @time f(r)
0.000034 seconds (4 allocations: 80 bytes)
(ConcreteRNumber{Float32}(138.17734f0),)
julia> @time g(r)
0.000043 seconds (4 allocations: 96 bytes)
(ConcreteRArray{Float32, 1}(Float32[-0.30116868, -0.30116868, -0.30116868, -0.30116868, -0.30116868, -0.30116868, -0.30116868, -0.30116868, -0.30116868, -0.30116868 … -0.30116868, -0.30116868, -0.30116868, -0.30116868, -0.30116868, -0.30116868, -0.30116868, -0.30116868, -0.30116868, -0.30116868]),)
Code below:
using Reactant
using Reactant: Ops
using Enzyme
using PythonCall
function jax_from_julia_with_grad(func_str, example_inputs)
# Create namespace and define function
namespace = pydict()
namespace["jax"] = jax
pyexec(func_str, namespace)
py_func = namespace["f"]
# Create test input
x = numpy.array(example_inputs, dtype=numpy.float32)
# Get lowered representation
lowered = jax.jit(py_func).lower(x)
_primal = @compile Ops.hlo_call(
pyconvert(String, lowered.as_text()),
Reactant.to_rarray(example_inputs)
)
function _grad(i)
function inner(inp)
Ops.hlo_call(
pyconvert(String, lowered.as_text()),
inp
)
end
return Enzyme.gradient(Reverse, inner, i)
end
__grad = @compile _grad(
Reactant.to_rarray(example_inputs)
)
return (args...)->_primal(nothing,args...), __grad
end