Understanding `Flux.Data.DataLoader` when training an LSTM model

Hello, folks

For a while I have been playing with Flux.jl to learn how to put together ML models, train them, and use them for predictive tasks. At the moment, I am stuck with the training stage of an LSTM model.

Problem setup

Here’s a short summary of the problem: for a long time series (20 days of data with a 15-second granularity), break the data into smaller sequences, use these sequences to “fit” an LSTM model, and then get a new set of forecasts on the go.

Below there’s a piece of code for a function that trains the LSTM model. The setup is simple: the model is trained either until the loss is less than the loss tolerance or until it reaches a maximum epoch number. The loss is defined as the RMSE.

Initially, I had the Flux.Data.DataLoader as in line (!A!) commented below. But the results were poor. The data being modelled are a non-stationary, highly irregular time series, and the forecasts I got were always a constant (or really close to one). But as I looked for far too long at the code, I wondered what data the model was “seeing” during training. With line (!A!) not commented, it was my impression the training was seeing only part of the data.

As I ran Flux.Data.DataLoader and query data and labels, I see a vector of size 6 (the input size) and 1, respectively. That made sense, but I don’t know how these data and labels should change as the training goes on.

Then I added line (!B!) and saw that, no matter how many epochs it took, the value being printed was always the same. This has to mean that for as long as the model got trained, the first value of data was the same, which means that the vector data never got updated (nor did the 1-element vector labels). And this doesn’t seem right.

How to properly use the Flux.Data.DataLoader? What should the variables data and labels be with every iteration inside the while loop?

I know that I didn’t provide a working example, but I am more than happy to clarify any point that might help you understand the problem better.

The training stage

function train_lstm(model, data_, labels_, lr, iter, max_epoch, loss_tolerance)
    """
    train_lstm: train an LSTM model

    - Inputs
        + model: a Flux.jl LSTM chain model 
        + data: training data 
        + labels: training labels
        + lr: learning rate 
        + epochs: number of steps for the training stage
    
    - Output
        + None. The training is done "in place". The already created "model" will
                be updated. 
    """
    # Define the loss function .................................................
    loss(x, y) = sqrt(mean((model(x) .- y) .^ 2))

    # Define optimizers ........................................................
    opt = ADAM(lr)

    # Create data loader .......................................................
    # (!A!) data, labels = Flux.Data.DataLoader((data_, labels_), batchsize=1, shuffle=true) |> Iterators.first
       
    # Training the model .......................................................
    # Set desired loss tolerance
    loss_tol = loss_tolerance

    # Initialize current loss
    current_loss = Inf

    # Set a max. number of iterations to prevent infinite loops
    max_iterations = max_epoch
    iteration = 0

    while current_loss > loss_tol && iteration < max_iterations
        data, labels = Flux.Data.DataLoader((data_, labels_), batchsize=1, shuffle=false) |> Iterators.first

        # (!B!) println(data[1])

        Flux.train!(loss, params(model), Iterators.repeated((data, labels), 1), opt)

        # Calculate the current loss and iteration
        current_loss = loss(data, labels)
        iteration += 1

        if isnan(current_loss)
            break
        end
    end
end

Extra Information

    model = Chain(
        Flux.LSTM(input_size => hidden_size),
        Flux.LSTM(hidden_size => aux),
        Flux.Dense(aux => output_size, sigmoid)
    )
    return (model)

This is the model I settled with after a few iterations. I looked for some sort of traditional/disciplined way to choose and place layers in an RNN model, but didn’t find much. I didn’t want a model too complex so I “only” have three layers. The number of parameters, though, can be daunting if the input size and number of hidden nodes are large.

In the absence of a MWE, some high level thoughts that should help you on your way:

RNNs in Flux (and PyTorch, and a couple other libraries) are different to work with compared to other NN layers because they require a different input format. In the case of Flux, this is either a Vector of arrays (features x batch) x time as described in Recurrence · Flux, or a 3D array of features x batch x time.

This has a couple of immediately obvious implications. The first is that train! is probably out of the picture, because it doesn’t give you a change to post-process data from the dataloader into one of the above forms before running AD on it. The second is that DataLoader wants to slice along the batch dimension, but the last/outer dimension for RNN inputs is the time one. The best way I’ve found to handle this is to pass the data to the DataLoader as (features x time) x batch or features x time x batch, then reverse the last two dimensions (e.g. using stack or permutedims) once you have a batch from the loader.

P.S. Make sure you’re calling Flux.reset! as mentioned in Recurrence · Flux on your model after every batch because RNN layers are stateful. I would also recommend getting rid of params since the whole “implicit params” mechanism is inefficient and going away in a future version of Flux + Zygote. The docs should have up to date versions of how to use both packages without relying on that.

5 Likes