I am trying to learn a function of time series data using a neural network. In the example below, I’m simulating an AR(1), and I’m trying to learn the function
f(x_t, x_{t - 1}, x_{t - 2}, . . .) = x_t^2 + x_{t - 1}^2 * lag_weight
The code seems to learn the function well when lag_weight = 0, but if lag_weight > 0, the neural network does not seem to learn the function. I’ve provided a minimal working example below.
The strategy for training is that I simulate an AR(1) from scratch for a length of period T
. I do this simulation batch_size
number of times. Following the documentation on RNNs for Flux, I structure the data as a vector of length T
, and each element of the vector is a matrix of dimensions 1 x batch_size
. I then calculate the predicted value using the RNN, and then I minimize the MSE against the true values. The fit does not improve much when I bump up T
or the number of epochs. Also, when I increase the number of epochs, the fit for later periods (i.e. high values of t
) tends to get worse.
using Flux, Distributions, Random, Plots
# Flags
configure_device = gpu # cpu or gpu, uses cpu if no gpu
# Step 0: Initialize parameters
ρ = Float32(0.1)
lag_weight = .1 # weight on lag
n_epochs = 1000
show_epoch_every = 100 # prints current epoch number
n_hidden = 10
n_train_ξ = 100 # Only used for evaluation, not training
n_train_x = 100
batch_size = 32
T = 100 # length of simulated time series
rnn_activation = tanh
########################
### Helper Functions ###
########################
function mdn_loss(a, ξ)
# Learn function f({ξ}ₜ) = ξₜ² + ξₜ₋₁²
lossed = 0f0
for t in 2:length(ξ) # ξ is a vector of row vectors
actual_value = ξ[t] .^ 2 - lag_weight .* ξ[t - 1] .^ 2
# actual_value = ξ[t] .^ 2
lossed = lossed + sum((a[t] - actual_value).^2)
end
return lossed / (length(ξ) - 1) / length(a[1])
end
################
### Main Run ###
################
function main(;n_epochs = n_epochs, n_hidden = n_hidden, n_train_ξ = n_train_ξ,
batch_size = batch_size, T = T,
rnn_activation = rnn_activation)
# Step 1: Create policy NN going from xᵢ to yᵢ
h_nn = Chain(
RNN(1, n_hidden, rnn_activation),
Dense(n_hidden, 2*n_hidden),
Dense(2*n_hidden, 1)
) |> configure_device
pars = Flux.params(h_nn)
opt = ADAM()
for epoch in 1:n_epochs
if epoch % show_epoch_every == 0
@show epoch
end
# Our data is simulated so we can completely re-draw data
## in each epoch. This should improve convergence since
## we are seeing new data each time
ηs = rand(Normal(0f0, 1), batch_size, T)
ξs = Vector{Matrix{Float32}}(undef, T + 1)
ξs[1] = zeros(Float32, 1, batch_size)
for t in 2:(T+1)
ξs[t] = ρ .* ξs[t - 1] + reshape(ηs[:, t - 1], 1, batch_size)
end
ξs = ξs |> configure_device
# Reset the hidden state of the RNN
## Otherwise, every call takes the past hidden state as input
Flux.reset!(h_nn)
# Compute gradients
gs = Flux.gradient(pars) do
a_out = h_nn.(ξs)
Flux.reset!(h_nn)
mdn_loss(a_out, ξs)
end
# Update neural networks
Flux.update!(opt, pars, gs)
end
h_nn = h_nn |> cpu # ensure neural net is back on CPU
end
return h_nn
end # Main function complete
h_nn = main()
###########################
### Testing Results ###
###########################
# Draw ξ
ηs = rand(Normal(0f0, 1), n_train_ξ, T)
ξs = Vector{Matrix{Float32}}(undef, T + 1)
ξs[1] = zeros(Float32, 1, n_train_ξ)
for t in 2:(T+1)
ξs[t] = ρ .* ξs[t - 1] + reshape(ηs[:, t - 1], 1, n_train_ξ)
end
# Check
y_vec = Matrix(vcat(h_nn.(ξs)...)')[:, 2:end]
Flux.reset!(h_nn)
true_val = zeros(Float32, (n_train_ξ, T))
for t = 2:length(ξs)
true_val[:, t-1] = vec(ξs[t] .^2 + lag_weight .* ξs[t - 1] .^2)
end
err_vec = vec((true_val - y_vec) .^ 2)
plot_ξ = Matrix(vcat(ξs...)')[:, 2:end]
println("The MSE is $(round(mean(err_vec), digits = 3))")
scatter(vec(plot_ξ), vec(true_val), label = "Truth")
scatter!(vec(plot_ξ), vec(y_vec), label = "NN")