I mean, this is already the case if you use model = Chain(...)
. There is no functionality in Flux for auto-calling reset!
, so you will have to do it yourself at some point. That said, doing so is pretty straightforward:
rnn = LSTM(10, 15)
fc = Dense(15, 5)
function model(seq)
reset!(rnn)
x = rnn.(seq)[end] # or use map, or just a loop
return fc(x)
end
function loss(seq, y, ...)
y_hat = model(seq)
return loss_func(y, y_hat)
end
Now you can use model without needing to reset manually or putting reset!
into the loss function.
Have you actually tried calling the RNN with a minibatched input like you describe? It’s a little confusing and I think we could make it less so, but everything works as you’d expect:
julia> rnn = RNN(10, 3)
Recur(RNNCell(10, 3, tanh))
julia> rnn.state
3-element Vector{Float32}:
0.0
0.0
0.0
julia> rnn.cell.h
3-element Vector{Float32}:
0.0
0.0
0.0
julia> x = rand(Float32, 10, 8);
julia> rnn(x)
3×8 Matrix{Float32}:
0.121863 0.0712726 0.468342 0.0159795 -0.50595 0.217166 0.321759 0.0969098
0.78138 0.0184485 0.309471 -0.131435 -0.0146722 0.552875 0.227291 0.191328
0.938252 0.981406 0.826487 0.98748 0.974808 0.960942 0.963614 0.964724
julia> rnn.state
3×8 Matrix{Float32}:
0.121863 0.0712726 0.468342 0.0159795 -0.50595 0.217166 0.321759 0.0969098
0.78138 0.0184485 0.309471 -0.131435 -0.0146722 0.552875 0.227291 0.191328
0.938252 0.981406 0.826487 0.98748 0.974808 0.960942 0.963614 0.964724
julia> rnn.cell.h
3-element Vector{Float32}:
0.0
0.0
0.0
julia> Flux.reset!(rnn)
3-element Vector{Float32}:
0.0
0.0
0.0
julia> rnn.state
3-element Vector{Float32}:
0.0
0.0
0.0
As you can see, the hidden state is actually stored in Recur
and not the RNN cell. That hidden state does start off as a vector, but will be overwritten as a matrix with the right number of samples if you pass it a minibatched input.