How to build a bidirectional RNN with Flux?


How would you build a bidirectional RNN with Flux? Drawing from a few examples

  1. Bidirectional LSTM example A
  2. Bidirectional LSTM example B
  3. Knet bidirectional RNN source

I’ve written

using Pkg; for p in ["Flux"] Pkg.add(p) end
using Flux

# Bidirectional RNN
struct BRNN{L,D}
  forward  :: L
  backward :: L
  dense    :: D

Flux.@functor BRNN

function BRNN(in::Integer, hidden::Integer, out::Integer, σ = relu)
  return BRNN(
    RNN(in, hidden, σ), # forward
    RNN(in, hidden, σ), # backward
    Dense(2hidden, out, σ)

function (m::BRNN)(xs)
  m.dense(vcat(m.forward(xs), m.backward(reverse(xs))))

inSize = 5
hiddenSize = 3
outSize = 1

trn = [(rand(inSize), rand(outSize)) for _ in 1:8]
@info "trn", summary(trn)

m = BRNN(inSize, hiddenSize, outSize)
loss(x, y) = Flux.mse(m(x), y)
ps = Flux.params(m)
opt = ADAM()

Flux.train!(loss, ps, trn, opt)

errors with

ERROR: LoadError: Mutating arrays is not supported

Flux does not like reversing the input data.


Other examples (A, B) use

Flux.flip(m.backward, xs)

but does not produce a compatible output shape for this example.

How would you approach it?

See Flux issue Zygote errors building bidirectional RNN