nice for now the workaround
function foo(x)
y = map(eachindex(x),x) do i,xi
xi
end
return y
end
works but don’t even try to trace things inside
function foo(x)
y = map(eachindex(x),x) do i,xi
@trace if i >2
xi*0
else
xi
end
end
return y
end
which will raise when jit
ERROR: MethodError: Cannot `convert` an object of type Reactant.TracedRArray{Float64, 1} to an object of type Reactant.TracedRNumber{Float64}
The function `convert` exists, but no method is defined for this combination of argument types.
Closest candidates are:
convert(::Type{T}, ::CartesianIndex{1}) where T<:Number
@ Base multidimensional.jl:136
convert(::Type{T}, ::Number) where T<:Number
@ Base number.jl:7
convert(::Type{T}, ::T) where T<:Number
@ Base number.jl:6
...
Stacktrace:
[1] setindex!(A::Vector{Reactant.TracedRNumber{Float64}}, x::Reactant.TracedRArray{Float64, 1}, i::Int64)
@ Base ./array.jl:987
[2] make_tracer(seen::Reactant.OrderedIdDict{…}, prev::Array, path::Any, mode::Reactant.TraceMode; track_numbers::Type, sharding::Any, runtime::Any, device::Any, client::Any, kwargs::@Kwargs{…})
@ Reactant ~/.julia/packages/Reactant/pQXes/src/Tracing.jl:1598
[3] elem_apply(::Function, ::Reactant.TracedRArray{Int64, 1}, ::Reactant.TracedRArray{Float64, 1})
@ Reactant.TracedUtils ~/.julia/packages/Reactant/pQXes/src/TracedUtils.jl:1188
[4] overloaded_map(f::Function, x::Base.OneTo{Int64}, xs::Reactant.TracedRArray{Float64, 1})
@ Reactant.TracedRArrayOverrides ~/.julia/packages/Reactant/pQXes/src/TracedRArray.jl:1044
[5] map(f::Function, x::Base.OneTo{Int64}, ys::Reactant.TracedRArray{Float64, 1})
@ Reactant ~/.julia/packages/Reactant/pQXes/src/Overlay.jl:175
[6] foo
@ ./REPL[2]:2 [inlined]
[7] (::Nothing)(none::typeof(foo), none::Reactant.TracedRArray{Float64, 1})
@ Reactant ./<missing>:0
[8] getproperty
@ ./Base.jl:49 [inlined]
[9] size
@ ~/.julia/packages/Reactant/pQXes/src/TracedRArray.jl:248 [inlined]
[10] axes
@ ./abstractarray.jl:98 [inlined]
[11] axes1
@ ./abstractarray.jl:137 [inlined]
[12] eachindex
@ ./abstractarray.jl:321 [inlined]
[13] foo
@ ./REPL[2]:2 [inlined]
[14] call_with_reactant(::typeof(foo), ::Reactant.TracedRArray{Float64, 1})
@ Reactant ~/.julia/packages/Reactant/pQXes/src/utils.jl:0
[15] make_mlir_fn(f::typeof(foo), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
@ Reactant.TracedUtils ~/.julia/packages/Reactant/pQXes/src/TracedUtils.jl:348
[16] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, kwargs::@Kwargs{})
@ Reactant.Compiler ~/.julia/packages/Reactant/pQXes/src/Compiler.jl:1575
[17] compile_mlir! (repeats 2 times)
@ ~/.julia/packages/Reactant/pQXes/src/Compiler.jl:1542 [inlined]
[18] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{…})
@ Reactant.Compiler ~/.julia/packages/Reactant/pQXes/src/Compiler.jl:3464
[19] compile_xla
@ ~/.julia/packages/Reactant/pQXes/src/Compiler.jl:3437 [inlined]
[20] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
@ Reactant.Compiler ~/.julia/packages/Reactant/pQXes/src/Compiler.jl:3536
[21] top-level scope
@ ~/.julia/packages/Reactant/pQXes/src/Compiler.jl:2614
Some type information was truncated. Use `show(err)` to see complete types.