I am trying to turn the following code to Lux.jl

```
def generator(gru_units, dense_units, sequence_length, noise_dimension, model_dimension):
# Inputs.
inputs = tf.keras.layers.Input(shape=(sequence_length, model_dimension))
# GRU block.
outputs = tf.keras.layers.GRU(units=gru_units[0], return_sequences=False if len(gru_units) == 1 else True)(inputs)
for i in range(1, len(gru_units)):
outputs = tf.keras.layers.GRU(units=gru_units[i], return_sequences=True if i < len(gru_units) - 1 else False)(outputs)
# Noise vector.
noise = tf.keras.layers.Input(shape=noise_dimension)
outputs = tf.keras.layers.Concatenate(axis=-1)([noise, outputs])
# Dense layers.
outputs = tf.keras.layers.Dense(units=dense_units)(outputs)
outputs = tf.keras.layers.Dense(units=model_dimension)(outputs)
return tf.keras.models.Model([inputs, noise], outputs)
```

What is the analog of `tf.keras.layers.GRU`

(this will be ultimately a forecasting task).

`Lux.StatefulRecurrentCell(GRUCell()))`

?`Lux.Recurrence(...; return_sequence = ... )`

?

I am not sure to fully understand the doc on that.

Moreover, I did not find any actual example of these layers in the doc (there is the LSTM tutorial, but it uses a custom layer).

I cannot get the `StatefulRecurrentCell`

to work.

For example, with `(features, timestep, N_seq) = (2,71,16)`

, `x = rand(features, timestep, N_seq)`

```
using Randomm, Lux
model = Chain(
StatefulRecurrentCell(GRUCell(inputsize => 3)),
)
rng = Xoshiro(0)
ps, st = Lux.setup(rng, model)
y, ps = model(x, ps, st)
ERROR: MethodError: no method matching reshape(::Float64, ::Colon, ::Int64)
```

```
model = Chain(
Recurrence(GRUCell(inputsize => 3); return_sequence=false),
)
rng = Xoshiro(0)
ps, st = Lux.setup(rng, model)
y, ps = model(x, ps, st)
# works
# But
model = Chain(
Recurrence(GRUCell(inputsize => 3); return_sequence=false),
Recurrence(GRUCell(inputsize => 3); return_sequence=false),
)
ERROR: `BatchLastIndex` not supported for AbstractMatrix. You probably want to use `TimeLastIndex`.
# don't work because the first layer output a 2D Matrix not a 3D matrix.
```

Maybe somehow linked to this questions