Hi,
I’ve been running into an issue where I try to train a neural network using Flux and the training times are much slower on GPU than on CPU. So I tried comparing an MWE using Flux and an MWE using PyCall and running the same model in PyTorch.
The PyTorch version runs in approximately 500 ms on average, while the Flux version is four times slower with an average of 2 seconds.
The code I have used is the following
# Necessary packages
using PyCall
using BenchmarkTools
using Flux
using StatsBase
# ===== PyTorch example =====
py"""
import torch
import torch.nn as nn
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__() # Base class constructor
self.lstm = nn.LSTM(input_size, hidden_size, num_layers=2,
batch_first=True)
self.linear = nn.Linear(hidden_size, output_size)
def forward(self, x):
h = self.lstm(x)[0]
return self.linear(h)
def train_network(X, Y, epochs=100):
dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
lstm = LSTM(100, 128, 1).to(dev)
X, Y = torch.from_numpy(X).to(dev), torch.from_numpy(Y).to(dev)
lstm.train()
criterion = nn.MSELoss()
opt = torch.optim.Adam(lstm.parameters())
for epoch in range(epochs):
output = lstm(X)
output = output[:, -1, :]
loss = criterion(output, Y)
opt.zero_grad()
loss.backward()
opt.step()
"""
train_pytorch = py"train_network"
X = rand(Float32, 1_000, 20, 100);
Y = rand(Float32, 1_000, 1);
@benchmark train_pytorch(X, Y) # Roughly 500 ms
# ===== Flux example =====
function julia_nn(X, Y, epochs=100)
lstm = Chain(
LSTM(100 => 128),
LSTM(128 => 128),
Dense(128 => 1)
) |> gpu
opt = ADAM()
θ = Flux.params(lstm)
for epoch ∈ 1:epochs
Flux.reset!(lstm)
∇ = gradient(θ) do
lstm(X[1]) # Warm up
Flux.Losses.mse.([lstm(x) for x ∈ X[2:end]], Y[2:end]) |> mean
end
Flux.update!(opt, θ, ∇)
end
end
X = gpu.([rand(Float32, 100, 1_000) for _ ∈ 1:20]);
Y = gpu.([rand(Float32, 1, 1_000) for _ ∈ 1:20]);
@benchmark julia_nn(X, Y) # Roughly 2 sec
As you can see, even though my features and output in the Flux example are passed to the GPU outside of the benchmark (this is not the case for the PyTorch version), the network in Flux trains very slowly.
Does anybody have an idea as to how one can improve the Flux example to reach the same speed as PyTorch? I am using CUDA Toolkit 11.6 with an RTX 3090.
As an aside, in general, training networks on the GPU using Flux is only better than training them on the CPU when the network has many parameters or when there are a lot of features. In contrast, any network I run in PyTorch is massively sped up on the GPU. I am grateful for any hints or suggestions on how to improve the performance of my Flux models.