# Help needed to train simple RNN using Flux

I am trying to learn a function of time series data using a neural network. In the example below, I’m simulating an AR(1), and I’m trying to learn the function

f(x_t, x_{t - 1}, x_{t - 2}, . . .) = x_t^2 + x_{t - 1}^2 * lag_weight

The code seems to learn the function well when lag_weight = 0, but if lag_weight > 0, the neural network does not seem to learn the function. I’ve provided a minimal working example below.

The strategy for training is that I simulate an AR(1) from scratch for a length of period `T`. I do this simulation `batch_size` number of times. Following the documentation on RNNs for Flux, I structure the data as a vector of length `T`, and each element of the vector is a matrix of dimensions `1 x batch_size`. I then calculate the predicted value using the RNN, and then I minimize the MSE against the true values. The fit does not improve much when I bump up `T` or the number of epochs. Also, when I increase the number of epochs, the fit for later periods (i.e. high values of `t`) tends to get worse.

``````using Flux, Distributions, Random, Plots

# Flags
configure_device = gpu # cpu or gpu, uses cpu if no gpu

# Step 0: Initialize parameters
ρ = Float32(0.1)
lag_weight = .1 # weight on lag
n_epochs = 1000
show_epoch_every = 100 # prints current epoch number
n_hidden = 10
n_train_ξ = 100 # Only used for evaluation, not training
n_train_x = 100
batch_size = 32
T = 100 # length of simulated time series
rnn_activation = tanh

########################
### Helper Functions ###
########################

function mdn_loss(a, ξ)
# Learn function f({ξ}ₜ) = ξₜ² + ξₜ₋₁²
lossed = 0f0
for t in 2:length(ξ) # ξ is a vector of row vectors
actual_value = ξ[t] .^ 2 - lag_weight .* ξ[t - 1] .^ 2
# actual_value = ξ[t] .^ 2
lossed = lossed + sum((a[t] - actual_value).^2)
end
return lossed / (length(ξ) - 1) / length(a[1])
end

################
### Main Run ###
################

function main(;n_epochs = n_epochs, n_hidden = n_hidden, n_train_ξ = n_train_ξ,
batch_size = batch_size, T = T,
rnn_activation = rnn_activation)

# Step 1: Create policy NN going from xᵢ to yᵢ
h_nn = Chain(
RNN(1, n_hidden, rnn_activation),
Dense(n_hidden, 2*n_hidden),
Dense(2*n_hidden, 1)
) |> configure_device

pars = Flux.params(h_nn)

for epoch in 1:n_epochs
if epoch % show_epoch_every == 0
@show epoch
end

# Our data is simulated so we can completely re-draw data
## in each epoch. This should improve convergence since
## we are seeing new data each time
ηs = rand(Normal(0f0, 1), batch_size, T)
ξs = Vector{Matrix{Float32}}(undef, T + 1)
ξs[1] = zeros(Float32, 1, batch_size)
for t in 2:(T+1)
ξs[t] = ρ .* ξs[t - 1] + reshape(ηs[:, t - 1], 1, batch_size)
end
ξs = ξs |> configure_device

# Reset the hidden state of the RNN
## Otherwise, every call takes the past hidden state as input
Flux.reset!(h_nn)

a_out = h_nn.(ξs)
Flux.reset!(h_nn)
mdn_loss(a_out, ξs)
end

# Update neural networks
Flux.update!(opt, pars, gs)
end

h_nn = h_nn |> cpu # ensure neural net is back on CPU
end

return h_nn
end # Main function complete

h_nn = main()

###########################
### Testing Results ###
###########################
# Draw ξ
ηs = rand(Normal(0f0, 1), n_train_ξ, T)
ξs = Vector{Matrix{Float32}}(undef, T + 1)
ξs[1] = zeros(Float32, 1, n_train_ξ)
for t in 2:(T+1)
ξs[t] = ρ .* ξs[t - 1] + reshape(ηs[:, t - 1], 1, n_train_ξ)
end

# Check
y_vec = Matrix(vcat(h_nn.(ξs)...)')[:, 2:end]
Flux.reset!(h_nn)
true_val = zeros(Float32, (n_train_ξ, T))
for t = 2:length(ξs)
true_val[:, t-1] = vec(ξs[t] .^2 + lag_weight .* ξs[t - 1] .^2)
end
err_vec = vec((true_val - y_vec) .^ 2)
plot_ξ = Matrix(vcat(ξs...)')[:, 2:end]
println("The MSE is \$(round(mean(err_vec), digits = 3))")
scatter(vec(plot_ξ), vec(true_val), label = "Truth")
scatter!(vec(plot_ξ), vec(y_vec), label = "NN")
``````
1 Like