I have been trying to build an RNN in Julia using Flux and it does not seem to converge or learn at all.
The code is learning on a parsed Stanford Sentiment Treebank (i.e. An array of words and a sentiment score)
The code is the following:
using Flux
using Flux: crossentropy, softmax, onehot
using Statistics: mean
#Quality of life assignments
#Data Pre-Processing
function cathot(x)
#Converts a categorical number into a 1 hot vector.
Float64.(onehot(x, [0,1,2,3,4]))
end
seqembed(list) = [embedding(x) for x in list] #Embeds the sequence of words into a sequences of vectors via Word2Vec.
etrainxs , etrainys = seqembed.(ssttrainxs), cathot.(ssttrainys)
evalidxs , evalidys = seqembed.(sstvalidxs), cathot.(sstvalidys)
etestxs , etestys = seqembed.(ssttestxs), cathot.(ssttestys)
#Defining the models - Recurrent cells with a 3 level MLP
n = 32
learning_rate = 0.003
#Defining subcells to get training to work.
RNN_Part = RNN(300, n, identity)
LSTM_Part = LSTM(300, n)
GRU_part = GRU(300, n)
MLP = Chain(Dense(n,n, Flux.σ),
#Dense(n,n, Flux.σ),
Dense(n,5,Flux.σ), softmax)
#Defining full models
function model(x, encoder, decider)
state = encoder.(x)[end] # the last element, so the last hidden state
Flux.reset!(encoder)
decider(state) # this returns a vector of a single element, so take the element
end
#Defining Loss and Optmisiers
#function loss(x, y)
#
# # Reset internal Recurrent Cell state.
# Flux.truncate!(RNN_Part)
#
# # Iterate over every timepoint in the sentence.
# y_hat_1 = RNN_Part.(x)[end]
#
# # Take the very last output from the recurrent section, reduce it
# y_hat_2 = MLP(y_hat_1)
#
# # Calculate reduced output difference against `y`
# delta = mean(Flux.logitcrossentropy(y_hat_2, y))
#
# return delta
#
#end
opt = ADAM(learning_rate)
#import Pkg; Pkg.add("Zygote")
using Flux
using Flux: throttle, crossentropy, @epochs, gradient, @progress, params
using Statistics: mean
#using Zygote: Params
# average loss for a single epoch needed for plotting
avg_train_losses = []
avg_valid_losses = []
batch_train_loss = []
batch_valid_loss = []
call(f, xs...) = f(xs...)
runall(f) = f
runall(fs::AbstractVector) = () -> foreach(call, fs)
struct SkipException <: Exception end
struct StopException <: Exception end
batchmemaybe(x) = tuple(x)
batchmemaybe(x::Tuple) = x
function fit(encoder, decider, opt, epochs)
function loss(x, y)
temp_model(x) = model(x, encoder, decider)
return crossentropy(temp_model(x), y)
end
mp = params(encoder,decider)
#num_tr = rand(1:length(etrainxs))
#num_vl = rand(1:length(evalidxs))
# This is to keep track of the loss for training and validation set.
tracktx, trackty = (etrainxs[5], etrainys[5])
track_vx, track_vy = (evalidxs[5], evalidys[5])
evalcb = function()
train_loss = loss(tracktx, trackty).data
valid_loss = loss(track_vx, track_vy).data
#.data gives us the non-tracked version of the array.
push!(batch_train_loss, train_loss)
push!(batch_valid_loss, valid_loss)
end
for i in 1:epochs
println("epoch: $i \n")
#Flux.train!(loss, mp, zip(etrainxs, etrainys), opt, cb = throttle(evalcb, 1))
ps = mp
cb = runall(throttle(evalcb, 1))
@progress for d in zip(etrainxs[5:5], etrainys[5:5])
#print(d)
try
gs = gradient(ps) do
loss(batchmemaybe(d)...)
#print(loss(batchmemaybe(d)...))
end
print(gs.grads, "\n")
Flux.Optimise.update!(opt, ps, gs)
cb()
catch ex
if ex isa StopException
break
elseif ex isa SkipException
continue
else
rethrow(ex)
end
end
end
@show batch_train_loss
@show batch_valid_loss
avg_train_loss = mean(batch_train_loss)
avg_valid_loss = mean(batch_valid_loss)
push!(avg_train_losses, avg_train_loss)
push!(avg_valid_losses, avg_valid_loss)
global batch_train_loss = []
global batch_valid_loss = []
end
@show avg_train_losses
@show avg_valid_losses
return model
end
epochs = 1
trained_model = fit(RNN_Part, MLP, opt, epochs)
# Plot the average loss on the training set and validation set for each epoch.
epoch_array = 1:epochs
p1 = plot(epoch_array, avg_train_losses, title="Loss vs Epochs (Training)", xlabel="Epoch",
ylabel="Cross Entropy Loss")
p2 = plot(epoch_array, avg_valid_losses, title="Loss vs Epochs (Validation)", xlabel="Epoch",
ylabel="Cross Entropy Loss")
plot(p1, p2, layout = (1, 2), legend = false)
When I try to get the derivitive I get
IdDict{Any,Any}(Tracked{Array{Float32,1}}(0x00000000, Call{Nothing,Tuple{}}(nothing, ()), true, Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])=>Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] (tracked),Tracked{Array{Float32,2}}(0x00000000, Call{Nothing,Tuple{}}(nothing, ()), true, Float32[0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ...