You generally should not use @jit from reactant outside of one off tests. You should say f = @compile foo(…), then call the compiled function f(…)
@MilesCranmer on my laptop, I see the following. We can probably bring down the overhead further to setup the call (though also note that the overhead is from calling the outermost function so you won’t see any such overhead calling jax code from within a julia function – just from the outermost setup itself).
julia> using Reactant
julia> using PythonCall: pyimport
julia> using BenchmarkTools: @btime
julia> jax = pyimport("jax")
Python: <module 'jax' from '/Users/wmoses/git/Reactant.jl/py/.CondaPkg/.pixi/envs/default/lib/python3.12/site-packages/jax/__init__.py'>
julia> jax_sum(x) = Reactant.@jit jax.numpy.sum(x)
jax_sum (generic function with 1 method)
julia> @btime sum(x) setup=(x=randn(Float32,100))
6.083 ns (0 allocations: 0 bytes)
-7.1881f0
julia> @btime jax_sum(x) setup=(x=Reactant.to_rarray(randn(Float32,100)))
27.854 ms (1615 allocations: 451.13 KiB)
Reactant.ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(-15.242416f0)
julia> jax_sum2 = Reactant.@compile sync=true jax.numpy.sum(Reactant.to_rarray(randn(Float32,100)))
Reactant compiled function <function sum at 0x35dfb9300> (with tag ##<function sum at 0x35dfb9300>_reactant#262)
julia> @btime jax_sum2(x) setup=(x=Reactant.to_rarray(randn(Float32,100)))
1.708 μs (16 allocations: 704 bytes)
Reactant.ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(2.3117552f0)
And yes you can call Reactant.to_rarray(0.0; track_numbers=true), or alternatively just ConcreteRNumber(0.0)
Just to elaborate, that’s because @jit
does two things: compile the function (what @compile
does) and call it. If you benchmark, even with @btime
, a @jit
call you always measure Reactant’s compilation and a first-time call latency of the (re-)compiled function. That’s why separating @compile
and function call gives much more sensible timing.
Thanks. I’ve been using Julia for so long I forgot that thing about JAX needing static shapes for compilation
Are there any static shape versions of ConcretePJRTArray so that the shape specialization can be done automatically?
julia> jax_sum_32 = Rx.@compile sync=true jax.numpy.sum(Rx.to_rarray(randn(Float32,32)))
Reactant compiled function <function sum at 0x340b8c860> (with tag ##<function sum at 0x340b8c860>_reactant#250)
julia> jax_sum_32(Rx.to_rarray(randn(Float32,64)))
ERROR: INVALID_ARGUMENT: Executable expected parameter 0 of size 128 but got buffer with incompatible size 256
It might be nice to build generic jax_sum(x)
that automatically compile to each shape, similar to how the @jax.jit
function wrapper does it
Originally we had the shape in the type, but we removed it so we could in the future support dynamic shapes (which is planned but we only started updating our optimizations to support recently).
But you can always make a separate cache based on the sizes that’s smarter than the current jit macro (contributions welcome).