Slow LSTM on GPU in Flux

Hi,

I’ve been running into an issue where I try to train a neural network using Flux and the training times are much slower on GPU than on CPU. So I tried comparing an MWE using Flux and an MWE using PyCall and running the same model in PyTorch.

The PyTorch version runs in approximately 500 ms on average, while the Flux version is four times slower with an average of 2 seconds.

The code I have used is the following

# Necessary packages
using PyCall
using BenchmarkTools
using Flux
using StatsBase

# ===== PyTorch example =====
py"""
import torch
import torch.nn as nn

class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__() # Base class constructor
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers=2, 
            batch_first=True)
        self.linear = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        h = self.lstm(x)[0]
        return self.linear(h)

def train_network(X, Y, epochs=100):
    dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    lstm = LSTM(100, 128, 1).to(dev)
    X, Y = torch.from_numpy(X).to(dev), torch.from_numpy(Y).to(dev)
    lstm.train()
    criterion = nn.MSELoss()
    opt = torch.optim.Adam(lstm.parameters())
    for epoch in range(epochs):
        output = lstm(X)
        output = output[:, -1, :]
        loss = criterion(output, Y)
        opt.zero_grad()
        loss.backward()
        opt.step()
"""

train_pytorch = py"train_network"

X = rand(Float32, 1_000, 20, 100);
Y = rand(Float32, 1_000, 1);
@benchmark train_pytorch(X, Y) # Roughly 500 ms


# ===== Flux example =====

function julia_nn(X, Y, epochs=100)
    lstm = Chain(
        LSTM(100 => 128),
        LSTM(128 => 128),
        Dense(128 => 1)
    ) |> gpu

    opt = ADAM()
    θ = Flux.params(lstm)
    for epoch ∈ 1:epochs
        Flux.reset!(lstm)
        ∇ = gradient(θ) do 
            lstm(X[1]) # Warm up
            Flux.Losses.mse.([lstm(x) for x ∈ X[2:end]], Y[2:end]) |> mean
        end
        Flux.update!(opt, θ, ∇)
    end
end

X = gpu.([rand(Float32, 100, 1_000) for _ ∈ 1:20]);
Y = gpu.([rand(Float32, 1, 1_000) for _ ∈ 1:20]);

@benchmark julia_nn(X, Y) # Roughly 2 sec

As you can see, even though my features and output in the Flux example are passed to the GPU outside of the benchmark (this is not the case for the PyTorch version), the network in Flux trains very slowly.

Does anybody have an idea as to how one can improve the Flux example to reach the same speed as PyTorch? I am using CUDA Toolkit 11.6 with an RTX 3090.

As an aside, in general, training networks on the GPU using Flux is only better than training them on the CPU when the network has many parameters or when there are a lot of features. In contrast, any network I run in PyTorch is massively sped up on the GPU. I am grateful for any hints or suggestions on how to improve the performance of my Flux models.

1 Like

Turns out I was comparing apple to oranges… the correct code for Flux should be:

function julia_nn(X, Y, epochs=100)
    lstm = Chain(
        LSTM(100 => 128),
        LSTM(128 => 128),
        Dense(128 => 1)
    ) |> gpu

    opt = ADAM()
    θ = Flux.params(lstm)
    for epoch ∈ 1:epochs
        Flux.reset!(lstm)
        ∇ = gradient(θ) do 
            [lstm(x) for x ∈ X[1:end-1]] # Warm up
            Flux.Losses.mse(lstm(X[end]), Y[end]) # MSE on last item only
        end
        Flux.update!(opt, θ, ∇)
    end
end

Since there I was doing the MSE on the full sequence instead of the last item only. Doing so brings the performance much closer to PyTorch with approximately 630 ms average speed.

1 Like

what about cpu usage compared to pytorch ?

Hmm, not great actually.

According to the benchmark, PyTorch runs the above code in ~11.5 seconds on my machine, while the Flux version (the corrected version) runs in ~30 seconds. This is a very large discrepancy.

Are you measuring the entire script end-to-end? I suspect that compiling the model will take at least 15-20 seconds with Flux, so if you want a pure runtime estimate I’d recommend calling gradient once as a “warm-up” using some small, dummy data (e.g. random batch of 1) before your training loop. PyTorch startup latency is usually negligible, but it wouldn’t hurt to isolate the training loop timing there from imports and model instantiation as well.

1 Like

Not the entire script, but the function julia_nn in which the network is indeed being created. I will look into this and update you with the results. Might take me a few days due to other more pressing problems. Thank you for the hints.

