Slow LSTM on GPU in Flux

Turns out I was comparing apple to oranges… the correct code for Flux should be:

function julia_nn(X, Y, epochs=100)
    lstm = Chain(
        LSTM(100 => 128),
        LSTM(128 => 128),
        Dense(128 => 1)
    ) |> gpu

    opt = ADAM()
    θ = Flux.params(lstm)
    for epoch ∈ 1:epochs
        Flux.reset!(lstm)
        ∇ = gradient(θ) do 
            [lstm(x) for x ∈ X[1:end-1]] # Warm up
            Flux.Losses.mse(lstm(X[end]), Y[end]) # MSE on last item only
        end
        Flux.update!(opt, θ, ∇)
    end
end

Since there I was doing the MSE on the full sequence instead of the last item only. Doing so brings the performance much closer to PyTorch with approximately 630 ms average speed.

1 Like