I am attempting to create a simple recurrent model that takes data generated by a parameterized data generating process as the input, and the parameters of the DGP as the output. The goal is to learn the parameters, given data from the model. The following code is trying to learn the parameter of a moving average order 1 model, given samples of size 100 from the model, with the prior of the MA(1) DGP being a U(0,1) distribution.
Because the target output is a scalar, only the final element of the RNN cell is used to evaluate loss.
The code runs, but the model does not learn: the loss remains constant during training. I suppose I’ve made some simple mistake, but I can’t find it. Any help would be appreciated!
When run, I get output like
julia> include("test2.jl")
[ Info: Epoch 1
loss(x, θ) = 0.0734574248343706
loss(x, θ) = 0.0734574248343706
loss(x, θ) = 0.0734574248343706
[ Info: Epoch 2
loss(x, θ) = 0.0734574248343706
loss(x, θ) = 0.0734574248343706
loss(x, θ) = 0.0734574248343706
julia>
So, the model does not seem to learn during training. The code is:
# this tries to learn the MA(1) parameter given a sample of size n,
# using a recurrent NN.
using Flux
using Base.Iterators
# define the model
L1 = LSTM(1, 10) # number of vars in sample by number of learned stats
L2 = Dense(10, 5, tanh)
L3 = Dense(5, 1)
function m(x)
# Flux.reset!(L1)
L3(L2((L1.(x))[end]))[]
end
# Data generating process: returns sample of size n from MA(1) model, and the parameter that generated it
function dgp(reps)
n = 100 # for future: make this random?
ys = zeros(Float32, reps, n)
θs = zeros(Float32, reps)
for i = 1:reps
ϕ = rand(Float32)
e = randn(Float32, n+1)
ys[i,:] = e[2:end] .+ ϕ*e[1:end-1] # MA1
θs[i] = ϕ
end
return ys, θs
end
# make the data for the net: x is input, θ is output
nsamples = 1000
x, θ = dgp(nsamples) # these are a nsamples X 100 matrix, and an nsamples vector
# chunk the data into batches
batches = [(x[ind,:], θ[ind]) for ind in partition(1:size(x,1), 50)]
# train
loss(x,y) = Flux.huber_loss(m(x), y; δ=0.1) # Define the loss function
opt = ADAM(0.001)
evalcb() = @show(loss(x, θ))
Flux.@epochs 2 Flux.train!(loss, Flux.params(m), batches, opt, cb = Flux.throttle(evalcb, 1))