How to do batching in Flux's recurrent sequence model to take advantage of GPU during training?

Suppose I have a recurrent model in Flux, say

using Flux
D = 2
m = RNN(D,3)

If want to apply it to sequence, then the documentation says I should represent my sequence seq as a vector of vectors, so something like

T = 5
seq = [rand(D) for _ in 1:T]
y = m.(seq)

Now, how should I batch my sequences so that I can take advantage of GPU during training?
I’m guessing, if I have a batch of N sequences each of length T then I should represent them as a vector of D×N-matrices? So something like

N = 7
seq_batch = [rand(D,N) for _ in 1:T]
y_batch = m.(seq_batch)

My understanding is that this will create N hidden state vectors each independent of each other (and then doing reset!(m) will reset each one of them), which is what I’m looking for. Is my understanding correct?

You are right. If you examine m.state variable after the model is applied to the batch, you will see that it is a 3x7 matrix:

julia> m.state
Tracked 3×7 Array{Float64,2}:
  0.18987    0.514859   0.0933913  -0.159167   -0.722453   -0.776108  -0.488948
 -0.673184  -0.821717  -0.166874   -0.0589437  -0.0631722   0.111586  -0.187056
 -0.746151  -0.775699  -0.757262   -0.740736   -0.425856   -0.575085  -0.231328

julia> size(m.state)
(3, 7)

reset! sets the state back to the original value, which is zero.

julia> Flux.reset!(m)

julia> m.state
3-element Array{Float32,1}:
1 Like