I am attempting to create a simple recurrent model that takes data generated by a parameterized data generating process as the input, and the parameters of the DGP as the output. The goal is to learn the parameters, given data from the model. The following code is trying to learn the parameter of a moving average order 1 model, given samples of size 100 from the model, with the prior of the MA(1) DGP being a U(0,1) distribution.
Because the target output is a scalar, only the final element of the RNN cell is used to evaluate loss.
The code runs, but the model does not learn: the loss remains constant during training. I suppose I’ve made some simple mistake, but I can’t find it. Any help would be appreciated!
When run, I get output like
julia> include("test2.jl") [ Info: Epoch 1 loss(x, θ) = 0.0734574248343706 loss(x, θ) = 0.0734574248343706 loss(x, θ) = 0.0734574248343706 [ Info: Epoch 2 loss(x, θ) = 0.0734574248343706 loss(x, θ) = 0.0734574248343706 loss(x, θ) = 0.0734574248343706 julia>
So, the model does not seem to learn during training. The code is:
# this tries to learn the MA(1) parameter given a sample of size n, # using a recurrent NN. using Flux using Base.Iterators # define the model L1 = LSTM(1, 10) # number of vars in sample by number of learned stats L2 = Dense(10, 5, tanh) L3 = Dense(5, 1) function m(x) # Flux.reset!(L1) L3(L2((L1.(x))[end])) end # Data generating process: returns sample of size n from MA(1) model, and the parameter that generated it function dgp(reps) n = 100 # for future: make this random? ys = zeros(Float32, reps, n) θs = zeros(Float32, reps) for i = 1:reps ϕ = rand(Float32) e = randn(Float32, n+1) ys[i,:] = e[2:end] .+ ϕ*e[1:end-1] # MA1 θs[i] = ϕ end return ys, θs end # make the data for the net: x is input, θ is output nsamples = 1000 x, θ = dgp(nsamples) # these are a nsamples X 100 matrix, and an nsamples vector # chunk the data into batches batches = [(x[ind,:], θ[ind]) for ind in partition(1:size(x,1), 50)] # train loss(x,y) = Flux.huber_loss(m(x), y; δ=0.1) # Define the loss function opt = ADAM(0.001) evalcb() = @show(loss(x, θ)) Flux.@epochs 2 Flux.train!(loss, Flux.params(m), batches, opt, cb = Flux.throttle(evalcb, 1))