What does the architecture of Flux.RNN look like?

Hi all,
I am confused about the architecture of RNN in Flux.jl. If I define a RNN using the code RNN(3,3), will it be like the left one or the right one in the picture below?

image

Based on the documentation, I think it is more like the right one, but I am not sure:

Flux.RNN —Function

RNN(in::Integer, out::Integer, σ = tanh)

The most basic recurrent layer; essentially acts as a Dense layer, but with the output fed back into the input each time step.

I actually want to define a multi-input and multi-output RNN like the right one in the figure. If flux.RNN() cannot achieve this, do anyone have ideas how to write it mannually?

Thanks a lot for your help!

7 Likes

Has there been any insight into this question?

You can use Flux.RNN either way, the only difference is how you manage your inputs.

Left:

x1 = rand(Float32, 3)
x2 = rand(Float32, 3)
x3 = rand(Float32, 3)

m = Flux.RNN(3,3)

y1, y2, y3 = m.([x1, x2, x3]) # Apply the RNN to each input sequentially.

Right:

x1 = rand(Float32, 3)
x2 = rand(Float32, 3)
x3 = rand(Float32, 3)

x = vcat(x1,x2,x3)

m = Flux.RNN(9,3)

y1, y2, y3 = m.([x,x,x]) # Apply the RNN to the concatenated input three times.

You can also use a Dense layer to learn a weighted combination of x1, x2, x3 .

2 Likes