Improve performance in Flux.jl

I tried to train a LSTM model in Flux.jl, but whatever I do, the gpu usage only reaches about 50%. But I can’t make the batchsize bigger, as the vram is nearly full. I would appreciate any kind of help.

module Lstm
    using Flux
    using CUDA
    using cuDNN
    using CSV
    using DataFrames
    using ProgressMeter
    using Zygote
    using Plots
    using JLD2
    using Statistics


    CUDA.allowscalar(false)

    function predicter()
        device = gpu_device()
        epochs = 1_000_000
        batchsize = 1000
        initLearn = 0.001
        clipping = 1
        plotting_mean_threshold = 2


        LSTM_NN = Chain(
            LSTM(14 => 256),
            LSTM(256 => 128),
            Dense(128 => 64, relu),
            Dense(64 => 1)
        ) |> device
        
        opt_state = Flux.setup(OptimiserChain(ClipNorm(clipping), AdamW(initLearn)), LSTM_NN)

        inloc = "data"
        losses = []
        averaged_losses = []

        train_data = Float32.(load(inloc * "/big_train_data.jld2", "collected_train_data"))
        true_data = Float32.(load(inloc * "/big_true_data.jld2", "collected_true_data"))
        println(size(train_data), typeof(train_data), size(true_data), typeof(true_data))

        loader = Flux.DataLoader((data=train_data, label=true_data); batchsize=batchsize, shuffle=true)

        @showprogress for epoch_idx in 1:epochs

            for batch in loader
                X, Y = batch |> device
                loss, grads = Zygote.withgradient(LSTM_NN) do m 
                    output = m(X)[1, end, :]
                    Flux.Losses.mse(output, Y)
                end
                Flux.update!(opt_state, LSTM_NN, grads[1]) 
                push!(losses, cpu(loss))
            end
            
            states = Flux.state(cpu(LSTM_NN))
            println("Finished_epoch_$epoch_idx")

            M = mean(losses)
            S = std(losses)
            filtered_losses = filter(x -> abs(x - M) <= plotting_mean_threshold * S, losses)
            push!(averaged_losses, mean(filtered_losses))
            resize!(losses, 0)
            display(plot(1:length(averaged_losses), averaged_losses))

  
        end
    end

    predicter()
end