I am trying to work out how to build an RNN in Lux, but just using the build-in blocks (instead of the custom approach in the tutorial). I don’t understand why the following basic code throws an error – do i need to be explicit about the hidden state? And if so – how do I do this?
julia> size(x1)
(128, 7598)
julia> model = Recurrence(LSTMCell(128 => 1))
Recurrence(
cell = LSTMCell(128 => 1), # 520 parameters, plus 1
) # Total: 520 parameters,
# plus 1 states.
julia> ps, st = Lux.setup(Random.default_rng(), model)
((weight_i = Float32[-0.10227254 -0.20976427 … 0.18935335 -0.15638316; 0.018118372 -0.13053711 … -0.0667165 0.118691884; 0.12346243 -0.16388968 … 0.13615233 0.038347457; 0.11975594 -0.17523423 … 0.013865251 0.19539973], weight_h = Float32[-0.29179388; -1.6530173; -0.99460226; -0.39993474;;], bias = Float32[0.0; 0.0; 1.0; 0.0;;]), (rng = Xoshiro(0x905ab0e7cf8ef38e, 0x03c6ff989c980d8a, 0x5fe68f1f16bcc9d1, 0x7998abd5cf6ae0ee),))
julia> model(x1,ps,st)
ERROR: MethodError: no method matching (::LSTMCell{true, false, false, Tuple{typeof(zeros32), typeof(zeros32), typeof(ones32), typeof(zeros32)}, NTuple{4, typeof(glorot_uniform)}, typeof(zeros32), typeof(zeros32)})(::SubArray{Float32, 1, Matrix{Float32}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true}, ::NamedTuple{(:weight_i, :weight_h, :bias), Tuple{Matrix{Float32}, Matrix{Float32}, Matrix{Float32}}}, ::NamedTuple{(:rng,), Tuple{Xoshiro}})
Closest candidates are:
(::LSTMCell{true})(::Tuple{AbstractMatrix, Tuple{AbstractMatrix, AbstractMatrix}}, ::Any, ::NamedTuple)
@ Lux ~/.julia/packages/Lux/1Iulg/src/layers/recurrent.jl:436
(::LSTMCell{use_bias, false, false})(::AbstractMatrix, ::Any, ::NamedTuple) where use_bias
@ Lux ~/.julia/packages/Lux/1Iulg/src/layers/recurrent.jl:397
Stacktrace:
[1] apply(model::LSTMCell{true, false, false, Tuple{typeof(zeros32), typeof(zeros32), typeof(ones32), typeof(zeros32)}, NTuple{4, typeof(glorot_uniform)}, typeof(zeros32), typeof(zeros32)}, x::SubArray{Float32, 1, Matrix{Float32}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true}, ps::NamedTuple{(:weight_i, :weight_h, :bias), Tuple{Matrix{Float32}, Matrix{Float32}, Matrix{Float32}}}, st::NamedTuple{(:rng,), Tuple{Xoshiro}})
@ LuxCore ~/.julia/packages/LuxCore/aumFq/src/LuxCore.jl:115
[2] (::Recurrence{false, LSTMCell{true, false, false, Tuple{typeof(zeros32), typeof(zeros32), typeof(ones32), typeof(zeros32)}, NTuple{4, typeof(glorot_uniform)}, typeof(zeros32), typeof(zeros32)}, BatchLastIndex})(x::Vector{SubArray{Float32, 1, Matrix{Float32}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true}}, ps::NamedTuple{(:weight_i, :weight_h, :bias), Tuple{Matrix{Float32}, Matrix{Float32}, Matrix{Float32}}}, st::NamedTuple{(:rng,), Tuple{Xoshiro}})
@ Lux ~/.julia/packages/Lux/1Iulg/src/layers/recurrent.jl:78
[3] apply(model::Recurrence{false, LSTMCell{true, false, false, Tuple{typeof(zeros32), typeof(zeros32), typeof(ones32), typeof(zeros32)}, NTuple{4, typeof(glorot_uniform)}, typeof(zeros32), typeof(zeros32)}, BatchLastIndex}, x::Vector{SubArray{Float32, 1, Matrix{Float32}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true}}, ps::NamedTuple{(:weight_i, :weight_h, :bias), Tuple{Matrix{Float32}, Matrix{Float32}, Matrix{Float32}}}, st::NamedTuple{(:rng,), Tuple{Xoshiro}})
@ LuxCore ~/.julia/packages/LuxCore/aumFq/src/LuxCore.jl:115
[4] (::Recurrence{false, LSTMCell{true, false, false, Tuple{typeof(zeros32), typeof(zeros32), typeof(ones32), typeof(zeros32)}, NTuple{4, typeof(glorot_uniform)}, typeof(zeros32), typeof(zeros32)}, BatchLastIndex})(x::Matrix{Float32}, ps::NamedTuple{(:weight_i, :weight_h, :bias), Tuple{Matrix{Float32}, Matrix{Float32}, Matrix{Float32}}}, st::NamedTuple{(:rng,), Tuple{Xoshiro}})
@ Lux ~/.julia/packages/Lux/1Iulg/src/layers/recurrent.jl:74
[5] top-level scope
@ REPL[15]:1
[6] top-level scope
@ ~/.julia/packages/Metal/lnkVP/src/initialization.jl:57