Hello,
I am trying to reproduce some basic code for training recurrent spiking networks. Here you can find an old porting for Flux from spytorch.
I already adapted to the current Flux version, but I would really like to move to Lux as I believe the support for Neural ODE could come in very handy in later stages of the project.
The idea is that there is a hidden spiking layer that receives spike inputs and yields spikes to an output layer. As such, the input is a InputNeurons x Time x Batchsize matrix, and the output should be an OutputNeuron x Time x Batchsize matrix.
On my way there, I tried to do this using a RNNCell and I can do it as below:
using Lux, Random, Optimisers
function RNNClassifierCompact(in_dims, hidden_dims, out_dims)
return @compact(;
input=Dense(in_dims=>hidden_dims, sigmoid),
rnn_cell=RNNCell(hidden_dims => hidden_dims) |> x-> Lux.Recurrence(x, return_sequence=true),
classifier=Dense(hidden_dims => out_dims, sigmoid)
) do x::AbstractArray{T,3} where {T}
out = map(rnn_cell(input(x))) do x
classifier(x)
end
@return cat(out..., dims=3) |> x -> permutedims(x, (1, 3, 2))
end
end
layers = (10, 2, 5)
model = RNNClassifierCompact(layers...)
ps, st = Lux.setup(Random.default_rng(), model)
x = rand(10, 50, 128)
train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
st_ = Lux.testmode(train_state.states)
ŷ, st_ = model(x, train_state.parameters, st_)
size(ŷ) # (5, 50, 128)
However, the way I handle the output of the RNN cell seems just wrong. In Flux I was wrapping all the layers with Recurrence and that was sufficient. Any hint on how do it in Lux in the cleanest and most performant way?
Thanks!