I am trying to turn the following code to Lux.jl
def generator(gru_units, dense_units, sequence_length, noise_dimension, model_dimension):
# Inputs.
inputs = tf.keras.layers.Input(shape=(sequence_length, model_dimension))
# GRU block.
outputs = tf.keras.layers.GRU(units=gru_units[0], return_sequences=False if len(gru_units) == 1 else True)(inputs)
for i in range(1, len(gru_units)):
outputs = tf.keras.layers.GRU(units=gru_units[i], return_sequences=True if i < len(gru_units) - 1 else False)(outputs)
# Noise vector.
noise = tf.keras.layers.Input(shape=noise_dimension)
outputs = tf.keras.layers.Concatenate(axis=-1)([noise, outputs])
# Dense layers.
outputs = tf.keras.layers.Dense(units=dense_units)(outputs)
outputs = tf.keras.layers.Dense(units=model_dimension)(outputs)
return tf.keras.models.Model([inputs, noise], outputs)
What is the analog of tf.keras.layers.GRU
(this will be ultimately a forecasting task).
Lux.StatefulRecurrentCell(GRUCell()))
?Lux.Recurrence(...; return_sequence = ... )
?
I am not sure to fully understand the doc on that.
Moreover, I did not find any actual example of these layers in the doc (there is the LSTM tutorial, but it uses a custom layer).
I cannot get the StatefulRecurrentCell
to work.
For example, with (features, timestep, N_seq) = (2,71,16)
, x = rand(features, timestep, N_seq)
using Randomm, Lux
model = Chain(
StatefulRecurrentCell(GRUCell(inputsize => 3)),
)
rng = Xoshiro(0)
ps, st = Lux.setup(rng, model)
y, ps = model(x, ps, st)
ERROR: MethodError: no method matching reshape(::Float64, ::Colon, ::Int64)
model = Chain(
Recurrence(GRUCell(inputsize => 3); return_sequence=false),
)
rng = Xoshiro(0)
ps, st = Lux.setup(rng, model)
y, ps = model(x, ps, st)
# works
# But
model = Chain(
Recurrence(GRUCell(inputsize => 3); return_sequence=false),
Recurrence(GRUCell(inputsize => 3); return_sequence=false),
)
ERROR: `BatchLastIndex` not supported for AbstractMatrix. You probably want to use `TimeLastIndex`.
# don't work because the first layer output a 2D Matrix not a 3D matrix.
Maybe somehow linked to this questions