This is the training function I ended up having.
function seq_batch_train!(loss, ps, data, opt; cb = () -> ())
local training_loss
cb = Flux.Optimise.runall(cb)
x, y = data
x = x |> gpu
y = y |> gpu
gs = Flux.gradient(ps) do
training_loss = loss(x, y)
return training_loss
end
@show(training_loss)
Flux.Optimise.update!(opt, ps, gs)
cb()
end
And here is the training loop with epoches.
for epoch in 1:1
@show(epoch)
for batch in train_loader
seq_batch_train!(loss, ps, batch, opt, cb = evalcb)
end
@save "model_$(now())_epoch-$epoch.bson" m opt
end
The trick is in the loss function:
function loss(x, y)
Flux.reset!(m)
y_pred = [(m(x[:, xi, :]) - y[:, xi, :]).^2 for xi in axes(x, 2)] |> sum |> sum
y_pred / length(x)
end
hcat() is very slow on GPU and therefore the MSE is manually calculated instead of using mse() from Flux.
Currently batch size of 10 uses about 8GB of GPU memory. I can see some tweaks to get it work with longer time series.