Trouble using @compact along reactant.jl

Hello,
I’m trying to set a simple network as described by the code below, but when I compiled it to GPU using reactant.jl it works if the sequence dimension (2nd dimension) is smaller than 100, but if I grow it bigger it takes ages to compile.
I suppose I need to use @trace but then it throws errors, is anyone kind enough to point me how to modifiy my @compact so that it compiles also for long sequences ! Thanks a lot

struct cmp_model{L} <: AbstractLuxContainerLayer{(:rnn_cell,:dense_in,:dense_out)}
    rnn_cell::RNNCell
    dense_in::Dense
    dense_out::Dense
end

function cmp_model(hidden_dims) #works on gpu with small size and cpu with any size, but not gpu with large size (OOM)
    return @compact(;
        dense_in=Dense(1 => hidden_dims),
        rnn_cell = RNNCell(hidden_dims => hidden_dims),
        dense_out = Dense(hidden_dims => 2),
    ) do x::AbstractArray{T,3} where {T}
        x = dense_in(x)
        x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
        y, carry = rnn_cell(x_init)
        ys = cat(y;dims=3)
        for x in x_rest
            y, carry = rnn_cell((x, carry))
            ys = cat(ys,y;dims=3)
        end
        ys = permutedims(ys, (1,3,2))
        ys = dense_out(ys)
        ys = cat(ys[1:1,:,:],cumsum(ys[2:2,:,:];dims=1);dims=1)
        @return ys
    end
end

cc @avikpal

[also what is @compact, mind putting a fully runnable code as a snippet?]

I’m guessing it’s Training a Simple LSTM | Lux.jl Docs

reproduced even though I wouldn’t call that ages and I’m at 1000 seqlen

using Reactant,Lux,Random

function cmp_model(hidden_dims) #works on gpu with small size and cpu with any size, but not gpu with large size (OOM)
           return @compact(;
               dense_in=Dense(1 => hidden_dims),
               rnn_cell = RNNCell(hidden_dims => hidden_dims),
               dense_out = Dense(hidden_dims => 2),
           ) do x::AbstractArray{T,3} where {T}
               x = dense_in(x)
               x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
               y, carry = rnn_cell(x_init)
               ys = cat(y;dims=3)
               for x in x_rest
                   y, carry = rnn_cell((x, carry))
                   ys = cat(ys,y;dims=3)
               end
               ys = permutedims(ys, (1,3,2))
               ys = dense_out(ys)
               ys = cat(ys[1:1,:,:],cumsum(ys[2:2,:,:];dims=1);dims=1)
               @return ys
           end
       end
model = cmp_model(32);
ps, st = Lux.setup(Random.default_rng(),model) |> reactant_device();
x = rand(Float32,1,1000,1000) |> reactant_device();
mc = @time( @compile sync=true model(x,ps,st) )
 42.431371 seconds (13.43 M allocations: 550.804 MiB, 0.20% gc time, 3.18% compilation time)
