Turns out I was comparing apple to oranges… the correct code for Flux should be:
function julia_nn(X, Y, epochs=100)
lstm = Chain(
LSTM(100 => 128),
LSTM(128 => 128),
Dense(128 => 1)
) |> gpu
opt = ADAM()
θ = Flux.params(lstm)
for epoch ∈ 1:epochs
Flux.reset!(lstm)
∇ = gradient(θ) do
[lstm(x) for x ∈ X[1:end-1]] # Warm up
Flux.Losses.mse(lstm(X[end]), Y[end]) # MSE on last item only
end
Flux.update!(opt, θ, ∇)
end
end
Since there I was doing the MSE on the full sequence instead of the last item only. Doing so brings the performance much closer to PyTorch with approximately 630 ms average speed.