LSTM training for a sequence of multiple features using a batch size 30

I am trying to do batch training using LSTM for a time series data with multiple features.
Assuming I have 5000 samples and 5 features for each sample. The input uses 14 days into the past and the output is a single value on the 15th day. (My time step is 14). The size of my data is the following:
xtrain: (5000,14,5)
ytrain: (5000,1,1)

My model is below. How do I train my data by using a batch size of 30? I tried using DataLoader and Flux.train but they are both not working with this input size:

model = Chain(
LSTM(5, 20),
Dense(20, 1,σ)

#evaluating prediction for the input sequence
function eval_model(x)
inputs = [x[1,t,:] for t in 1:14]
output = model.(inputs)

L(x, y) = Flux.mse(eval_model(x), y)
opt = ADAM(0.001)

#working fine until here

#creating batches and training (not working)
train_loader = DataLoader((trainX, trainY), batchsize=30, shuffle=true)
@time Flux.train!(L, params(model), trainloaded, opt)

Have you tried looking at the error Flux raises when the dataloader is constructed?

julia> using Flux, Flux.Data

julia> X = rand(5000,14,5);

julia> Y = rand(5000,1,1);

julia> train_loader = DataLoader((X, Y), batchsize=30)
ERROR: DimensionMismatch("All data should contain same number of observations")
 [1] _nobs
   @ ~/.julia/packages/Flux/0c9kI/src/data/dataloader.jl:104 [inlined]
 [2] DataLoader(data::Tuple{Array{Float64, 3}, Array{Float64, 3}}; batchsize::Int64, shuffle::Bool, partial::Bool, rng::Random._GLOBAL_RNG)
   @ Flux.Data ~/.julia/packages/Flux/0c9kI/src/data/dataloader.jl:73
 [3] top-level scope
   @ REPL[9]:1
 [4] top-level scope
   @ ~/.julia/packages/CUDA/3VnCC/src/initialization.jl:81

Remember that Flux expects the batch dimension to be the last in any data instead of the first (what you may be familiar with from Python libraries). If you make that change:

julia> X = rand(14,5,5000);

julia> Y = rand(1,1,5000);

julia> train_loader = DataLoader((X, Y), batchsize=30);

julia> size.(first(train_loader))
((14, 5, 30), (1, 1, 30))

Now the dataloader batching works as expected.


Thanks a lot for your reply! I adjusted the data using the function ‘‘reshape’’ and it worked.

1 Like

hi, I run this code, but I receive this message " attempt to access 14×5×30 Array{Float64, 3} at index [1, 6, 1:30]"!, what is wrong?

How do you reshape the data