RNN Not learning

I have been trying to build an RNN in Julia using Flux and it does not seem to converge or learn at all.

The code is learning on a parsed Stanford Sentiment Treebank (i.e. An array of words and a sentiment score)

The code is the following:

using Flux
using Flux: crossentropy, softmax, onehot
using Statistics: mean

#Quality of life assignments 

#Data Pre-Processing

function cathot(x)
    #Converts a categorical number into a 1 hot vector.
    Float64.(onehot(x, [0,1,2,3,4]))
end

seqembed(list) = [embedding(x) for x in list] #Embeds the sequence of words into a sequences of vectors via Word2Vec.

etrainxs , etrainys = seqembed.(ssttrainxs), cathot.(ssttrainys)
evalidxs , evalidys = seqembed.(sstvalidxs), cathot.(sstvalidys)
etestxs , etestys = seqembed.(ssttestxs), cathot.(ssttestys)

#Defining the models - Recurrent cells with a 3 level MLP

n = 32

learning_rate = 0.003

#Defining subcells to get training to work.

RNN_Part = RNN(300, n, identity)

LSTM_Part = LSTM(300, n)

GRU_part = GRU(300, n)

MLP = Chain(Dense(n,n, Flux.σ),
            #Dense(n,n, Flux.σ),
            Dense(n,5,Flux.σ), softmax)

#Defining full models

function model(x, encoder, decider)
	state = encoder.(x)[end]     # the last element, so the last hidden state   
	Flux.reset!(encoder)                   
	decider(state)              # this returns a vector of a single element, so take the element  
end 

#Defining Loss and Optmisiers

#function loss(x, y)
#
#    # Reset internal Recurrent Cell state.
#    Flux.truncate!(RNN_Part)
#
#    # Iterate over every timepoint in the sentence.
#    y_hat_1 = RNN_Part.(x)[end]
#
#    # Take the very last output from the recurrent section, reduce it
#    y_hat_2 = MLP(y_hat_1)
#
#    # Calculate reduced output difference against `y`
#    delta = mean(Flux.logitcrossentropy(y_hat_2, y))
#
#    return delta
#
#end

opt = ADAM(learning_rate)

#import Pkg; Pkg.add("Zygote")
using Flux
using Flux: throttle, crossentropy, @epochs, gradient, @progress, params
using Statistics: mean
#using Zygote: Params

# average loss for a single epoch needed for plotting
avg_train_losses = []
avg_valid_losses = []

batch_train_loss = []
batch_valid_loss = []

call(f, xs...) = f(xs...)
runall(f) = f
runall(fs::AbstractVector) = () -> foreach(call, fs)

struct SkipException <: Exception end

struct StopException <: Exception end

batchmemaybe(x) = tuple(x)
batchmemaybe(x::Tuple) = x

function fit(encoder, decider, opt, epochs)

    function loss(x, y)
        temp_model(x) = model(x, encoder, decider)
        return crossentropy(temp_model(x), y)
    end

    mp = params(encoder,decider)

    #num_tr = rand(1:length(etrainxs))
    #num_vl = rand(1:length(evalidxs))

    # This is to keep track of the loss for training and validation set.
    tracktx, trackty = (etrainxs[5], etrainys[5])
    track_vx, track_vy = (evalidxs[5], evalidys[5])

    evalcb = function()
        train_loss = loss(tracktx, trackty).data
        valid_loss = loss(track_vx, track_vy).data
        #.data gives us the non-tracked version of the array.
        push!(batch_train_loss, train_loss)
        push!(batch_valid_loss, valid_loss)
    end

    for i in 1:epochs
        println("epoch: $i \n")
        #Flux.train!(loss, mp, zip(etrainxs, etrainys), opt, cb = throttle(evalcb, 1))

        ps = mp
        cb = runall(throttle(evalcb, 1))
        @progress for d in zip(etrainxs[5:5], etrainys[5:5])
        #print(d)
          try
            gs = gradient(ps) do
              loss(batchmemaybe(d)...)
              #print(loss(batchmemaybe(d)...))
            end
            print(gs.grads, "\n")
            Flux.Optimise.update!(opt, ps, gs)
            cb()
          catch ex
            if ex isa StopException
              break
            elseif ex isa SkipException
              continue
            else
              rethrow(ex)
            end
          end
        end

        @show batch_train_loss
        @show batch_valid_loss
        avg_train_loss = mean(batch_train_loss)
        avg_valid_loss = mean(batch_valid_loss)
        push!(avg_train_losses, avg_train_loss)
        push!(avg_valid_losses, avg_valid_loss)

        global batch_train_loss = []
        global batch_valid_loss = []

    end

    @show avg_train_losses
    @show avg_valid_losses

    return model
end

epochs = 1
trained_model = fit(RNN_Part, MLP, opt, epochs)

# Plot the average loss on the training set and validation set for each epoch.

epoch_array = 1:epochs

p1 = plot(epoch_array, avg_train_losses, title="Loss vs Epochs (Training)", xlabel="Epoch",
    ylabel="Cross Entropy Loss")

p2 = plot(epoch_array, avg_valid_losses, title="Loss vs Epochs (Validation)", xlabel="Epoch",
    ylabel="Cross Entropy Loss")

plot(p1, p2, layout = (1, 2), legend = false)

When I try to get the derivitive I get

IdDict{Any,Any}(Tracked{Array{Float32,1}}(0x00000000, Call{Nothing,Tuple{}}(nothing, ()), true, Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])=>Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] (tracked),Tracked{Array{Float32,2}}(0x00000000, Call{Nothing,Tuple{}}(nothing, ()), true, Float32[0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ...

What versions of Flux et al. are you using? Flux has not been using Tracker for quite a while now…

Pkg.add(Pkg.PackageSpec(;name="Flux", version=v"0.9.0"))

Ah yup, a lot has happened in 16 months. If you can, try running with the latest Flux/Julia (0.11.2 and 1.5.3 respectively). https://github.com/FluxML/Flux.jl/issues/1360#issuecomment-727396539 is a good intro to the current RNN interface.

Does it make it easier to make time run with the train!(…) loop?

Not necessarily. train! can be a nice shortcut, but it can also hide certain errors in your training loop. If you’re comfortable using gradient with a custom training loop, I see no reason to stop doing that.

I am still fairly new to Flux and designing training loops in general. ^^

Do you have any suggestions for fixing the issue I have dug myself into in the code above? I can’t seem to understand where my theory went wrong.

I’m not sure there’s anything wrong with your theory, but it’s hard to tell if there might be any bugs in the implementation. Have a look at https://github.com/FluxML/model-zoo/blob/master/text/treebank/recursive.jl and see if that helps. I’ve not used RNNs in Flux 0.9 either, so it would be great if you could create a MWE per Please read: make it easier to help you as well.

Hi @Emilio4d46

I made this short example for some members in my lab on using RNNs w/ Zygote. Hopefully it can help you figure out your issues.

While I also need a minimum working example to help spot bugs, some areas that have been tricky in the past:

  • BPTT when broadcasting (i.e. model.(data)) is currently broken in zygote (although I think fixed on master), so you should use map.
  • Make sure the problem is solvable by your model. I usually use the sequential mnist test, just to confirm my implementation before starting on a new dataset/problem. (you can also look at pytorch or tensorflow RNN tutorials for example problems).
  • Make sure you are actually doing BPTT. Although if you are using tracker this is less of an issue because of the tape, but if you want to use zygote (which I would recommend moving towards) all your computations have to occur in the gradient call.
3 Likes

Thanks, that example has been very helpful for me.