# How to train Flux to learn a sequence conditional to some initial "seeds"?

I am trying to write a RNN that given an initial “seed” sequence, it reproduces the continuation of the sequence.
In the code above dummy sequences are generated as function of these initial seed points and a RNN approach is attempted, but when I plot the generated sequences are very badly connected with the “true” ones.

Where am I wrong ?

### Setting the environment…

``````# Setting the environment...
cd(@__DIR__)
using Pkg
Pkg.activate(".")
# Pkg.resolve()
# Pkg.instantiate()
using Random
Random.seed!(123)
using LinearAlgebra, Plots, Flux
``````

### Generating simulated data

The idea is to have a sequence that depends on the first 5 values. So the first 5 values are random, but the rest of the sequence depends deterministically to these first 5 values and the objective it to recreate this second part of the sequence knowing the first 5 parts.

``````nSeeds    = 5
seqLength = 5
nTrains   = 1000
nVal      = 100
nTot = nTrains+nVal
makeSeeds(nSeeds) = 2 .* (rand(nSeeds) .- 0.5) # [-1,+1]
function makeSequence(seeds,seqLength)
seq = Vector{Float32}(undef,seqLength+nSeeds) # Flux Works with Float32 for performance reasons
[seq[i] = seeds[i] for i in 1:nSeeds]
for i in nSeeds+1:(seqLength+nSeeds)
seq[i] = seq[i-1] + (seeds*0.5) # the only seed that matters is the 4th. Let's see if the RNN learn it !
end
return seq
return seq[nSeeds+1:end]
end

x0   = [makeSeeds(nSeeds) for i in 1:nTot]
seqs = makeSequence.(x0,seqLength)
seqs_vectors = [[[e] for e in seq] for seq in seqs]
y    = [s[2:end] for s in seqs_vectors] # y here is the value of the sequence itself at next step

xtrain = seqs_vectors[1:nTrains]
xval   = seqs_vectors[nTrains+1:end]
ytrain = y[1:nTrains]
yval   = y[nTrains+1:end]

# Flux wants a vector of sequences of individual items, when these in turns are vectors
allData   = xtrain;
aSequence = allData
anElement = aSequence
``````

### Utility functions

``````function predictSequence(m,seeds,seqLength)
seq = Vector{Vector{Float32}}(undef,seqLength+length(seeds)-1)
Flux.reset!(m) # Reset the state (not the weigtht!)
[seq[i] = [convert(Float32, seeds[i])] for i in 1:nSeeds]
[seq[i] = m(seq[i-1]) for i in nSeeds+1:nSeeds+seqLength-1]
[s for s in seq]
end

function myloss(x, y)
Flux.reset!(m)                 # Reset the state (not the weigtht!)
[m(x[i]) for i in 1:nSeeds-1]  # Ignores the output but updates the hidden states
# y_i is x_(i+1), i.e. next element
sum(Flux.mse(m(xi), yi) for (xi, yi) in zip(x[nSeeds:(end-1)], y[nSeeds:end]))
end
"""
batchSequences(x,batchSize)

Transform a vector of sequences of individual elements represented as feature vectors to a vector of sequences of elements represented as features ×  batched record matrices
"""
function batchSequences(x,batchSize)
x = copy(xtrain)
batchSize = 3
nRecords  = length(x)
nItems    = length(x)
nDims     = size(x,1)
nBatches  = Int(floor(nRecords/batchSize))

emptyBatchedElement = Matrix{Float32}(undef,nDims,batchSize)
emptySeq = [similar(emptyBatchedElement) for i in 1:nItems]
outx = [similar(emptySeq) for i in 1:nBatches]
for b in 1:nBatches
xmin = (b-1)*batchSize + 1
xmax = b*batchSize
for e in 1:nItems
outx[b][e] = hcat([x[i][e][:,1] for i in xmin:xmax]... )
end
end
return outx
end
``````

### Defining the model

``````m    = Chain(Dense(1,3),LSTM(3, 3), Dense(3, 5,relu),Dense(5,1))
ps  = params(m)
``````

### Plotting a random sequence and its prediction from untrained model…

``````seq1True = makeSequence(x0,seqLength)
seq1Est0 = predictSequence(m,x0,seqLength)
plot(seq1True)
plot!(seq1Est0)
``````

### Actual training

``````trainMSE  = Float64[]
valMSE    = Float64[]
epochs    = 20
batchSize = 16
for e in 1:epochs
print("Epoch \$e ")
# Shuffling at each epoch
ids = shuffle(1:length(xtrain))
x0e      = x0[ids]
xtraine  = xtrain[ids]
ytraine  = ytrain[ids]

xtraine =batchSequences(xtraine,batchSize)
ytraine =batchSequences(ytraine,batchSize)
trainxy = zip(xtraine,ytraine)

# Actual training
Flux.train!(myloss, ps, trainxy, opt)
# Making prediction on the trained model and computing accuracies
global trainMSE, valMSE
ŷtrain  = [predictSequence(m,x0[i],seqLength) for i in 1:nTrains]
ŷval    = [predictSequence(m,x0[i],seqLength) for i in (nTrains+1):nTot]
ytrain  = [makeSequence(x0[i],seqLength) for i in  1:nTrains]
yval    = [makeSequence(x0[i],seqLength) for i in  (nTrains+1):nTot]

trainmse =  sum(norm(ŷtrain[i][nSeeds+1:end] - ytrain[i][nSeeds+1:end-1])^2 for i in 1:nTrains)/nTrains
valmse   =  sum(norm(ŷval[i][nSeeds+1:end] - yval[i][nSeeds+1:end-1])^2 for i in 1:nVal)/nVal
push!(trainMSE,trainmse)
push!(valMSE,valmse)
println("MEan Sq Error: \$trainmse - \$valmse")
end
``````

### Plotting some random sequences

``````for i = rand(1:nTot,5)
trueseq = makeSequence(x0[i],seqLength)
estseq  = predictSequence(m,x0[i],seqLength)
seqPlot = plot(trueseq[1:end-1],label="true", title = "Seq \$i")
plot!(seqPlot, estseq, label="est")
display(seqPlot)
end
``````

### Plotting the error

Strange, the validation error is always lower than the training error…

``````plot(trainMSE,label="Train MSE")
plot!(valMSE,label="Validation MSE")
``````

The error changes depending of the parameter but got always stuck to some local minima, typically taking the expected value of the sequence unconditionally (i.e. horizontal line): And some sequence true/predicted looks like:   (note that the estimate doesn’t change. Sometimes I can make it change, but I always have outputs very far from the intended sequence)

Can you double check the validity of the algorithmic approach with another library? There’s quite a bit of code to go through and I don’t see any obvious errors.

Ok, thanks… Which other Julia libraries support RNN other than flux? Knet.jl?

I believe it does, but for this kind of proof of concept you need not limit yourself to Julia libraries either.

It is strange that the train MSE rises for the first 7 epochs straight (which covers several hundred gradient steps). Does decreasing the learning rate help?

By the way you wrote that the validation error is always smaller than the training error, but according to your plot that is not the case.