RNN Sentiment analysis example with Flux: how to use batch ? which value for the inner state?

Hello, the code belows model a RNN for sentiment analysis over Amazon reviews.
My main problem is that I don’t know how to batch a recursive cell when the sequence (i.e. the review text) has variable length.
In the [documentation of Recurrent networks] indeed the sequence is constant-length and the x passed to train! is in the form of n-records vector of nSeq vector (sequence) of elements (words) where in turn these are n-features by nBatch matrices.

How do I (randomly) batch these variable-length sequences considering also that each sequence is independent so I need to reset the inner state at each sequence ?

I have also other questions, as the code below is very slow to learn. After hours of training it I am still at an accuracy well below those I can get using a linear perceptron:

I have 4000 reviews in the training set, each with about 100 “words” on average. My vocabulary is 14624 “words”.

  • which are reasonable choices of the size of the state matrix ? I know I can cross validate… but given the time it takes, is there any heuristic ?
  • more in general, am I doing something really bad/wrong in the code below ?
using Pkg
using Random

using DelimitedFiles, CSV, HTTP, Pipe, DataFrames, Flux

# Load the data and shuffle it in case it isn't..
dataURL = "https://raw.githubusercontent.com/sylvaticus/SPMLJ/main/lessonsMaterial/04_NN/sentimentAnalysis/productReviews.csv"
data    = @pipe HTTP.get(dataURL).body |> CSV.File(_,delim='\t') |> DataFrame
data    = data[shuffle(1:size(data,1)),:] # Shuffle the data in case it isn't..
# data = data[1:500,:] # let's work fist on  a subsample... 
data.sentiment = max.(0,data.sentiment) # Converting the sentiment label from {-1,1} to {0,1}

Inputs a text string, returns a list of lowercase words in the string.
Punctuation and digits are separated out into their own words.
function extractWords(inputString)
    punctuation = "!\"#\$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
    digits = "0123456789"
    for c in punctuation * digits
        inputString = replace(inputString, c => " " * c * " ")
    lowercase(inputString) |> split

# Extract a vocaboulary of unique words
vocabulary = unique(vcat(extractWords.(data.text)...));
nV         = length(vocabulary)

traindata,valdata = data[1:4000,:],data[4001:end,:]

# nRecords size vector of nSeq size vector of nVocabulary vector of booleans 
xtrain = [Flux.onehot.(extractWords(traindata.text[i]),Ref(vocabulary)) for i in 1:size(traindata,1)]
xval   = [Flux.onehot.(extractWords(valdata.text[i]),Ref(vocabulary)) for i in 1:size(valdata,1)]

ytrain = traindata.sentiment
yval   = valdata.sentiment

trainxy = zip(xtrain,ytrain)

m  = Chain(RNN(nV, 30),  Dense(30, 1))

function loss(x, y)
    nSeq = length(x)
    Flux.reset!(m) # Reset the state (not the weigtht!)
    [m(x[i]) for i in 1:nSeq-1]  # ignores the output but updates the hidden states

ps  = params(m)
opt = ADAM()
epochs = 4
function predictSentiment(m,x)
    nSeq = length(x)
    Flux.reset!(m) # Reset the state (not the weigtht!)
    [m(x[i]) for i in 1:nSeq-1]  # ignores the output but updates the hidden states
    return Int64(round(sigmoid(m(x[end])[1])))

trainAccs = Float64[]
valAccs   = Float64[]
for e in 1:epochs
    print("Epoch $e ")
    Flux.train!(loss, ps, trainxy, opt)
    ŷtrain        = predictSentiment.(Ref(m),xtrain)
    ŷval          = predictSentiment.(Ref(m),xval)
    trainaccuracy =  sum(ŷtrain .== ytrain)/length(xtrain)
    valaccuracy   =  sum(ŷval   .== yval)/length(xval)
    println("accuracies: $trainaccuracy - $valaccuracy")