RNN is not trained

I’m trying to reproduce this example but my model won’t train.
I use only train dataset to achieve overfitting and get just a line as prediction:

using Pandas
using Plots
using Flux

df = read_csv("airline-passengers.csv")

# Get the two datasets of one point time-series
# The train data and target 
train = values(df[1:length(df)-2].Passengers)
target_train = values(df[2:length(df)].Passengers)

Plots.plot(train)
plot!(target_train)

train = [[convert(Float32, train[i])] for i = 1:length(train)]
train = reshape(train, (length(train), 1))
target_train = reshape(target_train, (length(target_train), 1))
target_train = [convert(Float32, target_train[i]) for i = 1:length(target_train)]

# Model and Training 

model = Chain(RNN(1,4), Dense(4, 1, x->x))
loss(x, y) = sum(abs2.(eval_model(x) .- y))

function eval_model(x)
    #@show x
    out = model.(x)
    @show out
    #Flux.reset!(model)
    out = [out[i][1] for i=1:length(out)]
end


ps = Flux.params(model);
opt = Flux.ADAM(0.01)
epochs = 100

for epoch in 1:epochs
        @show epoch
        #loss(train, target_train)
        gs = Flux.gradient(ps) do
             @show loss(train, target_train)
        end
        Flux.Optimise.update!(opt, ps, gs)
end

out = model.(train)
out = [out[i][1] for i=1:length(out)]
Plots.plot(out)


Does anybody see any problem?

reset! needs to be called after every minibatch, else the network will keep on accumulating with residual state from the last set of sequences.

I uncommented line with reset. But only Dense weights have non-zero grad the rest (all for RNN) are zeros

The line with reset! isn’t called at all during training. You’d need to move it to after gs = or update!.

for epoch in 1:epochs
        @show epoch
        #loss(train, target_train)
        gs = Flux.gradient(ps) do
             @show loss(train, target_train)
        end
        Flux.Optimise.update!(opt, ps, gs)
        Flux.reset!(model)
end

I moved now.
If I use Dense(1,1) it trains, if add RNN, it is not updated

Maybe try

out = map(model, x)

I know there was a bug sometime back where RNNs wouldn’t work as expected w/ broadcasting in Zygote. I’m not sure if it has been fixed. If that doesn’t work I have an example using 0.10.x/0.11.x LSTMs (although glancing at the code it should work for the new 0.12.x revamp), which should work w/ RNNs as well.

2 Likes

Here’s an example that has been a little bit reworked that looks to train properly:

using CSV
using DataFrames
using Flux
using Plots

df = DataFrame(CSV.File("airline-passengers.txt"))

# Get the two datasets of one point time-series
# The train data and target 
X_raw = Float32.(df[!,:Passengers][1:nrow(df)-1]) ./ 100
Y_raw = Float32.(df[!,:Passengers][2:nrow(df)]) ./ 100

Plots.plot(X_raw)
plot!(Y_raw)

X = [X_raw[i:i] for i in 1:length(X_raw)]
Y = Y_raw

# Model and Training 
model = Chain(RNN(1,4), Dense(4, 1), x -> reshape(x, :))
# model.(X)
loss(x, y) = sum(abs2.(Flux.stack(model.(x), 1) .- y))

ps = Flux.params(model);
opt = Flux.ADAM(0.01)
epochs = 100

for epoch in 1:epochs
        @show epoch
        gs = Flux.gradient(ps) do
             loss(X, Y)
        end
        Flux.Optimise.update!(opt, ps, gs)
end

out = vec(Flux.stack(model.(X), 1))
Plots.plot(out)
plot!(Y)

Something that may have contributed to the issue is the scaling the of the features/target. Here I roughly divided each by 100.

1 Like