I tried to train a LSTM model in Flux.jl, but whatever I do, the gpu usage only reaches about 50%. But I can’t make the batchsize bigger, as the vram is nearly full. I would appreciate any kind of help.
module Lstm
using Flux
using CUDA
using cuDNN
using CSV
using DataFrames
using ProgressMeter
using Zygote
using Plots
using JLD2
using Statistics
CUDA.allowscalar(false)
function predicter()
device = gpu_device()
epochs = 1_000_000
batchsize = 1000
initLearn = 0.001
clipping = 1
plotting_mean_threshold = 2
LSTM_NN = Chain(
LSTM(14 => 256),
LSTM(256 => 128),
Dense(128 => 64, relu),
Dense(64 => 1)
) |> device
opt_state = Flux.setup(OptimiserChain(ClipNorm(clipping), AdamW(initLearn)), LSTM_NN)
inloc = "data"
losses = []
averaged_losses = []
train_data = Float32.(load(inloc * "/big_train_data.jld2", "collected_train_data"))
true_data = Float32.(load(inloc * "/big_true_data.jld2", "collected_true_data"))
println(size(train_data), typeof(train_data), size(true_data), typeof(true_data))
loader = Flux.DataLoader((data=train_data, label=true_data); batchsize=batchsize, shuffle=true)
@showprogress for epoch_idx in 1:epochs
for batch in loader
X, Y = batch |> device
loss, grads = Zygote.withgradient(LSTM_NN) do m
output = m(X)[1, end, :]
Flux.Losses.mse(output, Y)
end
Flux.update!(opt_state, LSTM_NN, grads[1])
push!(losses, cpu(loss))
end
states = Flux.state(cpu(LSTM_NN))
println("Finished_epoch_$epoch_idx")
M = mean(losses)
S = std(losses)
filtered_losses = filter(x -> abs(x - M) <= plotting_mean_threshold * S, losses)
push!(averaged_losses, mean(filtered_losses))
resize!(losses, 0)
display(plot(1:length(averaged_losses), averaged_losses))
end
end
predicter()
end