What is the correct way to autodiff a simple CUDA kernel using Reactant.jl and Enzyme.jl

nice for now the workaround

function foo(x)
           y = map(eachindex(x),x) do i,xi
               xi
           end
           return y
       end

works but don’t even try to trace things inside

function foo(x)
           y = map(eachindex(x),x) do i,xi
               @trace if i >2
                    xi*0
               else
                    xi
                end
           end
           return y
       end

which will raise when jit

ERROR: MethodError: Cannot `convert` an object of type Reactant.TracedRArray{Float64, 1} to an object of type Reactant.TracedRNumber{Float64}
The function `convert` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  convert(::Type{T}, ::CartesianIndex{1}) where T<:Number
   @ Base multidimensional.jl:136
  convert(::Type{T}, ::Number) where T<:Number
   @ Base number.jl:7
  convert(::Type{T}, ::T) where T<:Number
   @ Base number.jl:6
  ...

Stacktrace:
  [1] setindex!(A::Vector{Reactant.TracedRNumber{Float64}}, x::Reactant.TracedRArray{Float64, 1}, i::Int64)
    @ Base ./array.jl:987
  [2] make_tracer(seen::Reactant.OrderedIdDict{…}, prev::Array, path::Any, mode::Reactant.TraceMode; track_numbers::Type, sharding::Any, runtime::Any, device::Any, client::Any, kwargs::@Kwargs{…})
    @ Reactant ~/.julia/packages/Reactant/pQXes/src/Tracing.jl:1598
  [3] elem_apply(::Function, ::Reactant.TracedRArray{Int64, 1}, ::Reactant.TracedRArray{Float64, 1})
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/pQXes/src/TracedUtils.jl:1188
  [4] overloaded_map(f::Function, x::Base.OneTo{Int64}, xs::Reactant.TracedRArray{Float64, 1})
    @ Reactant.TracedRArrayOverrides ~/.julia/packages/Reactant/pQXes/src/TracedRArray.jl:1044
  [5] map(f::Function, x::Base.OneTo{Int64}, ys::Reactant.TracedRArray{Float64, 1})
    @ Reactant ~/.julia/packages/Reactant/pQXes/src/Overlay.jl:175
  [6] foo
    @ ./REPL[2]:2 [inlined]
  [7] (::Nothing)(none::typeof(foo), none::Reactant.TracedRArray{Float64, 1})
    @ Reactant ./<missing>:0
  [8] getproperty
    @ ./Base.jl:49 [inlined]
  [9] size
    @ ~/.julia/packages/Reactant/pQXes/src/TracedRArray.jl:248 [inlined]
 [10] axes
    @ ./abstractarray.jl:98 [inlined]
 [11] axes1
    @ ./abstractarray.jl:137 [inlined]
 [12] eachindex
    @ ./abstractarray.jl:321 [inlined]
 [13] foo
    @ ./REPL[2]:2 [inlined]
 [14] call_with_reactant(::typeof(foo), ::Reactant.TracedRArray{Float64, 1})
    @ Reactant ~/.julia/packages/Reactant/pQXes/src/utils.jl:0
 [15] make_mlir_fn(f::typeof(foo), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/pQXes/src/TracedUtils.jl:348
 [16] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/packages/Reactant/pQXes/src/Compiler.jl:1575
 [17] compile_mlir! (repeats 2 times)
    @ ~/.julia/packages/Reactant/pQXes/src/Compiler.jl:1542 [inlined]
 [18] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/.julia/packages/Reactant/pQXes/src/Compiler.jl:3464
 [19] compile_xla
    @ ~/.julia/packages/Reactant/pQXes/src/Compiler.jl:3437 [inlined]
 [20] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/.julia/packages/Reactant/pQXes/src/Compiler.jl:3536
 [21] top-level scope
    @ ~/.julia/packages/Reactant/pQXes/src/Compiler.jl:2614
Some type information was truncated. Use `show(err)` to see complete types.