reproduced even though I wouldn’t call that ages and I’m at 1000 seqlen
using Reactant,Lux,Random
function cmp_model(hidden_dims) #works on gpu with small size and cpu with any size, but not gpu with large size (OOM)
return @compact(;
dense_in=Dense(1 => hidden_dims),
rnn_cell = RNNCell(hidden_dims => hidden_dims),
dense_out = Dense(hidden_dims => 2),
) do x::AbstractArray{T,3} where {T}
x = dense_in(x)
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
y, carry = rnn_cell(x_init)
ys = cat(y;dims=3)
for x in x_rest
y, carry = rnn_cell((x, carry))
ys = cat(ys,y;dims=3)
end
ys = permutedims(ys, (1,3,2))
ys = dense_out(ys)
ys = cat(ys[1:1,:,:],cumsum(ys[2:2,:,:];dims=1);dims=1)
@return ys
end
end
model = cmp_model(32);
ps, st = Lux.setup(Random.default_rng(),model) |> reactant_device();
x = rand(Float32,1,1000,1000) |> reactant_device();
mc = @time( @compile sync=true model(x,ps,st) )
42.431371 seconds (13.43 M allocations: 550.804 MiB, 0.20% gc time, 3.18% compilation time)
Reactant compiled function CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, var"#cmp_model##4#cmp_model##5", Nothing, @NamedTuple{dense_in::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}, rnn_cell::RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, Nothing, typeof(zeros32), Static.True}, dense_out::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{}}, Tuple{Tuple{}, Tuple{}}}(static(:₋₋₋no_special_dispatch₋₋₋), var"#cmp_model##4#cmp_model##5"(), nothing, ("@compact", "x::(AbstractArray{T, 3} where T)", "begin\n x = dense_in(x)\n (x_init, x_rest) = Iterators.peel(LuxOps.eachslice(x, Val(2)))\n (y, carry) = rnn_cell(x_init)\n ys = cat(y; dims = 3)\n for x = x_rest\n (y, carry) = rnn_cell((x, carry))\n ys = cat(ys, y; dims = 3)\n end\n ys = permutedims(ys, (1, 3, 2))\n ys = dense_out(ys)\n ys = cat(ys[1:1, :, :], cumsum(ys[2:2, :, :]; dims = 1); dims = 1)\n return ys\nend"), (dense_in = Dense(1 => 32), rnn_cell = RNNCell(32 => 32, tanh), dense_out = Dense(32 => 2)), (dense_in = Dense(1 => 32), rnn_cell = RNNCell(32 => 32, tanh), dense_out = Dense(32 => 2)), Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{}}(NamedTuple(), NamedTuple()), ((), ())) (with tag ##CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, var"#cmp_model##4#cmp_model##5", Nothing, @NamedTuple{dense_in::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}, rnn_cell::RNNCell{Static.False, typeof(tanh), Int64, Int64, Nothing, Nothing, Nothing, typeof(zeros32), Static.True}, dense_out::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{}}, Tuple{Tuple{}, Tuple{}}}(static(:₋₋₋no_special_dispatch₋₋₋), var"#cmp_model##4#cmp_model##5"(), nothing, ("@compact", "x::(AbstractArray{T, 3} where T)", "begin\n x = dense_in(x)\n (x_init, x_rest) = Iterators.peel(LuxOps.eachslice(x, Val(2)))\n (y, carry) = rnn_cell(x_init)\n ys = cat(y; dims = 3)\n for x = x_rest\n (y, carry) = rnn_cell((x, carry))\n ys = cat(ys, y; dims = 3)\n end\n ys = permutedims(ys, (1, 3, 2))\n ys = dense_out(ys)\n ys = cat(ys[1:1, :, :], cumsum(ys[2:2, :, :]; dims = 1); dims = 1)\n return ys\nend"), (dense_in = Dense(1 => 32), rnn_cell = RNNCell(32 => 32, tanh), dense_out = Dense(32 => 2)), (dense_in = Dense(1 => 32), rnn_cell = RNNCell(32 => 32, tanh), dense_out = Dense(32 => 2)), Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{}}(NamedTuple(), NamedTuple()), ((), ()))_reactant#504)
function cmp_model(hidden_dims) #works on gpu with small size and cpu with any size, but not gpu with large size (OOM)
return @compact(;
dense_in=Dense(1 => hidden_dims),
rnn_cell = RNNCell(hidden_dims => hidden_dims),
dense_out = Dense(hidden_dims => 2),
) do x::AbstractArray{T,3} where {T}
x = dense_in(x)
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
y, carry = rnn_cell(x_init)
ys = cat(y;dims=3)
@trace for x in x_rest
y, carry = rnn_cell((x, carry))
ys = cat(ys,y;dims=3)
end
ys = permutedims(ys, (1,3,2))
ys = dense_out(ys)
ys = cat(ys[1:1,:,:],cumsum(ys[2:2,:,:];dims=1);dims=1)
@return ys
end
end
model = cmp_model(32);
ps, st = Lux.setup(Random.default_rng(),model) |> reactant_device();
x = rand(Float32,1,1000,1000) |> reactant_device();
mc = @time( @compile sync=true model(x,ps,st) )
ERROR: MethodError: no method matching step(::Base.Iterators.Rest{Vector{Reactant.TracedRArray{Float32, 2}}, Int64})
The function `step` exists, but no method is defined for this combination of argument types.
Closest candidates are:
step(::CartesianIndices)
@ Base multidimensional.jl:465
step(::Static.OptionallyStaticStepRange{<:Any, Int64})
@ Static ~/.julia/packages/Static/TjBVO/src/ranges.jl:196
step(::Static.OptionallyStaticStepRange{<:Any, Static.StaticInt{S}}) where S
@ Static ~/.julia/packages/Static/TjBVO/src/ranges.jl:197
...
Stacktrace:
[1] macro expansion
@ ~/.julia/packages/ReactantCore/KHVpc/src/ReactantCore.jl:450 [inlined]
[2] #cmp_model##6
@ ./none:0
[3] call_with_reactant
@ ./none:-1 [inlined]
[4] call_with_reactant(::Reactant.EnsureReturnType{…}, ::var"#cmp_model##6#cmp_model##7", ::@NamedTuple{…}, ::Reactant.TracedRArray{…}, ::@NamedTuple{…}, ::@NamedTuple{…})
@ Reactant ~/.julia/packages/Reactant/vhs6d/src/utils.jl:0
[5] macro expansion
@ ~/.julia/packages/Lux/ZSmgp/src/helpers/compact.jl:381 [inlined]
[6] CompactLuxLayer
@ ~/.julia/packages/Lux/ZSmgp/src/helpers/compact.jl:373
[7] call_with_reactant
@ ./none:-1 [inlined]
[8] call_with_reactant(::Reactant.EnsureReturnType{…}, ::CompactLuxLayer{…}, ::Reactant.TracedRArray{…}, ::@NamedTuple{…}, ::@NamedTuple{…})
@ Reactant ~/.julia/packages/Reactant/vhs6d/src/utils.jl:0
[9] call_with_reactant
@ ./none:-1 [inlined]
[10] call_with_reactant(::Reactant.EnsureReturnType{…}, ::Reactant.var"##apply#128", ::@Kwargs{}, ::typeof(Reactant.apply), ::CompactLuxLayer{…}, ::Reactant.TracedRArray{…}, ::@NamedTuple{…}, ::@NamedTuple{…})
@ Reactant ~/.julia/packages/Reactant/vhs6d/src/utils.jl:0
[11] call_with_reactant
@ ./none:-1 [inlined]
[12] call_with_reactant(::typeof(Reactant.apply), ::CompactLuxLayer{…}, ::Reactant.TracedRArray{…}, ::@NamedTuple{…}, ::@NamedTuple{…})
@ Reactant ~/.julia/packages/Reactant/vhs6d/src/utils.jl:0
[13] make_mlir_fn(f::typeof(Reactant.apply), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, within_autodiff::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
@ Reactant.TracedUtils ~/.julia/packages/Reactant/vhs6d/src/TracedUtils.jl:370
[14] make_mlir_fn
@ ~/.julia/packages/Reactant/vhs6d/src/TracedUtils.jl:299 [inlined]
[15] #make_mlir_fn#4
@ ~/.julia/packages/Reactant/vhs6d/src/TracedUtils.jl:322 [inlined]
[16] make_mlir_fn
@ ~/.julia/packages/Reactant/vhs6d/src/TracedUtils.jl:299 [inlined]
[17] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::CompactLuxLayer{…}, args::Tuple{…}, compile_options::CompileOptions, debugcache::Vector{…}, callcache::Dict{…}, sdycache::Dict{…}, sdygroupidcache::Tuple{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, client::Reactant.XLA.PJRT.Client, kwargs::@Kwargs{})
@ Reactant.Compiler ~/.julia/packages/Reactant/vhs6d/src/compiler/Compiler.jl:278
[18] compile_mlir!
@ ~/.julia/packages/Reactant/vhs6d/src/compiler/Compiler.jl:237 [inlined]
[19] compile_xla(ctx::Reactant.MLIR.IR.Context, f::CompactLuxLayer{…}, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{…})
@ Reactant.Compiler ~/.julia/packages/Reactant/vhs6d/src/compiler/Compiler.jl:1330
[20] compile_xla
@ ~/.julia/packages/Reactant/vhs6d/src/compiler/Compiler.jl:1305 [inlined]
[21] compile(ctx::Reactant.MLIR.IR.Context, f::CompactLuxLayer{…}, args::Tuple{…}; kwargs::@Kwargs{…})
@ Reactant.Compiler ~/.julia/packages/Reactant/vhs6d/src/compiler/Compiler.jl:1422
[22] macro expansion
@ ~/.julia/packages/Reactant/vhs6d/src/compiler/Macros.jl:256 [inlined]
[23] macro expansion
@ ~/.julia/packages/LLVM/eBGq5/src/base.jl:113 [inlined]
[24] macro expansion
@ ~/.julia/packages/Reactant/vhs6d/src/compiler/Macros.jl:255 [inlined]
[25] macro expansion
@ ./timing.jl:697 [inlined]
[26] top-level scope
@ ./REPL[44]:0
Some type information was truncated. Use `show(err)` to see complete types.