Right, so everything up to and including the first run of the loop will incur compilation latency. If you can separate that code out and run one loop iteration before the main training loop (you can use a throwaway model and data for this), then you’ll have isolated just the runtime bits.

So, if I just “warm up” the model before the function, I end up with ~27.5 seconds on the CPU.

If I take out everything (initialization of the optimizer, extraction of the parameters, first run of the loop with gradient computation and update), I unfortunately still end up with the same amount of time.

Can you try adding GitHub - JuliaLinearAlgebra/MKL.jl: Intel MKL linear algebra backend for Julia to your project and importing at at the top of the script (see README for how to check it’s loaded)? 11.5s already seems way too high for running a RNN over a single batch of data, so perhaps I’m missing something there too.

So I tried MKL.jl out. This only slightly speeds up the Flux version (~23 seconds). On the other hand, the PyTorch version instantly crashes with MKL, giving the following error message

┌ Warning: both Julia and NumPy are linked with MKL, which may cause conflicts and crashes (#433).
└ @ PyCall ~/.julia/packages/PyCall/7a7w0/src/numpy.jl:77

signal (11): Segmentation fault
in expression starting at REPL[14]:1
Allocations: 28076327 (Pool: 28069689; Big: 6638); GC: 21
Segmentation fault (core dumped)

Concerning the comment about a single batch of data, note that I am running 100 epochs with a batch size of 1’000 and 100 features per observation. Perhaps I did not make this explicit enough.

Are the PyTorch and Flux code not separated? I believe PyTorch uses MKL already on CPU, so I could see how loading another version could mess things up.

Another reason I ask about separating out the code is that you may be running into a GC bottleneck. If # ===== Flux example ===== and # ===== PyTorch example ===== are already in separate files, then disregard this. Another thing to note is that if you’re using @benchmark to time, that I believe messes with GC a bit—would recommend a plain @time instead.

They were not separated but I was only running parts of the scripts in fresh Julia sessions so I believe it should amount to the same. In any case I went ahead and created two separate scripts and also tried both @time and @benchmark. The results do not really change for @time:

Flux: 25.607051 seconds (1.25 M allocations: 49.215 GiB, 50.91% gc time, 7.06% compilation time)
PyTorch: 12.022518 seconds (1.44 M allocations: 77.844 MiB, 0.22% gc time, 2.91% compilation time)

When using @benchmark, however, I now run into errors, e.g., telling me X is not defined or lstm is not defined. I’m not sure how @benchmark works but I’m guessing these new problems come from this since they don’t happen when using @time.

I have created a gist with both files, which I then simply call from console using julia mwe_flux.jl. Strangely enough, the PyTorch example also results in an error (but it does run and output the ~12 seconds before), even when using only @time. Perhaps I’m doing something wrong in the way my scripts are run.

1 Like

Thanks. I pulled out the python script into its own file to remove Julia as a factor for timing, and I can’t replicate that 12s time. This makes me suspect @time is succeeding in your version despite an error being thrown.

Full Python source:
import torch
import torch.nn as nn

from timeit import default_timer
from torch.profiler import profile, record_function, ProfilerActivity

class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__() # Base class constructor
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers=2,
            batch_first=True)
        self.linear = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        h = self.lstm(x)[0]
        return self.linear(h)

def train_network(X, Y, epochs=100):
    dev = torch.device("cpu") # torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    lstm = LSTM(100, 128, 1).to(dev)
    X, Y = X.to(dev), Y.to(dev)
    lstm.train()
    criterion = nn.MSELoss()
    opt = torch.optim.Adam(lstm.parameters())

    start = default_timer()
    for epoch in range(epochs):
        output = lstm(X)
        output = output[:, -1, :]
        opt.zero_grad()
        loss = criterion(output, Y)
        loss.backward()
        opt.step()
    end = default_timer()
    print(end - start)

if __name__ == "__main__":
    X = torch.rand(1000, 20, 100)
    Y = torch.rand(1000, 1)
    # with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
    train_network(X, Y)
    # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=30))

Timings for the Julia version are ~23s like you saw earlier, whereas timing for python is ~19s.

I played around with the LSTM cell implementation and managed to get ~17s. This is now a PR at Use muladd for LSTM cell matmuls by ToucheSir · Pull Request #2023 · FluxML/Flux.jl · GitHub, but could use some validation to ensure it doesn’t compromise GPU performance.

3 Likes

Thank you so much for all your work.

I also ran the Python script directly, i.e., without using PyCall. I still achieve 12s with Python 3.10.4, PyTorch 1.12.0 and an Intel(R) Core™ i9-10980XE CPU @ 3.00GHz CPU. Also, this doesn’t produce any errors as long as I don’t run it using the .jl script I posted as a gist above.

I might be wrong here, but the Pytorch version seems to apply the Dense layer only once, while Flux version applies it multiple times.

1 Like

Initially, I thought this to be correct, so I went ahead and used


struct Seq2One
    rnn
    fc
end
Flux.@functor Seq2One

function (m::Seq2One)(X)
    [m.rnn(x) for x ∈ X[1:end-1]]
    m.fc(m.rnn(X[end]))
end

to set up the model. However, running the full script with this model instead still results in a 23 seconds runtime on my machine, so I’m a bit confused. Let me know if I misunderstood what you meant and you think I should implement it differently.

I can see the performance difference in my machine as well. It seems to be coming majorly from gc (like 50%). The batch size of 1000 with 250k model parameters leads to quite high gc triggers. If I reduce the batchsize them julia is faster than pytorch till around 256 but beyond that gc timings hit >40%

import Lux, Optimisers, Random, Statistics, Zygote
using MKL

struct StackedLSTM{N, C} <: Lux.AbstractExplicitContainerLayer{(:cells,)}
  cells::NTuple{N, C}
end

# Hardcoded for 2 Cells Case but not hard to generalize 
function Lux.initialparameters(rng::Random.AbstractRNG, l::StackedLSTM)
  return (cell_1 = Lux.initialparameters(rng, l.cells[1]),
          cell_2 = Lux.initialparameters(rng, l.cells[2]))
end

function Lux.initialstates(rng::Random.AbstractRNG, l::StackedLSTM)
  return (cell_1 = Lux.initialstates(rng, l.cells[1]),
          cell_2 = Lux.initialstates(rng, l.cells[2]))
end

function (l::StackedLSTM)(x::NTuple{N, <:AbstractMatrix}, ps, st) where {N}
  (h1, c1), st_c1 = l.cells[1](x[1], ps.cell_1, st.cell_1)
  (h2, c2), st_c2 = l.cells[2](h1, ps.cell_2, st.cell_2)
  for i in 2:N
    (h1, c1), st_c1 = l.cells[1]((x[i], h1, c1), ps.cell_1, st_c1)
    (h2, c2), st_c2 = l.cells[2]((h1, h2, c2), ps.cell_2, st_c2)
  end
  return h2, (cell_1 = st_c1, cell_2 = st_c2)
end

function train()
  rng = Random.default_rng()

  x = Tuple([rand(rng, Float32, 100, 1_000) for _ in 1:20])
  y = Tuple([rand(rng, Float32, 1, 1_000) for _ in 1:20])

  lstm = StackedLSTM((Lux.LSTMCell(100 => 128), Lux.LSTMCell(128 => 128)))
  model = Lux.Chain(lstm, Lux.Dense(128 => 1))
  ps, st = Lux.setup(rng, model)

  opt = Optimisers.Adam(0.001)
  st_opt = Optimisers.setup(opt, ps)

  loss_function(ps) = mean(abs2, model(x, ps, st)[1] .- y[end])

  # Warmup
  @time model(x, ps, st)
  @time Zygote.gradient(loss_function, ps)
  @time Optimisers.update(st_opt, ps, ps)

  @time for i in 1:100
    gs = Zygote.gradient(loss_function, ps)[1]
    st_opt, ps = Optimisers.update(st_opt, ps, gs)
  end
end

train()
  • Julia – 52.183093 seconds (1.49 M allocations: 86.648 GiB, 25.23% gc time, 0.03% compilation time)
  • Pytorch – 34.413s
1 Like

I’m on an older CPU and Python 3.9, but it’s odd that the Julia perf stays the same while Python perf improves 1.5x+. Have you tried LinearAlgebra.BLAS.set_num_threads(<number of CPU cores>) as well?

@avikpal you’ll want to use the in-place Optimisers.update! because the out of place version does a ton of defensive copying (for reference, the lowest I was able to get allocations was ~40GB). I suspect that Lux’s LSTMCell could also benefit from the muladd change I linked above.

Would choosing a different AD make any difference or be possible at all?
Some people use Nabla, Yota or Enzyme.

AFAIK none of those support enough of the DL stack to be useful right now.