Timeseries model training using Lux.jl

Hi,
I’m using the Lux.jl simple LSTM classifier example ( Training a Simple LSTM | Lux.jl Docs ) to try and train a simple LSTM to predict timeseries data. I’m having issues however, as the LSTM model example is designed for a single output, so I’m not sure how to predict a timeseries.

Here are my model and model training functions:

using ADTypes, Lux, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random, Statistics

function LSTMChain(in_dims, hidden_dims, out_dims)
    lstm_cell = LSTMCell(in_dims => hidden_dims)
    classifier = Dense(hidden_dims => out_dims)
    return @compact(; lstm_cell, classifier) do x
        x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
        y, carry = lstm_cell(x_init)
        for x in x_rest
            y, carry = lstm_cell((x, carry))
        end
        @return vec(classifier(y))
    end
end

function train_model_lstm(model, train_data, val_data, epochs; lossfn = MAELoss(),
    rng = Xoshiro(0), opt=Adam(0.03), batch_size=50)
    ps, st = Lux.setup(rng, model)

    train_state = Training.TrainState(model, ps, st, opt)

    for epoch in 1:epochs
        # Train the model
        for (x, y) in eachobs(train_data, batchsize=batch_size)
            (_, loss, _, train_state) = Training.single_train_step!(
                AutoZygote(), lossfn, (x, y), train_state)

            @printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
        end

        # Validate the model
        st_ = Lux.testmode(train_state.states)
        for (x, y) in eachobs(val_data, batchsize=batch_size)
            ŷ, st_ = model(x, train_state.parameters, st_)
            loss = lossfn(ŷ, y)
            @printf "Validation: Loss %4.5f " loss
        end
    end

    return (train_state.parameters, train_state.states)
end

Then training the model is run

# Split data into training and testing with (split_size*100)% training data
cv_data, test_data = MLUtils.splitobs((X[:, vars_ind]', Y'); at=0.60)

model = LSTMChain(length(vars_ind), 128, 1)

ps_trained, st_trained = train_model_lstm(model, cv_data, test_data, 100)

The data has 77 timesteps, 100 scenarios and 5 features, but is arranged so that the timeseries are stacked in a (77005) matrix, which is then flipped to match (featuresbatch_size). This stacking of timeseries works fine with the chained dense layers architecture, (such as Chain(Dense(input, hidden), Dense(hidden,output)), but I’m guessing my issue is the LSTM layers behaving differently. The error I get is
DimensionMismatch: loss function expects size(ŷ) = (1,) to match size(y) = (1, 50)
(here I am using a batch size of 50)

Should I not be stacking timeseries and instead be looping over timesteps and training for each time step? Or is something else going wrong?

Thanks!