Building simple sequence-to-one RNN with Flux

I mean, this is already the case if you use model = Chain(...). There is no functionality in Flux for auto-calling reset!, so you will have to do it yourself at some point. That said, doing so is pretty straightforward:

rnn = LSTM(10, 15)
fc = Dense(15, 5)

function model(seq)
  reset!(rnn)
  x = rnn.(seq)[end] # or use map, or just a loop
  return fc(x)
end

function loss(seq, y, ...)
  y_hat = model(seq)
  return loss_func(y, y_hat)
end

Now you can use model without needing to reset manually or putting reset! into the loss function.

Have you actually tried calling the RNN with a minibatched input like you describe? It’s a little confusing and I think we could make it less so, but everything works as you’d expect:

julia> rnn = RNN(10, 3)
Recur(RNNCell(10, 3, tanh))

julia> rnn.state
3-element Vector{Float32}:
 0.0
 0.0
 0.0

julia> rnn.cell.h
3-element Vector{Float32}:
 0.0
 0.0
 0.0

julia> x = rand(Float32, 10, 8);

julia> rnn(x)
3×8 Matrix{Float32}:
 0.121863  0.0712726  0.468342   0.0159795  -0.50595    0.217166  0.321759  0.0969098
 0.78138   0.0184485  0.309471  -0.131435   -0.0146722  0.552875  0.227291  0.191328
 0.938252  0.981406   0.826487   0.98748     0.974808   0.960942  0.963614  0.964724

julia> rnn.state
3×8 Matrix{Float32}:
 0.121863  0.0712726  0.468342   0.0159795  -0.50595    0.217166  0.321759  0.0969098
 0.78138   0.0184485  0.309471  -0.131435   -0.0146722  0.552875  0.227291  0.191328
 0.938252  0.981406   0.826487   0.98748     0.974808   0.960942  0.963614  0.964724

julia> rnn.cell.h
3-element Vector{Float32}:
 0.0
 0.0
 0.0

julia> Flux.reset!(rnn)
3-element Vector{Float32}:
 0.0
 0.0
 0.0

julia> rnn.state
3-element Vector{Float32}:
 0.0
 0.0
 0.0

As you can see, the hidden state is actually stored in Recur and not the RNN cell. That hidden state does start off as a vector, but will be overwritten as a matrix with the right number of samples if you pass it a minibatched input.

1 Like