PythonCall spends a lot of time showing stuff for JAX

Hi everyone, especially @cjdoris!

I’m trying to call JAX from Julia and I find that PythonCall.jl has a lot of overhead in a simple conversion case due to… printing? Has anyone run into the same issue?

Setup code:

using BenchmarkTools, CondaPkg, PythonCall
CondaPkg.add("numpy")
CondaPkg.add_pip("jax")
np = pyimport("numpy")
jnp = pyimport("jax.numpy")
x = rand(Float32, 1000);

Benchmark:

julia> @btime $(np.array)($x);  # fast
  2.049 μs (17 allocations: 464 bytes)

julia> @btime $(jnp.array)($x);  # slow
  2.400 ms (8386 allocations: 524.22 KiB)

Profiling:

julia> @profview for _ in 1:100; jnp.array(x); end

2 Likes

Weird, it seems like some conversion is going through repr rather than using the Python buffer interface. What if you call jnp.asarray(x)?

1 Like

Ok so it appears that jnp.array(x) is calling str(x) and repr(x) several times for some reason. I wonder if jnp.array tries a bunch of ways to convert x to an array, and the earlier tries involve throwing (then catching) an exception whose message includes x.

Same results unfortunately.

That’s probably what happens. By first using np.array followed by jnp.array I divide the overhead by 10:

julia> @btime $(np.array)($x);  # fast
  3.073 μs (17 allocations: 464 bytes)

julia> @btime $(jnp.array)($(np.array(x)));
  105.150 μs (2 allocations: 32 bytes)

Of course the question is how much better this can get (it is a pretty crucial operation in GitHub - gdalle/DifferentiationInterfaceJAX.jl).

EDIT: Perhaps I can get away with using only np.array, will report back.

1 Like