Using control flow in Reactant

I’ve got some expensive code I’m running that I’d like to compile with Reactant. There is some control flow that’s not playing nice with Reactant. I’ve got an MWE producing a different error than my main code, but I’m hoping debugging this MWE will still help:

begin
    using Random
    using Reactant
    using Statistics
    using BenchmarkTools
    Reactant.set_default_backend("cpu")
    Random.seed!(1234)
end

begin # Functions
    function testFlow(x, thetas)
        nlen = size(x,1)
        datalen = sum(length.(@view thetas[1:nlen]))
        dσ = [zero(eltype(x)) for _ in 1:datalen] 
        j = 1
        for i in 1:nlen
            exp_len = length(thetas[i])
            j_next = j + exp_len
            sig = [zero(eltype(x)) for _ in 1:exp_len]
            @trace if x[i,2] > 0.5
                sig = sum(x[i,:])*cos.(thetas[i])
            else
                sig = 10*sum(x[i,:])*cos.(thetas[i])
            end
            dσ[j:j_next-1] = sig
            j = j_next
        end
        return dσ
    end
end

x = rand(Float32, 2, 2)
thetas = [
    [1.0, 10.0, 20.0, 45.0, 60],
    [5.0, 15.0, 25.0, 60]
]
# testFlow(x, thetas)
@allowscalar test_comp = @compile testFlow(xdev(x), xdev(thetas))
@allowscalar test = test_comp(xdev(x), xdev(thetas))

And the error:

LoadError: MethodError: no method matching set_mlir_data!(::Int64, ::Reactant.MLIR.IR.Value)
The function `set_mlir_data!` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  set_mlir_data!(::LinearAlgebra.LowerTriangular{Reactant.TracedRNumber{T}, <:AbstractArray{Reactant.TracedRNumber{T}, 2} where T}, ::Any) where T
   @ Reactant ~/.julia/packages/Reactant/QTNFa/src/stdlibs/LinearAlgebra.jl:165
  set_mlir_data!(::LinearAlgebra.UnitLowerTriangular{Reactant.TracedRNumber{T}, <:AbstractArray{Reactant.TracedRNumber{T}, 2} where T}, ::Any) where T
   @ Reactant ~/.julia/packages/Reactant/QTNFa/src/stdlibs/LinearAlgebra.jl:165
  set_mlir_data!(::LinearAlgebra.Symmetric{Reactant.TracedRNumber{T}, <:AbstractArray{Reactant.TracedRNumber{T}, 2} where T}, ::Any) where T
   @ Reactant ~/.julia/packages/Reactant/QTNFa/src/stdlibs/LinearAlgebra.jl:186
  ...

