Is there a way to speed up this implementation of BLSTM layer?

I have implemented the BLSTM layer for a batch of sequences input stored as a 3D array with dimensions ordered as [features × time × batch]. And I wonder if there is a faster way to implement it. I’m interested in the speed of both forward and backward passes. Here is my implementation:

using Flux
using Flux: @functor, Recur, LSTMCell
using Zygote
using Zygote: Buffer

struct BLSTM{M <: DenseMatrix, V <: DenseVector}
   forward  :: Recur{LSTMCell{M,V}}
   backward :: Recur{LSTMCell{M,V}}
   outdim   :: Int

Flux.trainable(m::BLSTM) = (m.forward, m.backward)
@functor BLSTM

function BLSTM(in::Integer, out::Integer)
   forward  = LSTM(in, out)
   backward = LSTM(in, out)
   return BLSTM(forward, backward, Int(out)) |> gpu

function (m::BLSTM)(Xs::DenseArray{<:Real,3})
   Xs = permutedims(Xs, (1, 3, 2)) # [features × time × batch] -> [features × batch × time]
   # preallocate output buffer
   Ys = Buffer(Xs, 2m.outdim, size(Xs,3), size(Xs,2))
   axisYs₁ = axes(Ys, 1)
   time_f = axes(Ys, 2)
   time_b = reverse(time_f)
   @inbounds begin
      # get forward and backward slice indices
      slice_f = axisYs₁[1:m.outdim]
      slice_b = axisYs₁[(m.outdim+1):end]
      # bidirectional run step
      setindex!.((Ys,),  m.forward.(view.((Xs,), :, :, time_f)), (slice_f,), time_f, :)
      setindex!.((Ys,), m.backward.(view.((Xs,), :, :, time_b)), (slice_b,), time_b, :)
      # the same as
      # @views for (t_f, t_b) ∈ zip(time_f, time_b)
      #    Ys[slice_f, t_f, :] =  m.forward(Xs[:, :, t_f])
      #    Ys[slice_b, t_b, :] = m.backward(Xs[:, :, t_b])
      # end
      # but implemented via broadcasting as Zygote differentiates loops much slower than broadcasting
   return copy(Ys) # [features × time × batch]

And here are the results I get:

using BenchmarkTools

D, T, B = 512, 800, 32
m = BLSTM(D, D÷2)
θ = params(m)
Xs = rand(Float32, 512, 800, 32) |> gpu

f(m, Xs) = sum(m(Xs))
g(m, Xs, θ) = gradient(() -> f(m, Xs), θ)

julia> @benchmark f($m, $Xs)
  memory estimate:  26.50 MiB
  allocs estimate:  722527
  minimum time:     607.991 ms (0.00% GC)
  median time:      629.910 ms (1.41% GC)
  mean time:        625.312 ms (1.13% GC)
  maximum time:     631.088 ms (1.46% GC)
  samples:          9
  evals/sample:     1

julia> Flux.reset!(m); @benchmark g($m, $Xs, $θ)
  memory estimate:  28.98 MiB
  allocs estimate:  797407
  minimum time:     603.341 ms (0.00% GC)
  median time:      622.730 ms (1.30% GC)
  mean time:        621.413 ms (1.17% GC)
  maximum time:     629.907 ms (1.38% GC)
  samples:          9
  evals/sample:     1

Any improvement suggestions?

I discovered that the speed up above is due to bug in Zygote, which is dropping the gradients. Filed here.