How to train Flux to learn a sequence conditional to some initial "seeds"?

I am trying to write a RNN that given an initial “seed” sequence, it reproduces the continuation of the sequence.
In the code above dummy sequences are generated as function of these initial seed points and a RNN approach is attempted, but when I plot the generated sequences are very badly connected with the “true” ones.

Where am I wrong ?

Setting the environment…

# Setting the environment...
using Pkg      
# Pkg.add(["Plots","Flux"])
# Pkg.resolve()   
# Pkg.instantiate()
using Random
using LinearAlgebra, Plots, Flux

Generating simulated data

The idea is to have a sequence that depends on the first 5 values. So the first 5 values are random, but the rest of the sequence depends deterministically to these first 5 values and the objective it to recreate this second part of the sequence knowing the first 5 parts.

nSeeds    = 5
seqLength = 5
nTrains   = 1000  
nVal      = 100
nTot = nTrains+nVal
makeSeeds(nSeeds) = 2 .* (rand(nSeeds) .- 0.5) # [-1,+1]
function makeSequence(seeds,seqLength)
  seq = Vector{Float32}(undef,seqLength+nSeeds) # Flux Works with Float32 for performance reasons
  [seq[i] = seeds[i] for i in 1:nSeeds]
  for i in nSeeds+1:(seqLength+nSeeds)
     seq[i] = seq[i-1] + (seeds[4]*0.5) # the only seed that matters is the 4th. Let's see if the RNN learn it !
  return seq
  return seq[nSeeds+1:end]

x0   = [makeSeeds(nSeeds) for i in 1:nTot]
seqs = makeSequence.(x0,seqLength)
seqs_vectors = [[[e] for e in seq] for seq in seqs]
y    = [s[2:end] for s in seqs_vectors] # y here is the value of the sequence itself at next step

xtrain = seqs_vectors[1:nTrains]
xval   = seqs_vectors[nTrains+1:end]
ytrain = y[1:nTrains]
yval   = y[nTrains+1:end]

# Flux wants a vector of sequences of individual items, when these in turns are vectors
allData   = xtrain;
aSequence = allData[1]
anElement = aSequence[1]

Utility functions

function predictSequence(m,seeds,seqLength)
    seq = Vector{Vector{Float32}}(undef,seqLength+length(seeds)-1)
    Flux.reset!(m) # Reset the state (not the weigtht!)
    [seq[i] = [convert(Float32, seeds[i])] for i in 1:nSeeds]
    [seq[i] = m(seq[i-1]) for i in nSeeds+1:nSeeds+seqLength-1]
    [s[1] for s in seq]

function myloss(x, y)
    Flux.reset!(m)                 # Reset the state (not the weigtht!)
    [m(x[i]) for i in 1:nSeeds-1]  # Ignores the output but updates the hidden states
    # y_i is x_(i+1), i.e. next element
    sum(Flux.mse(m(xi), yi) for (xi, yi) in zip(x[nSeeds:(end-1)], y[nSeeds:end]))

Transform a vector of sequences of individual elements represented as feature vectors to a vector of sequences of elements represented as features ×  batched record matrices
function batchSequences(x,batchSize)
    x = copy(xtrain)
    batchSize = 3
    nRecords  = length(x)
    nItems    = length(x[1])
    nDims     = size(x[1][1],1) 
    nBatches  = Int(floor(nRecords/batchSize))

    emptyBatchedElement = Matrix{Float32}(undef,nDims,batchSize)
    emptySeq = [similar(emptyBatchedElement) for i in 1:nItems]
    outx = [similar(emptySeq) for i in 1:nBatches]
    for b in 1:nBatches
        xmin = (b-1)*batchSize + 1
        xmax = b*batchSize
        for e in 1:nItems
            outx[b][e] = hcat([x[i][e][:,1] for i in xmin:xmax]... )
    return outx

Defining the model

m    = Chain(Dense(1,3),LSTM(3, 3), Dense(3, 5,relu),Dense(5,1))
ps  = params(m)
opt = Flux.ADAM()

Plotting a random sequence and its prediction from untrained model…

seq1True = makeSequence(x0[1],seqLength)
seq1Est0 = predictSequence(m,x0[1],seqLength)

Actual training

trainMSE  = Float64[]
valMSE    = Float64[]
epochs    = 20 
batchSize = 16
for e in 1:epochs
    print("Epoch $e ")
    # Shuffling at each epoch
    ids = shuffle(1:length(xtrain))
    x0e      = x0[ids]
    xtraine  = xtrain[ids]
    ytraine  = ytrain[ids]

    xtraine =batchSequences(xtraine,batchSize)
    ytraine =batchSequences(ytraine,batchSize)
    trainxy = zip(xtraine,ytraine)

    # Actual training
    Flux.train!(myloss, ps, trainxy, opt)
    # Making prediction on the trained model and computing accuracies
    global trainMSE, valMSE
    ŷtrain  = [predictSequence(m,x0[i],seqLength) for i in 1:nTrains]
    ŷval    = [predictSequence(m,x0[i],seqLength) for i in (nTrains+1):nTot]
    ytrain  = [makeSequence(x0[i],seqLength) for i in  1:nTrains]
    yval    = [makeSequence(x0[i],seqLength) for i in  (nTrains+1):nTot]

    trainmse =  sum(norm(ŷtrain[i][nSeeds+1:end] - ytrain[i][nSeeds+1:end-1])^2 for i in 1:nTrains)/nTrains
    valmse   =  sum(norm(ŷval[i][nSeeds+1:end] - yval[i][nSeeds+1:end-1])^2 for i in 1:nVal)/nVal
    println("MEan Sq Error: $trainmse - $valmse")

Plotting some random sequences

for i = rand(1:nTot,5)
    trueseq = makeSequence(x0[i],seqLength)
    estseq  = predictSequence(m,x0[i],seqLength)
    seqPlot = plot(trueseq[1:end-1],label="true", title = "Seq $i")
    plot!(seqPlot, estseq, label="est")

Plotting the error

Strange, the validation error is always lower than the training error…

plot(trainMSE,label="Train MSE")
plot!(valMSE,label="Validation MSE")

The error changes depending of the parameter but got always stuck to some local minima, typically taking the expected value of the sequence unconditionally (i.e. horizontal line):


And some sequence true/predicted looks like:


(note that the estimate doesn’t change. Sometimes I can make it change, but I always have outputs very far from the intended sequence)

Crosspost on SO: deep learning - How to train Flux.jl to learn a sequence conditional to some initial "seeds"? - Stack Overflow

Can you double check the validity of the algorithmic approach with another library? There’s quite a bit of code to go through and I don’t see any obvious errors.

Ok, thanks… Which other Julia libraries support RNN other than flux? Knet.jl?

I believe it does, but for this kind of proof of concept you need not limit yourself to Julia libraries either.

It is strange that the train MSE rises for the first 7 epochs straight (which covers several hundred gradient steps). Does decreasing the learning rate help?

By the way you wrote that the validation error is always smaller than the training error, but according to your plot that is not the case.