Lux.jl is special in that it has the extra st
state variable which it distinguishes from the output and the other fixed parameters ps
.
Hence I am currently confused about the implementation of LSTM
function initialstates(rng::AbstractRNG, ::LSTMCell)
# FIXME(@avik-pal): Take PRNGs seriously
randn(rng, 1)
return (rng=replicate(rng),)
end
function (lstm::LSTMCell{use_bias, false, false})(
x::AbstractMatrix, ps, st::NamedTuple) where {use_bias}
rng = replicate(st.rng)
@set! st.rng = rng
hidden_state = _init_hidden_state(rng, lstm, x)
memory = _init_hidden_state(rng, lstm, x)
return lstm((x, (hidden_state, memory)), ps, st)
end
Why aren’t hidden_state
and memory
part of st
?
As they are not, why is there actually a st
parameter? Wouldn’t it simplify the interface if it is also passed as an input argument like (hidden_state, memory)
here?
It would be great if someone can explain this design clash which I feel here.