I am trying to reproduce a simple NN model in Julia using Lux.jl
, originally using the R package forecast
which trains an AR(2) model using nnet
internally. I am almost able to reproduce it, but my fitted model is off by one index!. See the figures before
R output
Julia Flux output
As you can see, the fit would be perfect is the red curve was shifted by one unit. I don’t know where my problem is
- I am guessing the problem is how I structure the input and response data for training
- Could also be how I am plotting this, but I don’t think that’s the issue.
Reproducible code:
using Lux, Random, Optimisers, Zygote
using CSV, DataFrames, DataFramesMeta
using Statistics
rng = MersenneTwister()
data = [2258, 2239, 2156, 2460, 2817, 3061, 3373, 2249, 2509, 2927, 2505, 2276, 1317, 1371, 1337, 1469, 1601, 1817, 2467, 2712, 2840, 3510, 3065, 2731, 2736, 2479, 2590, 2908, 2964, 2727, 2451, 2130, 2134, 2637, 2886, 3243, 3437, 3308, 2725, 2501, 2256, 2333, 1814, 1756, 1361, 1245, 1194, 1077, 1172, 980, 833, 759, 551, 556, 433, 359, 372, 350, 329, 375, 235, 209]
# set up x (input) and y (response) vectors for training
_xs = data[1:46] # use only 46 points for the input data
xs = stack(zip(_xs, @view _xs[2:end])) # stack them i.e. (x1, x2), (x2, x3), (x3, x4)...
ys = data[3:47] # need to start at index 3 since input (x1, x2) -> y (at index 3)
# Setup Lux NN
Random.seed!(rng, 24)
model = Lux.Chain(Dense(2 => 2, relu), Dense(2 => 1, relu))
# run this to see the output of the un-trained model
# ps, st = Lux.setup(rng, model) # models don't hold state or parameters
# ypred, st = Lux.apply(model, xs, ps, st) # run the model on data
# Lets train the model
function loss_function(model, ps, st, data)
y_pred, st = Lux.apply(model, data[1], ps, st)
mse_loss = mean(abs2, y_pred .- data[2])
return mse_loss, st, ()
end
function training_loop(tstate::Lux.Experimental.TrainState, vjp, data, epochs)
for epoch in 1:epochs
grads, loss, stats, tstate = Lux.Training.compute_gradients(
vjp, loss_function, data, tstate)
#println("Epoch: $(epoch) || Loss: $(loss)")
tstate = Lux.Training.apply_gradients(tstate, grads)
end
return tstate
end
opt = Adam(0.03f0) # setup optimizer
vjp_rule = Lux.Training.AutoZygote() # Zygote for our AD requirements.
Random.seed!(rng, 24)
tstate = Lux.Training.TrainState(rng, model, opt)
tstate = training_loop(tstate, vjp_rule, (xs, ys), 250)
y_pred = Lux.apply(tstate.model, xs, tstate.parameters, tstate.states)[1]
# can plot y_pred against ys (data) (uncomment to plot using `Gnuplot.jl`)
# @gp "reset"
# @gp :- ys "with lines title 'data' lw 2 lc 'black'"
# @gp :- vec(y_pred) "with lines title 'Julia NN' lw 2 lc 'red'"
So my guess is that I am just stacking the xs
wrong or something to do with how I am setting up the data.
Notes:
- The
R
packageforecast
actually transforms the input values (scales them + BoxCox transformation) though I am not sure that’s needed for me and not sure if that’s the problem. - I have also used
Flux.jl
but same issue – which leads me to believe that the problem is in how I am structuring my input/response vectors.