I haven’t used the recursive cells in Flux before, so not sure this is the best way of doing things, but this seems to learn at least when plotting the loss. And it can handle batches containing sequences of different lengths.
I’m a bit unsure about the reset! in the loss function, since there was recently a post about something similar where I remember it was suggested that this might not be good practice, but I can’t see why it is a problem.
using Flux, Plots
function loss(m, xs, ys)
loss = 0f0
for (x, y) in zip(xs, ys)
Flux.reset!(m) # Reset the state of the recursive cell for each new sequence
loss += sum(exp2.(m(x)[:, end] - y))
end
loss
end
dim_feature = 4
dim_sample = 100
min_seq_length = 3
max_seq_length = 6
dim_output = 2
x_seq_length = rand(min_seq_length:max_seq_length, dim_sample)
x = rand.(Float32, dim_feature, x_seq_length)
y = [rand(dim_output) for _ in 1:dim_sample]
m = LSTM(dim_feature, dim_output)
data = Flux.DataLoader((x, y), batchsize=4)
opt = ADAM()
losses = [loss(m, x, y)]
Flux.@epochs 40 begin
Flux.train!((x, y) -> loss(m, x, y), Flux.params(m), data, opt)
push!(losses, loss(m, x, y))
end
plot(losses)