Use of Reactant.jl on existing framework

Hello,

I have some generic function that tests some functionalities of a package. The function does something like this

function shared_function(T, array_type)
    x = rand(T, 10, 10)
    dx = array_type(x)
    @test some_function(dx) = some_function(x)
end

In this way I can test many backends by just changing array_type to e.g. Array, CuArray, MtlArray, ConcreteRArray.

However, Reactant.jl need to compile the function first

comp_f = @compile some_function(dx)
@test comp_f(dx) == some_function(x)

How should I manage this? Should I really put @compile to all the functions? This seems to be quite annoying.

1 Like

Just want to say that I have the same inquiry of @albertomercurio for the same exact motivation :slight_smile:

You could just do @jit for test that’s completely ok

It doesn’t seem to work

function test_fun(T, array_type)
    x = rand(T, 100)
    dx = array_type(x)
    @test sum(dx) == sum(x)
end
julia> test_fun(Float32, Array)
Test Passed

julia> test_fun(Float32, Reactant.to_rarray)
Test Failed at REPL[9]:4
  Expression: sum(dx) == sum(x)
   Evaluated: ConcretePJRTNumber{Float32, 1}(47.982807f0) == 47.98281f0

ERROR: There was an error during testing

julia> test_fun_compiled = @compile test_fun(Float32, Reactant.to_rarray)
ERROR: Cannot convert TracedRArray to ConcreteArray
Stacktrace:
  [1] call_with_reactant(::Reactant.EnsureReturnType{Union{}}, ::typeof(error), ::String)
    @ Reactant ~/.julia/packages/Reactant/AO3KW/src/utils.jl:1064
  [2] call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(Reactant.to_rarray_internal), ::Reactant.TracedRArray{…}, ::Type{…}, ::Reactant.Sharding.NoSharding, ::Val{…}, ::Nothing, ::Nothing)
    @ Reactant ~/.julia/packages/Reactant/AO3KW/src/utils.jl:1064
  [3] #to_rarray#153
    @ ~/.julia/packages/Reactant/AO3KW/src/Tracing.jl:2050 [inlined]
  [4] to_rarray
    @ ~/.julia/packages/Reactant/AO3KW/src/Tracing.jl:2040 [inlined]
  [5] test_fun
    @ ./REPL[9]:3
  [6] call_with_reactant(::typeof(test_fun), ::Type{Float32}, ::typeof(Reactant.to_rarray))
    @ Reactant ~/.julia/packages/Reactant/AO3KW/src/utils.jl:1064
  [7] make_mlir_fn(f::typeof(test_fun), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, within_autodiff::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/AO3KW/src/TracedUtils.jl:355
  [8] make_mlir_fn
    @ ~/.julia/packages/Reactant/AO3KW/src/TracedUtils.jl:284 [inlined]
  [9] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::typeof(test_fun), args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}, sdygroupidcache::Tuple{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, client::Reactant.XLA.PJRT.Client, kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/packages/Reactant/AO3KW/src/Compiler.jl:1740
 [10] compile_mlir!
    @ ~/.julia/packages/Reactant/AO3KW/src/Compiler.jl:1702 [inlined]
 [11] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/.julia/packages/Reactant/AO3KW/src/Compiler.jl:3708
 [12] compile_xla
    @ ~/.julia/packages/Reactant/AO3KW/src/Compiler.jl:3680 [inlined]
 [13] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/.julia/packages/Reactant/AO3KW/src/Compiler.jl:3802
 [14] top-level scope
    @ ~/.julia/packages/Reactant/AO3KW/src/Compiler.jl:2850
Some type information was truncated. Use `show(err)` to see complete types.

Currently to_rarray isn’t allowed within a compile, which is your last issue. We should fix this, file an issue on github (but also in the interim, don’t do that?)

Not tested [on a plane atm], but I assume something like this ought work?

function test_fun(T, array_type)
    x = rand(T, 100)
    dx = array_type(x)
    if dx isa ConcreteRArray
       @test sum(dx) ≈ sum(x)
    else
       @test Reactant.@jit(sum(dx)) ≈ sum(x)
    end
end

test_fun(Float32, Array)
test_fun(Float32, ConcreteRArray)

This is fairly similar to what’s in Reactant’s test suite itself

Oh yes this should work. However this implies that I have to import Reactant.jl in the main tests. Now it is structured that every backend has its own environment, CUDA, Metal, Reactant and so on, with test_fun as shared function. Your example implies that I need to have Reactant always loaded.

I mean there’s other ways to structure this if you don’t want to import things. For example, you could also pass in an execute function, like

function test_fun(T, array_type, execute)
    x = rand(T, 100)
    dx = array_type(x)
    @test execute(sum, dx) ≈ sum(x)
end

function default_exec(f, args...)
   f(args...)
end

function reactant_exec(f, args...)
   @jit f(args...)
end