Stacktrace:
  [1] set!(x::Tuple{Vector{Reactant.TracedRArray{Float64, 1}}, Int64, Reactant.TracedRArray{Float32, 2}}, path::Tuple{Int64}, tostore::Reactant.MLIR.IR.Value; emptypath::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/QTNFa/src/TracedUtils.jl:1067
  [2] set!(x::Tuple{Vector{Reactant.TracedRArray{Float64, 1}}, Int64, Reactant.TracedRArray{Float32, 2}}, path::Tuple{Int64}, tostore::Reactant.MLIR.IR.Value)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/QTNFa/src/TracedUtils.jl:1062
  [3] if_condition(::Reactant.TracedRNumber{Bool}, ::var"#157#161", ::var"#158#162", ::Vector{Reactant.TracedRArray{Float64, 1}}, ::Vararg{Any}; track_numbers::Type, location::Reactant.MLIR.IR.Location)
    @ Reactant.Ops ~/.julia/packages/Reactant/QTNFa/src/Ops.jl:0
  [4] #traced_if#126
    @ ~/.julia/packages/Reactant/QTNFa/src/ControlFlow.jl:4 [inlined]
  [5] traced_if
    @ ~/.julia/packages/Reactant/QTNFa/src/ControlFlow.jl:1 [inlined]
  [6] (::Nothing)(none::typeof(Core.kwcall), none::@NamedTuple{track_numbers::DataType}, none::typeof(ReactantCore.traced_if), none::Reactant.TracedRNumber{Bool}, none::var"#157#161", none::var"#158#162", none::Tuple{Vector{…}, Int64, Reactant.TracedRArray{…}})
    @ Reactant ./<missing>:0
  [7] call_with_reactant(::typeof(Core.kwcall), ::@NamedTuple{track_numbers::DataType}, ::typeof(ReactantCore.traced_if), ::Reactant.TracedRNumber{Bool}, ::var"#157#161", ::var"#158#162", ::Tuple{Vector{…}, Int64, Reactant.TracedRArray{…}})
    @ Reactant ~/.julia/packages/Reactant/QTNFa/src/utils.jl:501
  [8] macro expansion
    @ ~/.julia/packages/ReactantCore/9hY4Z/src/ReactantCore.jl:524 [inlined]
  [9] testFlow
    @ ~/Documents/Code/nuclear-diffprog/MWEs/reactcontrolflow.jl:20 [inlined]
 [10] (::Nothing)(none::typeof(testFlow), none::Reactant.TracedRArray{Float32, 2}, none::Vector{Reactant.TracedRArray{Float64, 1}})
    @ Reactant ./<missing>:0
 [11] Box
    @ ./boot.jl:434 [inlined]
 [12] testFlow
    @ ~/Documents/Code/nuclear-diffprog/MWEs/reactcontrolflow.jl:12 [inlined]
 [13] call_with_reactant(::typeof(testFlow), ::Reactant.TracedRArray{Float32, 2}, ::Vector{Reactant.TracedRArray{Float64, 1}})
    @ Reactant ~/.julia/packages/Reactant/QTNFa/src/utils.jl:0
 [14] make_mlir_fn(f::typeof(testFlow), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/QTNFa/src/TracedUtils.jl:332
 [15] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/packages/Reactant/QTNFa/src/Compiler.jl:1555
 [16] compile_mlir! (repeats 2 times)
    @ ~/.julia/packages/Reactant/QTNFa/src/Compiler.jl:1522 [inlined]
 [17] compile_xla(f::Function, args::Tuple{ConcretePJRTArray{…}, Vector{…}}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{compile_options::CompileOptions, fn_kwargs::@NamedTuple{}})
    @ Reactant.Compiler ~/.julia/packages/Reactant/QTNFa/src/Compiler.jl:3433
 [18] compile_xla
    @ ~/.julia/packages/Reactant/QTNFa/src/Compiler.jl:3406 [inlined]
 [19] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/.julia/packages/Reactant/QTNFa/src/Compiler.jl:3505
 [20] macro expansion
    @ ~/.julia/packages/Reactant/QTNFa/src/Compiler.jl:2586 [inlined]
 [21] top-level scope
    @ ~/Documents/Code/nuclear-diffprog/MWEs/reactcontrolflow.jl:206
 [22] include(fname::String)
    @ Main ./sysimg.jl:38
 [23] top-level scope
    @ REPL[10]:1
in expression starting at /Users/daningburg/Documents/Code/nuclear-diffprog/MWEs/reactcontrolflow.jl:38
Some type information was truncated. Use `show(err)` to see complete types.

I’ve read the Reactant page on control flow but it’s not clear to me how to apply it to my situation: Control Flow | Reactant.jl

Can you make an issue here is a mwe :

using Reactant

function foo(x)
    t = one(eltype(x))
    @trace if x[1] > 0
        t += x[1]
    else
        t -= x[1]
    end
    return t
end

x = Reactant.to_rarray(rand(1))
@allowscalar @jit foo(x)

you can overcome with ifelse() for now but that also seems to have an issue and can only handle TracedNumbers for now not array. As for the name of the issue I think it should be something like “issue when tracing control flow with an element of a traced array”
Looking at Reactant.jl/test/control_flow.jl at main · EnzymeAD/Reactant.jl · GitHub it seems like this is never tested (if on indexed part of an array)
PS : Reactant.jl is really young and in active development so it will take time to figure out all the cornercases

1 Like

I was able to cheat a little but beside testing 1 million thing for now its hard

    function testFlow(x, thetas)
        nlen = size(x,1)
        datalen = sum(length.(@view thetas[1:nlen]))
        dσ = [zero(eltype(x)) for _ in 1:datalen] 
        j = 1
        for i in 1:nlen
            exp_len = length(thetas[i])
            j_next = j + exp_len
            sig = [zero(eltype(x)) for _ in 1:exp_len]
            xi = x[i,:]
            ti = thetas[i]
            res = @trace if xi[2] > 0.5
                sum(xi)*cos.(ti)
            else
                10.0 .* sum(xi) .* cos.(ti)
            end
            sig = res[1]
            dσ[j:j_next-1] = sig
            j = j_next
        end
        return dσ
    end
end

Now I think the issue is actually about indexing within a trace, also, no idea why I need to index res like that to get the result of the trace

This is really odd, but I dropped your function into my testscript, renamed it, and now my code isn’t recognizing reactant_device:

begin
    using Random
    using Reactant
    using Statistics
    using BenchmarkTools
    Reactant.set_default_backend("cpu")
    Random.seed!(1234)
end

begin # Functions
    function testFlow(x, thetas)
        nlen = size(x,1)
        datalen = sum(length.(@view thetas[1:nlen]))
        dσ = [zero(eltype(x)) for _ in 1:datalen] 
        j = 1
        for i in 1:nlen
            exp_len = length(thetas[i])
            j_next = j + exp_len
            sig = [zero(eltype(x)) for _ in 1:exp_len]
            @trace if x[i,2] > 0.5
                sig = sum(x[i,:])*cos.(thetas[i])
            else
                sig = 10*sum(x[i,:])*cos.(thetas[i])
            end
            dσ[j:j_next-1] = sig
            j = j_next
        end
        return dσ
    end

    function testFlow2(x, thetas)
        nlen = size(x,1)
        datalen = sum(length.(@view thetas[1:nlen]))
        dσ = [zero(eltype(x)) for _ in 1:datalen] 
        j = 1
        for i in 1:nlen
            exp_len = length(thetas[i])
            j_next = j + exp_len
            sig = [zero(eltype(x)) for _ in 1:exp_len]
            xi = x[i,:]
            ti = thetas[i]
            res = @trace if xi[2] > 0.5
                sum(xi)*cos.(ti)
            else
                10.0 .* sum(xi) .* cos.(ti)
            end
            sig = res[1]
            dσ[j:j_next-1] = sig
            j = j_next
        end
        return dσ
    end
end
const xdev = reactant_device()

x = rand(Float32, 2, 2)
thetas = [
    [1.0, 10.0, 20.0, 45.0, 60],
    [5.0, 15.0, 25.0, 60]
]
# testFlow(x, thetas)
@allowscalar test_comp = @compile testFlow(xdev(x), xdev(thetas))
@allowscalar test = test_comp(xdev(x), xdev(thetas))
LoadError: UndefVarError: `reactant_device` not defined in `Main`
Suggestion: check for spelling errors or missing imports.
Stacktrace:
 [1] top-level scope
   @ ~/Documents/Code/nuclear-diffprog/MWEs/reactcontrolflow.jl:54
 [2] include(fname::String)
   @ Main ./sysimg.jl:38
 [3] top-level scope
   @ REPL[1]:1
in expression starting at /Users/daningburg/Documents/Code/nuclear-diffprog/MWEs/reactcontrolflow.jl:54

I’m sure it’s unrelated to your code, maybe something about my environment. I tried restarting Julia to no effect.

It should be Reactant.to_rarray,the device thing is in Lux

1 Like