Issue understanding Lux recurrent cells

I am trying to turn the following code to Lux.jl

def generator(gru_units, dense_units, sequence_length, noise_dimension, model_dimension):
   
    # Inputs.
    inputs = tf.keras.layers.Input(shape=(sequence_length, model_dimension))

    # GRU block.
    outputs = tf.keras.layers.GRU(units=gru_units[0], return_sequences=False if len(gru_units) == 1 else True)(inputs)
    for i in range(1, len(gru_units)):
        outputs = tf.keras.layers.GRU(units=gru_units[i], return_sequences=True if i < len(gru_units) - 1 else False)(outputs)

    # Noise vector.
    noise = tf.keras.layers.Input(shape=noise_dimension)
    outputs = tf.keras.layers.Concatenate(axis=-1)([noise, outputs])

    # Dense layers.
    outputs = tf.keras.layers.Dense(units=dense_units)(outputs)
    outputs = tf.keras.layers.Dense(units=model_dimension)(outputs)

    return tf.keras.models.Model([inputs, noise], outputs)

What is the analog of tf.keras.layers.GRU (this will be ultimately a forecasting task).

  • Lux.StatefulRecurrentCell(GRUCell())) ?
  • Lux.Recurrence(...; return_sequence = ... ) ?

I am not sure to fully understand the doc on that.

Moreover, I did not find any actual example of these layers in the doc (there is the LSTM tutorial, but it uses a custom layer).
I cannot get the StatefulRecurrentCell to work.

For example, with (features, timestep, N_seq) = (2,71,16), x = rand(features, timestep, N_seq)

using Randomm, Lux

model = Chain(
    StatefulRecurrentCell(GRUCell(inputsize => 3)),
)
rng = Xoshiro(0)

ps, st = Lux.setup(rng, model)

y, ps = model(x, ps, st)
ERROR: MethodError: no method matching reshape(::Float64, ::Colon, ::Int64)
model = Chain(
        Recurrence(GRUCell(inputsize => 3); return_sequence=false),
    )
rng = Xoshiro(0)

ps, st = Lux.setup(rng, model)

y, ps = model(x, ps, st)
# works

# But
model = Chain(
        Recurrence(GRUCell(inputsize => 3); return_sequence=false),
        Recurrence(GRUCell(inputsize => 3); return_sequence=false),
    )
ERROR: `BatchLastIndex` not supported for AbstractMatrix. You probably want to use `TimeLastIndex`.
# don't work because the first layer output a 2D Matrix not a 3D matrix.

Maybe somehow linked to this questions

You are explicitly requesting the layer not to return the sequence return_sequence=false, so it will output a matrix.

What you want in this situation is

model = Chain(
    Recurrence(GRUCell(inputsize => 3); return_sequence=true),
    Recurrence(GRUCell(inputsize => 3); return_sequence=false),
)

StatefulRecurrentCell is a completely different model that caches the carry and output of the underlying cell; it is useful when defining custom pipelines without figuring out how to cache those objects.