Passing whole sequence to next layer from RNNCell in Lux

Hello,

I am trying to reproduce some basic code for training recurrent spiking networks. Here you can find an old porting for Flux from spytorch.

I already adapted to the current Flux version, but I would really like to move to Lux as I believe the support for Neural ODE could come in very handy in later stages of the project.

The idea is that there is a hidden spiking layer that receives spike inputs and yields spikes to an output layer. As such, the input is a InputNeurons x Time x Batchsize matrix, and the output should be an OutputNeuron x Time x Batchsize matrix.

On my way there, I tried to do this using a RNNCell and I can do it as below:

using Lux, Random, Optimisers

function RNNClassifierCompact(in_dims, hidden_dims, out_dims)
    return @compact(;
        input=Dense(in_dims=>hidden_dims, sigmoid),
        rnn_cell=RNNCell(hidden_dims => hidden_dims) |> x-> Lux.Recurrence(x, return_sequence=true),
        classifier=Dense(hidden_dims => out_dims, sigmoid)
    ) do x::AbstractArray{T,3} where {T}
        out = map(rnn_cell(input(x))) do x
                classifier(x)
        end
        @return cat(out..., dims=3) |> x -> permutedims(x, (1, 3, 2)) 
    end
end

layers = (10, 2, 5)
model = RNNClassifierCompact(layers...)
ps, st = Lux.setup(Random.default_rng(), model) 


x = rand(10, 50, 128)
train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
st_ = Lux.testmode(train_state.states)
ŷ, st_ = model(x, train_state.parameters, st_)

size(ŷ) # (5, 50, 128)

However, the way I handle the output of the RNN cell seems just wrong. In Flux I was wrapping all the layers with Recurrence and that was sufficient. Any hint on how do it in Lux in the cleanest and most performant way?

Thanks!