PythonCall spends a lot of time showing stuff for JAX

Hi everyone, especially @cjdoris!

I’m trying to call JAX from Julia and I find that PythonCall.jl has a lot of overhead in a simple conversion case due to… printing? Has anyone run into the same issue?

Setup code:

using BenchmarkTools, CondaPkg, PythonCall
CondaPkg.add("numpy")
CondaPkg.add_pip("jax")
np = pyimport("numpy")
jnp = pyimport("jax.numpy")
x = rand(Float32, 1000);

Benchmark:

julia> @btime $(np.array)($x);  # fast
  2.049 μs (17 allocations: 464 bytes)

julia> @btime $(jnp.array)($x);  # slow
  2.400 ms (8386 allocations: 524.22 KiB)

Profiling:

julia> @profview for _ in 1:100; jnp.array(x); end

Weird, it seems like some conversion is going through repr rather than using the Python buffer interface. What if you call jnp.asarray(x)?

Ok so it appears that jnp.array(x) is calling str(x) and repr(x) several times for some reason. I wonder if jnp.array tries a bunch of ways to convert x to an array, and the earlier tries involve throwing (then catching) an exception whose message includes x.

Same results unfortunately.

That’s probably what happens. By first using np.array followed by jnp.array I divide the overhead by 10:

julia> @btime $(np.array)($x);  # fast
  3.073 μs (17 allocations: 464 bytes)

julia> @btime $(jnp.array)($(np.array(x)));
  105.150 μs (2 allocations: 32 bytes)

Of course the question is how much better this can get (it is a pretty crucial operation in GitHub - gdalle/DifferentiationInterfaceJAX.jl).

EDIT: Perhaps I can get away with using only np.array, will report back.