Training Flux LSTM on GPU is slower than on CPU

Hi,

I’m trying to train an LSTM with timeseries data using Flux and attempting to speed it up by training on the GPU, but it takes longer than on the CPU.

The data is an n_features x n_observation matrix as generated in the MWE below (the data in the MWE is not the real data; the real data has 200 features), and my current goal is to have a network that is capable of making a prediction at each time-step with a ‘memory’ of batch_number steps.

I didn’t use DataLoader because I read this answer and would like to avoid doing extra reshaping in the training loop.

MWE:


using Flux, Statistics, Distributions
using Base.Iterators, LinearAlgebra

using Random
using CUDA, BenchmarkTools


function generate_data_matrix(start_elem,end_elem,T,func)
    #Generates n_features x n_observations Matrix with n_features = 2
    #First row is a function of time, second row is the element from the first row multiplied by a sine wave
    data_matrix = zeros(Float32,2,end_elem-start_elem+1)
    data_matrix[1,:] = func.(collect(start_elem:end_elem))
    data_matrix[2,:] = data_matrix[1,:].*(sin.((collect(start_elem:end_elem))./(T) .* (2*pi)))
    return data_matrix
end

function shape_recurrent(x::Matrix, axes)
    #Reshaping the data into a shape expected by LSTM
    return [x[:,i] for i in 1:size(x,axes)]
end


function batch_recurrent(full_matrix::Matrix,batch_number::Int)
    ax = 2
    @assert typeof(div((size(full_matrix)[ax] - 1),batch_number)) == Int
    batch_size = div((size(full_matrix)[ax] - 1),batch_number)
    all_datatuples = Vector{Tuple{Vector{Vector{Float32}},Vector{Vector{Float32}}}}(undef,batch_number)
    for idx in eachindex(all_datatuples)
        all_datatuples[idx] = (shape_recurrent(full_matrix[:,(idx-1)*batch_size+1:(idx)*batch_size],ax), shape_recurrent(full_matrix[:,(idx-1)*batch_size+2:(idx)*batch_size+1],ax))
    end
    return all_datatuples
end


function batch_recurrent_gpu(full_matrix::Matrix,batch_number::Int)
    ax = 2
    @assert typeof(div((size(full_matrix)[ax] - 1),batch_number)) == Int
    batch_size = div((size(full_matrix)[ax] - 1),batch_number)
    X_all = []
    Y_all = []
    for idx in 1:batch_number
        push!(X_all,gpu.(shape_recurrent(full_matrix[:,(idx-1)*batch_size+1:(idx)*batch_size],ax)))
        push!(Y_all,gpu.(shape_recurrent(full_matrix[:,(idx-1)*batch_size+2:(idx)*batch_size+1],ax)))
    end 
    return gpu(zip(X_all,Y_all))
end


function train_withbatch_gpu(model,opt,epochs,data_tup)
    opt_state = Flux.setup(opt,model)
    losses = []
    for epoch in 1:epochs
        for (inp,outp) in data_tup
            x,y = inp,outp
            Flux.reset!(model)
            loss,grad = Flux.withgradient(model) do m
                sum(Flux.Losses.mse.([m(xi) for xi in x[1:end]], y[1:end]))
            end
            Flux.update!(opt_state, model, grad[1])
            push!(losses,loss)
        end
    end
    return losses,model
end

function train_withbatch_cpu(model,opt,epochs,data_tup)
    opt_state = Flux.setup(opt,model)
    losses = []
    for epoch in 1:epochs
        for (inp,outp) in data_tup
            x,y = inp,outp
            Flux.reset!(model)
            loss,grad = Flux.withgradient(model) do m
                m(x[1])
                sum(Flux.Losses.mse.([m(xi) for xi in x[2:end]], y[2:end]))
            end
            Flux.update!(opt_state, model, grad[1])
            push!(losses,loss)
        end
    end
    return losses,model
end

start_elem = 1
end_elem = 361
T = 10
batch_number = 40
data_dims = 2
hidden_dims = 5*data_dims

generated_data_matrix = generate_data_matrix(start_elem,end_elem,T,x->sqrt(x))

cpu_datatuples = batch_recurrent(generated_data_matrix,batch_number)
gpu_datatuples = batch_recurrent_gpu(generated_data_matrix,batch_number)

cpu_model = Chain(LSTM(data_dims=>hidden_dims),LSTM(hidden_dims=>hidden_dims),Dense(hidden_dims=>data_dims))
gpu_model = Chain(LSTM(data_dims=>hidden_dims),LSTM(hidden_dims=>hidden_dims),Dense(hidden_dims=>data_dims)) |> gpu

epochs = 500
opt = ADAM()


gpu_losses, gpu_model = CUDA.@time train_withbatch_gpu(gpu_model,opt,epochs,gpu_datatuples)
cpu_losses, cpu_model = @time train_withbatch_cpu(cpu_model,opt,epochs,cpu_datatuples)

The MWE above does the job (I’ve plotted the results to confirm), but the GPU code is much slower than the CPU code. CUDA.@time also shows that there are more CPU allocations than GPU allocations.

Here are the outputs from CUDA.@time and @time (using NVIDIA RTX A5000 and Intel(R) Xeon(R) W-2265 CPU @ 3.50GHz)

138.760279 seconds (368.54 M CPU allocations: 14.604 GiB, 3.75% gc time) (16.12 M GPU allocations: 3.182 GiB, 21.28% memmgmt time)
12.789505 seconds (66.35 M allocations: 8.777 GiB, 4.22% gc time, 15.88% compilation time)

I’m new to GPUs and machine learning, so I’m not sure where I should start if I want to fix this issue.

I would appreciate any troubleshooting advice and any insights into why this is happening. Thanks!

GPU memory management could be an issue. You may be filling up GPU memory and then spend all your time shuffling data between GPU memory and shared memory. How full is your GPU memory? You may be able to manually free items, or call the GC manually every few iterations with GC.gc(false).

Also you could use @profview to see where the slowest bits are. If memory is slow, the gpu functions are likely going to show up as hot spots.