Use of Reactant.jl on existing framework

Hello,

I have some generic function that tests some functionalities of a package. The function does something like this

function shared_function(T, array_type)
    x = rand(T, 10, 10)
    dx = array_type(x)
    @test some_function(dx) = some_function(x)
end

In this way I can test many backends by just changing array_type to e.g. Array, CuArray, MtlArray, ConcreteRArray.

However, Reactant.jl need to compile the function first

comp_f = @compile some_function(dx)
@test comp_f(dx) == some_function(x)

How should I manage this? Should I really put @compile to all the functions? This seems to be quite annoying.

3 Likes

Just want to say that I have the same inquiry of @albertomercurio for the same exact motivation :slight_smile:

You could just do @jit for test that’s completely ok

It doesn’t seem to work

function test_fun(T, array_type)
    x = rand(T, 100)
    dx = array_type(x)
    @test sum(dx) == sum(x)
end
julia> test_fun(Float32, Array)
Test Passed

julia> test_fun(Float32, Reactant.to_rarray)
Test Failed at REPL[9]:4
  Expression: sum(dx) == sum(x)
   Evaluated: ConcretePJRTNumber{Float32, 1}(47.982807f0) == 47.98281f0

ERROR: There was an error during testing

julia> test_fun_compiled = @compile test_fun(Float32, Reactant.to_rarray)
ERROR: Cannot convert TracedRArray to ConcreteArray
Stacktrace:
  [1] call_with_reactant(::Reactant.EnsureReturnType{Union{}}, ::typeof(error), ::String)
    @ Reactant ~/.julia/packages/Reactant/AO3KW/src/utils.jl:1064
  [2] call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(Reactant.to_rarray_internal), ::Reactant.TracedRArray{…}, ::Type{…}, ::Reactant.Sharding.NoSharding, ::Val{…}, ::Nothing, ::Nothing)
    @ Reactant ~/.julia/packages/Reactant/AO3KW/src/utils.jl:1064
  [3] #to_rarray#153
    @ ~/.julia/packages/Reactant/AO3KW/src/Tracing.jl:2050 [inlined]
  [4] to_rarray
    @ ~/.julia/packages/Reactant/AO3KW/src/Tracing.jl:2040 [inlined]
  [5] test_fun
    @ ./REPL[9]:3
  [6] call_with_reactant(::typeof(test_fun), ::Type{Float32}, ::typeof(Reactant.to_rarray))
    @ Reactant ~/.julia/packages/Reactant/AO3KW/src/utils.jl:1064
  [7] make_mlir_fn(f::typeof(test_fun), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, within_autodiff::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/AO3KW/src/TracedUtils.jl:355
  [8] make_mlir_fn
    @ ~/.julia/packages/Reactant/AO3KW/src/TracedUtils.jl:284 [inlined]
  [9] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::typeof(test_fun), args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}, sdygroupidcache::Tuple{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, client::Reactant.XLA.PJRT.Client, kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/packages/Reactant/AO3KW/src/Compiler.jl:1740
 [10] compile_mlir!
    @ ~/.julia/packages/Reactant/AO3KW/src/Compiler.jl:1702 [inlined]
 [11] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/.julia/packages/Reactant/AO3KW/src/Compiler.jl:3708
 [12] compile_xla
    @ ~/.julia/packages/Reactant/AO3KW/src/Compiler.jl:3680 [inlined]
 [13] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/.julia/packages/Reactant/AO3KW/src/Compiler.jl:3802
 [14] top-level scope
    @ ~/.julia/packages/Reactant/AO3KW/src/Compiler.jl:2850
Some type information was truncated. Use `show(err)` to see complete types.

Currently to_rarray isn’t allowed within a compile, which is your last issue. We should fix this, file an issue on github (but also in the interim, don’t do that?)

Not tested [on a plane atm], but I assume something like this ought work?

function test_fun(T, array_type)
    x = rand(T, 100)
    dx = array_type(x)
    if dx isa ConcreteRArray
       @test sum(dx) ≈ sum(x)
    else
       @test Reactant.@jit(sum(dx)) ≈ sum(x)
    end
end

test_fun(Float32, Array)
test_fun(Float32, ConcreteRArray)

This is fairly similar to what’s in Reactant’s test suite itself

Oh yes this should work. However this implies that I have to import Reactant.jl in the main tests. Now it is structured that every backend has its own environment, CUDA, Metal, Reactant and so on, with test_fun as shared function. Your example implies that I need to have Reactant always loaded.

I mean there’s other ways to structure this if you don’t want to import things. For example, you could also pass in an execute function, like

function test_fun(T, array_type, execute)
    x = rand(T, 100)
    dx = array_type(x)
    @test execute(sum, dx) ≈ sum(x)
end

function default_exec(f, args...)
   f(args...)
end

function reactant_exec(f, args...)
   @jit f(args...)
end

Ok this seems the best way thanks.

I’m wondering if Reactant.jl is planning to autocompile the functions in the future, in order to avoid this extra step for all the functions.

That is an eventual goal, but unlikely to start anytime soon.

1 Like