Help needed to train simple RNN using Flux

I am trying to learn a function of time series data using a neural network. In the example below, I’m simulating an AR(1), and I’m trying to learn the function

f(x_t, x_{t - 1}, x_{t - 2}, . . .) = x_t^2 + x_{t - 1}^2 * lag_weight

The code seems to learn the function well when lag_weight = 0, but if lag_weight > 0, the neural network does not seem to learn the function. I’ve provided a minimal working example below.

The strategy for training is that I simulate an AR(1) from scratch for a length of period T. I do this simulation batch_size number of times. Following the documentation on RNNs for Flux, I structure the data as a vector of length T, and each element of the vector is a matrix of dimensions 1 x batch_size. I then calculate the predicted value using the RNN, and then I minimize the MSE against the true values. The fit does not improve much when I bump up T or the number of epochs. Also, when I increase the number of epochs, the fit for later periods (i.e. high values of t) tends to get worse.

using Flux, Distributions, Random, Plots

# Flags
configure_device = gpu # cpu or gpu, uses cpu if no gpu

# Step 0: Initialize parameters
ρ = Float32(0.1)
lag_weight = .1 # weight on lag
n_epochs = 1000
show_epoch_every = 100 # prints current epoch number
n_hidden = 10
n_train_ξ = 100 # Only used for evaluation, not training
n_train_x = 100
batch_size = 32
T = 100 # length of simulated time series
rnn_activation = tanh

########################
### Helper Functions ###
########################

function mdn_loss(a, ξ)
    # Learn function f({ξ}ₜ) = ξₜ² + ξₜ₋₁²
    lossed = 0f0
    for t in 2:length(ξ) # ξ is a vector of row vectors
        actual_value = ξ[t] .^ 2 - lag_weight .* ξ[t - 1] .^ 2
        # actual_value = ξ[t] .^ 2
        lossed = lossed + sum((a[t] - actual_value).^2)
    end
    return lossed / (length(ξ) - 1) / length(a[1])
end

################
### Main Run ###
################

function main(;n_epochs = n_epochs, n_hidden = n_hidden, n_train_ξ = n_train_ξ,
              batch_size = batch_size, T = T,
              rnn_activation = rnn_activation)

    # Step 1: Create policy NN going from xᵢ to yᵢ
    h_nn = Chain(
            RNN(1, n_hidden, rnn_activation),
            Dense(n_hidden, 2*n_hidden),
            Dense(2*n_hidden, 1)
        ) |> configure_device

    pars = Flux.params(h_nn)
    opt = ADAM()

    for epoch in 1:n_epochs
        if epoch % show_epoch_every == 0
            @show epoch
        end

        # Our data is simulated so we can completely re-draw data
        ## in each epoch. This should improve convergence since
        ## we are seeing new data each time
        ηs = rand(Normal(0f0, 1), batch_size, T)
        ξs = Vector{Matrix{Float32}}(undef, T + 1)
        ξs[1] = zeros(Float32, 1, batch_size)
        for t in 2:(T+1)
            ξs[t] = ρ .* ξs[t - 1] + reshape(ηs[:, t - 1], 1, batch_size)
        end
        ξs = ξs |> configure_device

        # Reset the hidden state of the RNN
        ## Otherwise, every call takes the past hidden state as input
        Flux.reset!(h_nn)

        # Compute gradients
        gs = Flux.gradient(pars) do
            a_out = h_nn.(ξs)
            Flux.reset!(h_nn)
            mdn_loss(a_out, ξs)
        end

        # Update neural networks
        Flux.update!(opt, pars, gs)
    end

    h_nn = h_nn |> cpu # ensure neural net is back on CPU
end

    return h_nn
end # Main function complete

h_nn = main()

###########################
### Testing Results ###
###########################
# Draw ξ
ηs = rand(Normal(0f0, 1), n_train_ξ, T)
ξs = Vector{Matrix{Float32}}(undef, T + 1)
ξs[1] = zeros(Float32, 1, n_train_ξ)
for t in 2:(T+1)
    ξs[t] = ρ .* ξs[t - 1] + reshape(ηs[:, t - 1], 1, n_train_ξ)
end

# Check
y_vec = Matrix(vcat(h_nn.(ξs)...)')[:, 2:end]
Flux.reset!(h_nn)
true_val = zeros(Float32, (n_train_ξ, T))
for t = 2:length(ξs)
    true_val[:, t-1] = vec(ξs[t] .^2 + lag_weight .* ξs[t - 1] .^2)
end
err_vec = vec((true_val - y_vec) .^ 2)
plot_ξ = Matrix(vcat(ξs...)')[:, 2:end]
println("The MSE is $(round(mean(err_vec), digits = 3))")
scatter(vec(plot_ξ), vec(true_val), label = "Truth")
scatter!(vec(plot_ξ), vec(y_vec), label = "NN")
1 Like