I’ve got some expensive code I’m running that I’d like to compile with Reactant. There is some control flow that’s not playing nice with Reactant. I’ve got an MWE producing a different error than my main code, but I’m hoping debugging this MWE will still help:
begin
using Random
using Reactant
using Statistics
using BenchmarkTools
Reactant.set_default_backend("cpu")
Random.seed!(1234)
end
begin # Functions
function testFlow(x, thetas)
nlen = size(x,1)
datalen = sum(length.(@view thetas[1:nlen]))
dσ = [zero(eltype(x)) for _ in 1:datalen]
j = 1
for i in 1:nlen
exp_len = length(thetas[i])
j_next = j + exp_len
sig = [zero(eltype(x)) for _ in 1:exp_len]
@trace if x[i,2] > 0.5
sig = sum(x[i,:])*cos.(thetas[i])
else
sig = 10*sum(x[i,:])*cos.(thetas[i])
end
dσ[j:j_next-1] = sig
j = j_next
end
return dσ
end
end
x = rand(Float32, 2, 2)
thetas = [
[1.0, 10.0, 20.0, 45.0, 60],
[5.0, 15.0, 25.0, 60]
]
# testFlow(x, thetas)
@allowscalar test_comp = @compile testFlow(xdev(x), xdev(thetas))
@allowscalar test = test_comp(xdev(x), xdev(thetas))
And the error:
LoadError: MethodError: no method matching set_mlir_data!(::Int64, ::Reactant.MLIR.IR.Value)
The function `set_mlir_data!` exists, but no method is defined for this combination of argument types.
Closest candidates are:
set_mlir_data!(::LinearAlgebra.LowerTriangular{Reactant.TracedRNumber{T}, <:AbstractArray{Reactant.TracedRNumber{T}, 2} where T}, ::Any) where T
@ Reactant ~/.julia/packages/Reactant/QTNFa/src/stdlibs/LinearAlgebra.jl:165
set_mlir_data!(::LinearAlgebra.UnitLowerTriangular{Reactant.TracedRNumber{T}, <:AbstractArray{Reactant.TracedRNumber{T}, 2} where T}, ::Any) where T
@ Reactant ~/.julia/packages/Reactant/QTNFa/src/stdlibs/LinearAlgebra.jl:165
set_mlir_data!(::LinearAlgebra.Symmetric{Reactant.TracedRNumber{T}, <:AbstractArray{Reactant.TracedRNumber{T}, 2} where T}, ::Any) where T
@ Reactant ~/.julia/packages/Reactant/QTNFa/src/stdlibs/LinearAlgebra.jl:186
...
Stacktrace:
[1] set!(x::Tuple{Vector{Reactant.TracedRArray{Float64, 1}}, Int64, Reactant.TracedRArray{Float32, 2}}, path::Tuple{Int64}, tostore::Reactant.MLIR.IR.Value; emptypath::Bool)
@ Reactant.TracedUtils ~/.julia/packages/Reactant/QTNFa/src/TracedUtils.jl:1067
[2] set!(x::Tuple{Vector{Reactant.TracedRArray{Float64, 1}}, Int64, Reactant.TracedRArray{Float32, 2}}, path::Tuple{Int64}, tostore::Reactant.MLIR.IR.Value)
@ Reactant.TracedUtils ~/.julia/packages/Reactant/QTNFa/src/TracedUtils.jl:1062
[3] if_condition(::Reactant.TracedRNumber{Bool}, ::var"#157#161", ::var"#158#162", ::Vector{Reactant.TracedRArray{Float64, 1}}, ::Vararg{Any}; track_numbers::Type, location::Reactant.MLIR.IR.Location)
@ Reactant.Ops ~/.julia/packages/Reactant/QTNFa/src/Ops.jl:0
[4] #traced_if#126
@ ~/.julia/packages/Reactant/QTNFa/src/ControlFlow.jl:4 [inlined]
[5] traced_if
@ ~/.julia/packages/Reactant/QTNFa/src/ControlFlow.jl:1 [inlined]
[6] (::Nothing)(none::typeof(Core.kwcall), none::@NamedTuple{track_numbers::DataType}, none::typeof(ReactantCore.traced_if), none::Reactant.TracedRNumber{Bool}, none::var"#157#161", none::var"#158#162", none::Tuple{Vector{…}, Int64, Reactant.TracedRArray{…}})
@ Reactant ./<missing>:0
[7] call_with_reactant(::typeof(Core.kwcall), ::@NamedTuple{track_numbers::DataType}, ::typeof(ReactantCore.traced_if), ::Reactant.TracedRNumber{Bool}, ::var"#157#161", ::var"#158#162", ::Tuple{Vector{…}, Int64, Reactant.TracedRArray{…}})
@ Reactant ~/.julia/packages/Reactant/QTNFa/src/utils.jl:501
[8] macro expansion
@ ~/.julia/packages/ReactantCore/9hY4Z/src/ReactantCore.jl:524 [inlined]
[9] testFlow
@ ~/Documents/Code/nuclear-diffprog/MWEs/reactcontrolflow.jl:20 [inlined]
[10] (::Nothing)(none::typeof(testFlow), none::Reactant.TracedRArray{Float32, 2}, none::Vector{Reactant.TracedRArray{Float64, 1}})
@ Reactant ./<missing>:0
[11] Box
@ ./boot.jl:434 [inlined]
[12] testFlow
@ ~/Documents/Code/nuclear-diffprog/MWEs/reactcontrolflow.jl:12 [inlined]
[13] call_with_reactant(::typeof(testFlow), ::Reactant.TracedRArray{Float32, 2}, ::Vector{Reactant.TracedRArray{Float64, 1}})
@ Reactant ~/.julia/packages/Reactant/QTNFa/src/utils.jl:0
[14] make_mlir_fn(f::typeof(testFlow), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
@ Reactant.TracedUtils ~/.julia/packages/Reactant/QTNFa/src/TracedUtils.jl:332
[15] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, kwargs::@Kwargs{})
@ Reactant.Compiler ~/.julia/packages/Reactant/QTNFa/src/Compiler.jl:1555
[16] compile_mlir! (repeats 2 times)
@ ~/.julia/packages/Reactant/QTNFa/src/Compiler.jl:1522 [inlined]
[17] compile_xla(f::Function, args::Tuple{ConcretePJRTArray{…}, Vector{…}}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{compile_options::CompileOptions, fn_kwargs::@NamedTuple{}})
@ Reactant.Compiler ~/.julia/packages/Reactant/QTNFa/src/Compiler.jl:3433
[18] compile_xla
@ ~/.julia/packages/Reactant/QTNFa/src/Compiler.jl:3406 [inlined]
[19] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
@ Reactant.Compiler ~/.julia/packages/Reactant/QTNFa/src/Compiler.jl:3505
[20] macro expansion
@ ~/.julia/packages/Reactant/QTNFa/src/Compiler.jl:2586 [inlined]
[21] top-level scope
@ ~/Documents/Code/nuclear-diffprog/MWEs/reactcontrolflow.jl:206
[22] include(fname::String)
@ Main ./sysimg.jl:38
[23] top-level scope
@ REPL[10]:1
in expression starting at /Users/daningburg/Documents/Code/nuclear-diffprog/MWEs/reactcontrolflow.jl:38
Some type information was truncated. Use `show(err)` to see complete types.
I’ve read the Reactant page on control flow but it’s not clear to me how to apply it to my situation: Control Flow | Reactant.jl