I am trying to reproduce a simple NN model in Julia using Lux.jl, originally using the R package forecast which trains an AR(2) model using nnet internally. I am almost able to reproduce it, but my fitted model is off by one index!. See the figures before
R output
Julia Flux output
As you can see, the fit would be perfect is the red curve was shifted by one unit. I don’t know where my problem is
- I am guessing the problem is how I structure the input and response data for training
- Could also be how I am plotting this, but I don’t think that’s the issue.
Reproducible code:
using Lux, Random, Optimisers, Zygote
using CSV, DataFrames, DataFramesMeta
using Statistics
rng = MersenneTwister()
data = [2258, 2239, 2156, 2460, 2817, 3061, 3373, 2249, 2509, 2927, 2505, 2276, 1317, 1371, 1337, 1469, 1601, 1817, 2467, 2712, 2840, 3510, 3065, 2731, 2736, 2479, 2590, 2908, 2964, 2727, 2451, 2130, 2134, 2637, 2886, 3243, 3437, 3308, 2725, 2501, 2256, 2333, 1814, 1756, 1361, 1245, 1194, 1077, 1172, 980, 833, 759, 551, 556, 433, 359, 372, 350, 329, 375, 235, 209]
# set up x (input) and y (response) vectors for training 
_xs = data[1:46]  # use only 46 points for the input data 
xs = stack(zip(_xs, @view _xs[2:end])) # stack them i.e. (x1, x2), (x2, x3), (x3, x4)...
ys = data[3:47] # need to start at index 3 since input (x1, x2) -> y (at index 3)
# Setup Lux NN 
Random.seed!(rng, 24)
model = Lux.Chain(Dense(2 => 2, relu), Dense(2 => 1, relu))
# run this to see the output of the un-trained model 
# ps, st = Lux.setup(rng, model) # models don't hold state or parameters 
# ypred, st = Lux.apply(model, xs, ps, st) # run the model on data
# Lets train the model 
function loss_function(model, ps, st, data)
    y_pred, st = Lux.apply(model, data[1], ps, st)
    mse_loss = mean(abs2, y_pred .- data[2])
    return mse_loss, st, ()
end
function training_loop(tstate::Lux.Experimental.TrainState, vjp, data, epochs)
    for epoch in 1:epochs
        grads, loss, stats, tstate = Lux.Training.compute_gradients(
            vjp, loss_function, data, tstate)
        #println("Epoch: $(epoch) || Loss: $(loss)")
        tstate = Lux.Training.apply_gradients(tstate, grads)
    end
    return tstate
end
opt = Adam(0.03f0) # setup optimizer 
vjp_rule = Lux.Training.AutoZygote() # Zygote for our AD requirements.    
Random.seed!(rng, 24)
tstate = Lux.Training.TrainState(rng, model, opt)
tstate = training_loop(tstate, vjp_rule, (xs, ys), 250)
y_pred = Lux.apply(tstate.model, xs, tstate.parameters, tstate.states)[1]
# can plot y_pred against ys (data)  (uncomment to plot using `Gnuplot.jl`)
# @gp "reset" 
# @gp :- ys "with lines title 'data' lw 2 lc 'black'" 
# @gp :- vec(y_pred) "with lines title 'Julia NN' lw 2 lc 'red'" 
So my guess is that I am just stacking the xs wrong or something to do with how I am setting up the data.
Notes:
- The Rpackageforecastactually transforms the input values (scales them + BoxCox transformation) though I am not sure that’s needed for me and not sure if that’s the problem.
- I have also used Flux.jlbut same issue – which leads me to believe that the problem is in how I am structuring my input/response vectors.

