How to do batching in Flux's recurrent sequence model to take advantage of GPU during training?

Suppose I have a recurrent model in Flux, say

using Flux
D = 2
m = RNN(D,3)

If want to apply it to sequence, then the documentation says I should represent my sequence seq as a vector of vectors, so something like

T = 5
seq = [rand(D) for _ in 1:T]
y = m.(seq)

Now, how should I batch my sequences so that I can take advantage of GPU during training?
I’m guessing, if I have a batch of N sequences each of length T then I should represent them as a vector of D×N-matrices? So something like

N = 7
seq_batch = [rand(D,N) for _ in 1:T]
y_batch = m.(seq_batch)

My understanding is that this will create N hidden state vectors each independent of each other (and then doing reset!(m) will reset each one of them), which is what I’m looking for. Is my understanding correct?

1 Like

You are right. If you examine m.state variable after the model is applied to the batch, you will see that it is a 3x7 matrix:

julia> m.state
Tracked 3×7 Array{Float64,2}:
  0.18987    0.514859   0.0933913  -0.159167   -0.722453   -0.776108  -0.488948
 -0.673184  -0.821717  -0.166874   -0.0589437  -0.0631722   0.111586  -0.187056
 -0.746151  -0.775699  -0.757262   -0.740736   -0.425856   -0.575085  -0.231328

julia> size(m.state)
(3, 7)

reset! sets the state back to the original value, which is zero.

julia> Flux.reset!(m)

julia> m.state
3-element Array{Float32,1}:
 0.0
 0.0
 0.0
1 Like