Calling Python function JIT compiled with JAX from Julia without overhead

You generally should not use @jit from reactant outside of one off tests. You should say f = @compile foo(…), then call the compiled function f(…)

@MilesCranmer on my laptop, I see the following. We can probably bring down the overhead further to setup the call (though also note that the overhead is from calling the outermost function so you won’t see any such overhead calling jax code from within a julia function – just from the outermost setup itself).

julia> using Reactant

julia> using PythonCall: pyimport

julia> using BenchmarkTools: @btime

julia> jax = pyimport("jax")
Python: <module 'jax' from '/Users/wmoses/git/Reactant.jl/py/.CondaPkg/.pixi/envs/default/lib/python3.12/site-packages/jax/__init__.py'>

julia> jax_sum(x) = Reactant.@jit jax.numpy.sum(x)
jax_sum (generic function with 1 method)

julia> @btime sum(x) setup=(x=randn(Float32,100))
  6.083 ns (0 allocations: 0 bytes)
-7.1881f0

julia> @btime jax_sum(x) setup=(x=Reactant.to_rarray(randn(Float32,100)))
  27.854 ms (1615 allocations: 451.13 KiB)
Reactant.ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(-15.242416f0)

julia> jax_sum2 = Reactant.@compile sync=true jax.numpy.sum(Reactant.to_rarray(randn(Float32,100)))
Reactant compiled function <function sum at 0x35dfb9300> (with tag ##<function sum at 0x35dfb9300>_reactant#262)

julia> @btime jax_sum2(x) setup=(x=Reactant.to_rarray(randn(Float32,100)))
  1.708 μs (16 allocations: 704 bytes)
Reactant.ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(2.3117552f0)
2 Likes

And yes you can call Reactant.to_rarray(0.0; track_numbers=true), or alternatively just ConcreteRNumber(0.0)

1 Like

Just to elaborate, that’s because @jit does two things: compile the function (what @compile does) and call it. If you benchmark, even with @btime, a @jit call you always measure Reactant’s compilation and a first-time call latency of the (re-)compiled function. That’s why separating @compile and function call gives much more sensible timing.

3 Likes

Thanks. I’ve been using Julia for so long I forgot that thing about JAX needing static shapes for compilation

Are there any static shape versions of ConcretePJRTArray so that the shape specialization can be done automatically?

julia> jax_sum_32 = Rx.@compile sync=true jax.numpy.sum(Rx.to_rarray(randn(Float32,32)))
Reactant compiled function <function sum at 0x340b8c860> (with tag ##<function sum at 0x340b8c860>_reactant#250)

julia> jax_sum_32(Rx.to_rarray(randn(Float32,64)))
ERROR: INVALID_ARGUMENT: Executable expected parameter 0 of size 128 but got buffer with incompatible size 256

It might be nice to build generic jax_sum(x) that automatically compile to each shape, similar to how the @jax.jit function wrapper does it

1 Like

Originally we had the shape in the type, but we removed it so we could in the future support dynamic shapes (which is planned but we only started updating our optimizations to support recently).

But you can always make a separate cache based on the sizes that’s smarter than the current jit macro (contributions welcome).