I’m trying to call JAX from Julia and I find that PythonCall.jl has a lot of overhead in a simple conversion case due to… printing? Has anyone run into the same issue?
Setup code:
using BenchmarkTools, CondaPkg, PythonCall
CondaPkg.add("numpy")
CondaPkg.add_pip("jax")
np = pyimport("numpy")
jnp = pyimport("jax.numpy")
x = rand(Float32, 1000);
Ok so it appears that jnp.array(x) is calling str(x) and repr(x) several times for some reason. I wonder if jnp.array tries a bunch of ways to convert x to an array, and the earlier tries involve throwing (then catching) an exception whose message includes x.