Reactant compiled function CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, var"#cmp_model##4#cmp_model##5", Nothing, @NamedTuple{dense_in::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}, rnn_cell::RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, Nothing, typeof(zeros32), Static.True}, dense_out::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{}}, Tuple{Tuple{}, Tuple{}}}(static(:₋₋₋no_special_dispatch₋₋₋), var"#cmp_model##4#cmp_model##5"(), nothing, ("@compact", "x::(AbstractArray{T, 3} where T)", "begin\n    x = dense_in(x)\n    (x_init, x_rest) = Iterators.peel(LuxOps.eachslice(x, Val(2)))\n    (y, carry) = rnn_cell(x_init)\n    ys = cat(y; dims = 3)\n    for x = x_rest\n        (y, carry) = rnn_cell((x, carry))\n        ys = cat(ys, y; dims = 3)\n    end\n    ys = permutedims(ys, (1, 3, 2))\n    ys = dense_out(ys)\n    ys = cat(ys[1:1, :, :], cumsum(ys[2:2, :, :]; dims = 1); dims = 1)\n    return ys\nend"), (dense_in = Dense(1 => 32), rnn_cell = RNNCell(32 => 32, tanh), dense_out = Dense(32 => 2)), (dense_in = Dense(1 => 32), rnn_cell = RNNCell(32 => 32, tanh), dense_out = Dense(32 => 2)), Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{}}(NamedTuple(), NamedTuple()), ((), ())) (with tag ##CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, var"#cmp_model##4#cmp_model##5", Nothing, @NamedTuple{dense_in::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}, rnn_cell::RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, Nothing, typeof(zeros32), Static.True}, dense_out::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{}}, Tuple{Tuple{}, Tuple{}}}(static(:₋₋₋no_special_dispatch₋₋₋), var"#cmp_model##4#cmp_model##5"(), nothing, ("@compact", "x::(AbstractArray{T, 3} where T)", "begin\n    x = dense_in(x)\n    (x_init, x_rest) = Iterators.peel(LuxOps.eachslice(x, Val(2)))\n    (y, carry) = rnn_cell(x_init)\n    ys = cat(y; dims = 3)\n    for x = x_rest\n        (y, carry) = rnn_cell((x, carry))\n        ys = cat(ys, y; dims = 3)\n    end\n    ys = permutedims(ys, (1, 3, 2))\n    ys = dense_out(ys)\n    ys = cat(ys[1:1, :, :], cumsum(ys[2:2, :, :]; dims = 1); dims = 1)\n    return ys\nend"), (dense_in = Dense(1 => 32), rnn_cell = RNNCell(32 => 32, tanh), dense_out = Dense(32 => 2)), (dense_in = Dense(1 => 32), rnn_cell = RNNCell(32 => 32, tanh), dense_out = Dense(32 => 2)), Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{}}(NamedTuple(), NamedTuple()), ((), ()))_reactant#504)
function cmp_model(hidden_dims) #works on gpu with small size and cpu with any size, but not gpu with large size (OOM)
           return @compact(;
               dense_in=Dense(1 => hidden_dims),
               rnn_cell = RNNCell(hidden_dims => hidden_dims),
               dense_out = Dense(hidden_dims => 2),
           ) do x::AbstractArray{T,3} where {T}
               x = dense_in(x)
               x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
               y, carry = rnn_cell(x_init)
               ys = cat(y;dims=3)
               @trace for x in x_rest
                   y, carry = rnn_cell((x, carry))
                   ys = cat(ys,y;dims=3)
               end
               ys = permutedims(ys, (1,3,2))
               ys = dense_out(ys)
               ys = cat(ys[1:1,:,:],cumsum(ys[2:2,:,:];dims=1);dims=1)
               @return ys
           end
       end

model = cmp_model(32);
ps, st = Lux.setup(Random.default_rng(),model) |> reactant_device();
x = rand(Float32,1,1000,1000) |> reactant_device();
mc = @time( @compile sync=true model(x,ps,st) )
ERROR: MethodError: no method matching step(::Base.Iterators.Rest{Vector{Reactant.TracedRArray{Float32, 2}}, Int64})
The function `step` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  step(::CartesianIndices)
   @ Base multidimensional.jl:465
  step(::Static.OptionallyStaticStepRange{<:Any, Int64})
   @ Static ~/.julia/packages/Static/TjBVO/src/ranges.jl:196
  step(::Static.OptionallyStaticStepRange{<:Any, Static.StaticInt{S}}) where S
   @ Static ~/.julia/packages/Static/TjBVO/src/ranges.jl:197
  ...

Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/ReactantCore/KHVpc/src/ReactantCore.jl:450 [inlined]
  [2] #cmp_model##6
    @ ./none:0
  [3] call_with_reactant
    @ ./none:-1 [inlined]
  [4] call_with_reactant(::Reactant.EnsureReturnType{…}, ::var"#cmp_model##6#cmp_model##7", ::@NamedTuple{…}, ::Reactant.TracedRArray{…}, ::@NamedTuple{…}, ::@NamedTuple{…})
    @ Reactant ~/.julia/packages/Reactant/vhs6d/src/utils.jl:0
  [5] macro expansion
    @ ~/.julia/packages/Lux/ZSmgp/src/helpers/compact.jl:381 [inlined]
  [6] CompactLuxLayer
    @ ~/.julia/packages/Lux/ZSmgp/src/helpers/compact.jl:373
  [7] call_with_reactant
    @ ./none:-1 [inlined]
  [8] call_with_reactant(::Reactant.EnsureReturnType{…}, ::CompactLuxLayer{…}, ::Reactant.TracedRArray{…}, ::@NamedTuple{…}, ::@NamedTuple{…})
    @ Reactant ~/.julia/packages/Reactant/vhs6d/src/utils.jl:0
  [9] call_with_reactant
    @ ./none:-1 [inlined]
 [10] call_with_reactant(::Reactant.EnsureReturnType{…}, ::Reactant.var"##apply#128", ::@Kwargs{}, ::typeof(Reactant.apply), ::CompactLuxLayer{…}, ::Reactant.TracedRArray{…}, ::@NamedTuple{…}, ::@NamedTuple{…})
    @ Reactant ~/.julia/packages/Reactant/vhs6d/src/utils.jl:0
 [11] call_with_reactant
    @ ./none:-1 [inlined]
 [12] call_with_reactant(::typeof(Reactant.apply), ::CompactLuxLayer{…}, ::Reactant.TracedRArray{…}, ::@NamedTuple{…}, ::@NamedTuple{…})
    @ Reactant ~/.julia/packages/Reactant/vhs6d/src/utils.jl:0
 [13] make_mlir_fn(f::typeof(Reactant.apply), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, within_autodiff::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/vhs6d/src/TracedUtils.jl:370
 [14] make_mlir_fn
    @ ~/.julia/packages/Reactant/vhs6d/src/TracedUtils.jl:299 [inlined]
 [15] #make_mlir_fn#4
    @ ~/.julia/packages/Reactant/vhs6d/src/TracedUtils.jl:322 [inlined]
 [16] make_mlir_fn
    @ ~/.julia/packages/Reactant/vhs6d/src/TracedUtils.jl:299 [inlined]
 [17] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::CompactLuxLayer{…}, args::Tuple{…}, compile_options::CompileOptions, debugcache::Vector{…}, callcache::Dict{…}, sdycache::Dict{…}, sdygroupidcache::Tuple{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, client::Reactant.XLA.PJRT.Client, kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/packages/Reactant/vhs6d/src/compiler/Compiler.jl:278
 [18] compile_mlir!
    @ ~/.julia/packages/Reactant/vhs6d/src/compiler/Compiler.jl:237 [inlined]
 [19] compile_xla(ctx::Reactant.MLIR.IR.Context, f::CompactLuxLayer{…}, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/.julia/packages/Reactant/vhs6d/src/compiler/Compiler.jl:1330
 [20] compile_xla
    @ ~/.julia/packages/Reactant/vhs6d/src/compiler/Compiler.jl:1305 [inlined]
 [21] compile(ctx::Reactant.MLIR.IR.Context, f::CompactLuxLayer{…}, args::Tuple{…}; kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/.julia/packages/Reactant/vhs6d/src/compiler/Compiler.jl:1422
 [22] macro expansion
    @ ~/.julia/packages/Reactant/vhs6d/src/compiler/Macros.jl:256 [inlined]
 [23] macro expansion
    @ ~/.julia/packages/LLVM/eBGq5/src/base.jl:113 [inlined]
 [24] macro expansion
    @ ~/.julia/packages/Reactant/vhs6d/src/compiler/Macros.jl:255 [inlined]
 [25] macro expansion
    @ ./timing.jl:697 [inlined]
 [26] top-level scope
    @ ./REPL[44]:0
Some type information was truncated. Use `show(err)` to see complete types.

Would be nice to make Reactant tracing Iterator friendly, a small mwe of the trace issue is

using Reactant
function foo(x)
           res = copy(x)
           it = eachslice(x;dims=2)
           @trace for xi in it
               res .+= xi
           end
           res
       end
@time(@compile foo(x));
ERROR: MethodError: no method matching step(::ColumnSlices{Reactant.TracedRArray{…}, Tuple{…}, SubArray{…}})
The function `step` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  step(::CartesianIndices)
   @ Base multidimensional.jl:465
  step(::AbstractUnitRange{Bool})
   @ Base range.jl:714
  step(::AbstractUnitRange{T}) where T
   @ Base range.jl:713
  ...

Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/ReactantCore/KHVpc/src/ReactantCore.jl:450 [inlined]
  [2] foo
    @ ./REPL[11]:4
  [3] call_with_reactant
    @ ./none:-1 [inlined]
  [4] call_with_reactant(::typeof(foo), ::Reactant.TracedRArray{Float32, 2})
    @ Reactant ~/.julia/packages/Reactant/vhs6d/src/utils.jl:0
  [5] make_mlir_fn(f::typeof(foo), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, within_autodiff::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/vhs6d/src/TracedUtils.jl:370
  [6] make_mlir_fn
    @ ~/.julia/packages/Reactant/vhs6d/src/TracedUtils.jl:299 [inlined]
  [7] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::typeof(foo), args::Tuple{…}, compile_options::CompileOptions, debugcache::Vector{…}, callcache::Dict{…}, sdycache::Dict{…}, sdygroupidcache::Tuple{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, client::Reactant.XLA.PJRT.Client, kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/packages/Reactant/vhs6d/src/compiler/Compiler.jl:278
  [8] compile_xla(ctx::Reactant.MLIR.IR.Context, f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/.julia/packages/Reactant/vhs6d/src/compiler/Compiler.jl:1330
  [9] compile_xla
    @ ~/.julia/packages/Reactant/vhs6d/src/compiler/Compiler.jl:1305 [inlined]
 [10] compile(ctx::Reactant.MLIR.IR.Context, f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/.julia/packages/Reactant/vhs6d/src/compiler/Compiler.jl:1422
 [11] macro expansion
    @ ~/.julia/packages/Reactant/vhs6d/src/compiler/Macros.jl:256 [inlined]
 [12] macro expansion
    @ ~/.julia/packages/LLVM/eBGq5/src/base.jl:113 [inlined]
 [13] macro expansion
    @ ~/.julia/packages/Reactant/vhs6d/src/compiler/Macros.jl:255 [inlined]
 [14] macro expansion
    @ ./timing.jl:697 [inlined]
 [15] top-level scope
    @ ./REPL[12]:1
Some type information was truncated. Use `show(err)` to see complete types.