Variable sequence length RNN in Flux

How to implement RNN for variable sequence length data with minibatching in Flux?
According to Flux doc,

In Flux, those 3 dimensions are provided through a vector of seq length containing a matrix (features, samples) .
Recurrence · Flux

but it’s impossible to create such a vector of matrices for variable sequence length data.

For example, I’m assuming the following situation:

using Flux

dim_feature = 4
dim_sample = 100
min_seq_length = 3
max_seq_length = 6

dim_output = 2

x_seq_length = rand(min_seq_length:max_seq_length, dim_sample)

# input dataset
x = []
for i = 1:dim_sample
	 push!(x, rand(dim_feature, x_seq_length[i]))
end

network = LSTM(dim_feature, dim_output)

# want to apply minibatch of input dataset `x` to `network` in minibatch SGD

I haven’t used the recursive cells in Flux before, so not sure this is the best way of doing things, but this seems to learn at least when plotting the loss. And it can handle batches containing sequences of different lengths.

I’m a bit unsure about the reset! in the loss function, since there was recently a post about something similar where I remember it was suggested that this might not be good practice, but I can’t see why it is a problem.

using Flux, Plots

function loss(m, xs, ys)
    loss = 0f0
    for (x, y) in zip(xs, ys)
        Flux.reset!(m) # Reset the state of the recursive cell for each new sequence
        loss += sum(exp2.(m(x)[:, end] - y)) 
    end
    loss 
end

dim_feature = 4
dim_sample = 100
min_seq_length = 3
max_seq_length = 6
dim_output = 2

x_seq_length = rand(min_seq_length:max_seq_length, dim_sample)
x = rand.(Float32, dim_feature, x_seq_length)
y = [rand(dim_output) for _ in 1:dim_sample]

m = LSTM(dim_feature, dim_output)
	
data = Flux.DataLoader((x, y), batchsize=4)
opt = ADAM()
losses = [loss(m, x, y)]
Flux.@epochs 40 begin
	Flux.train!((x, y) -> loss(m, x, y), Flux.params(m), data, opt)
	push!(losses, loss(m, x, y))
end
plot(losses)
1 Like