It doesn’t seem to work
function test_fun(T, array_type)
x = rand(T, 100)
dx = array_type(x)
@test sum(dx) == sum(x)
end
julia> test_fun(Float32, Array)
Test Passed
julia> test_fun(Float32, Reactant.to_rarray)
Test Failed at REPL[9]:4
Expression: sum(dx) == sum(x)
Evaluated: ConcretePJRTNumber{Float32, 1}(47.982807f0) == 47.98281f0
ERROR: There was an error during testing
julia> test_fun_compiled = @compile test_fun(Float32, Reactant.to_rarray)
ERROR: Cannot convert TracedRArray to ConcreteArray
Stacktrace:
[1] call_with_reactant(::Reactant.EnsureReturnType{Union{}}, ::typeof(error), ::String)
@ Reactant ~/.julia/packages/Reactant/AO3KW/src/utils.jl:1064
[2] call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(Reactant.to_rarray_internal), ::Reactant.TracedRArray{…}, ::Type{…}, ::Reactant.Sharding.NoSharding, ::Val{…}, ::Nothing, ::Nothing)
@ Reactant ~/.julia/packages/Reactant/AO3KW/src/utils.jl:1064
[3] #to_rarray#153
@ ~/.julia/packages/Reactant/AO3KW/src/Tracing.jl:2050 [inlined]
[4] to_rarray
@ ~/.julia/packages/Reactant/AO3KW/src/Tracing.jl:2040 [inlined]
[5] test_fun
@ ./REPL[9]:3
[6] call_with_reactant(::typeof(test_fun), ::Type{Float32}, ::typeof(Reactant.to_rarray))
@ Reactant ~/.julia/packages/Reactant/AO3KW/src/utils.jl:1064
[7] make_mlir_fn(f::typeof(test_fun), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, within_autodiff::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
@ Reactant.TracedUtils ~/.julia/packages/Reactant/AO3KW/src/TracedUtils.jl:355
[8] make_mlir_fn
@ ~/.julia/packages/Reactant/AO3KW/src/TracedUtils.jl:284 [inlined]
[9] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::typeof(test_fun), args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}, sdygroupidcache::Tuple{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, client::Reactant.XLA.PJRT.Client, kwargs::@Kwargs{})
@ Reactant.Compiler ~/.julia/packages/Reactant/AO3KW/src/Compiler.jl:1740
[10] compile_mlir!
@ ~/.julia/packages/Reactant/AO3KW/src/Compiler.jl:1702 [inlined]
[11] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{…})
@ Reactant.Compiler ~/.julia/packages/Reactant/AO3KW/src/Compiler.jl:3708
[12] compile_xla
@ ~/.julia/packages/Reactant/AO3KW/src/Compiler.jl:3680 [inlined]
[13] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
@ Reactant.Compiler ~/.julia/packages/Reactant/AO3KW/src/Compiler.jl:3802
[14] top-level scope
@ ~/.julia/packages/Reactant/AO3KW/src/Compiler.jl:2850
Some type information was truncated. Use `show(err)` to see complete types.