# Is there a way to speed up this implementation of BLSTM layer?

I have implemented the BLSTM layer for a batch of sequences input stored as a 3D array with dimensions ordered as `[features × time × batch]`. And I wonder if there is a faster way to implement it. I’m interested in the speed of both forward and backward passes. Here is my implementation:

``````using Flux
using Flux: @functor, Recur, LSTMCell
using Zygote
using Zygote: Buffer

struct BLSTM{M <: DenseMatrix, V <: DenseVector}
forward  :: Recur{LSTMCell{M,V}}
backward :: Recur{LSTMCell{M,V}}
outdim   :: Int
end

Flux.trainable(m::BLSTM) = (m.forward, m.backward)
@functor BLSTM

function BLSTM(in::Integer, out::Integer)
forward  = LSTM(in, out)
backward = LSTM(in, out)
return BLSTM(forward, backward, Int(out)) |> gpu
end

function (m::BLSTM)(Xs::DenseArray{<:Real,3})
Xs = permutedims(Xs, (1, 3, 2)) # [features × time × batch] -> [features × batch × time]
# preallocate output buffer
Ys = Buffer(Xs, 2m.outdim, size(Xs,3), size(Xs,2))
axisYs₁ = axes(Ys, 1)
time_f = axes(Ys, 2)
time_b = reverse(time_f)
@inbounds begin
# get forward and backward slice indices
slice_f = axisYs₁[1:m.outdim]
slice_b = axisYs₁[(m.outdim+1):end]
# bidirectional run step
setindex!.((Ys,),  m.forward.(view.((Xs,), :, :, time_f)), (slice_f,), time_f, :)
setindex!.((Ys,), m.backward.(view.((Xs,), :, :, time_b)), (slice_b,), time_b, :)
# the same as
# @views for (t_f, t_b) ∈ zip(time_f, time_b)
#    Ys[slice_f, t_f, :] =  m.forward(Xs[:, :, t_f])
#    Ys[slice_b, t_b, :] = m.backward(Xs[:, :, t_b])
# end
# but implemented via broadcasting as Zygote differentiates loops much slower than broadcasting
end
return copy(Ys) # [features × time × batch]
end
``````

And here are the results I get:

``````using BenchmarkTools

D, T, B = 512, 800, 32
m = BLSTM(D, D÷2)
θ = params(m)
Xs = rand(Float32, 512, 800, 32) |> gpu

f(m, Xs) = sum(m(Xs))
g(m, Xs, θ) = gradient(() -> f(m, Xs), θ)

julia> @benchmark f(\$m, \$Xs)
BenchmarkTools.Trial:
memory estimate:  26.50 MiB
allocs estimate:  722527
--------------
minimum time:     607.991 ms (0.00% GC)
median time:      629.910 ms (1.41% GC)
mean time:        625.312 ms (1.13% GC)
maximum time:     631.088 ms (1.46% GC)
--------------
samples:          9
evals/sample:     1

julia> Flux.reset!(m); @benchmark g(\$m, \$Xs, \$θ)
BenchmarkTools.Trial:
memory estimate:  28.98 MiB
allocs estimate:  797407
--------------
minimum time:     603.341 ms (0.00% GC)
median time:      622.730 ms (1.30% GC)
mean time:        621.413 ms (1.17% GC)
maximum time:     629.907 ms (1.38% GC)
--------------
samples:          9
evals/sample:     1
``````

Any improvement suggestions?

1 Like

I discovered that the speed up above is due to bug in Zygote, which is dropping the gradients. Filed here.