Lux recurrent networks like LSTM - why are hidden state and memory not part of the model state `st`?

Lux.jl is special in that it has the extra st state variable which it distinguishes from the output and the other fixed parameters ps.

Hence I am currently confused about the implementation of LSTM

function initialstates(rng::AbstractRNG, ::LSTMCell)
    # FIXME(@avik-pal): Take PRNGs seriously
    randn(rng, 1)
    return (rng=replicate(rng),)
end

function (lstm::LSTMCell{use_bias, false, false})(
        x::AbstractMatrix, ps, st::NamedTuple) where {use_bias}
    rng = replicate(st.rng)
    @set! st.rng = rng
    hidden_state = _init_hidden_state(rng, lstm, x)
    memory = _init_hidden_state(rng, lstm, x)
    return lstm((x, (hidden_state, memory)), ps, st)
end

Why aren’t hidden_state and memory part of st?

As they are not, why is there actually a st parameter? Wouldn’t it simplify the interface if it is also passed as an input argument like (hidden_state, memory) here?

It would be great if someone can explain this design clash which I feel here.

1 Like
  1. Placing hidden_state and memory inside st makes the dispatch clunky. Currently, the dispatch is on the type of x, which is easy to understand (and maintain).
  2. The implementation that you are pointing to is of a *Cell, which is distinct from an RNN. For eg, LSTM, in this case, is Recurrence(AbstractRecurrentCell(....)) and if you see the implementation for that it actually hides all of the memory and state part from the end user.
  3. Gradients do propagate through the hidden_state and memory. You can place these in st but typically st is used for non-trainable and things that don’t propagate gradients (though the interface doesn’t enforce the latter)
2 Likes

Thank you. A follow up question:
next to random number generator, what are other typical usages of st ?

  • train/test mode flags
  • statistics tracking – Normalization Layers
  • Passing around complete solutions – often needed for DEQs, NeuralODEs
1 Like