I have implemented the BLSTM layer for a batch of sequences input stored as a 3D array with dimensions ordered as [features × time × batch]
. And I wonder if there is a faster way to implement it. I’m interested in the speed of both forward and backward passes. Here is my implementation:
using Flux
using Flux: @functor, Recur, LSTMCell
using Zygote
using Zygote: Buffer
struct BLSTM{M <: DenseMatrix, V <: DenseVector}
forward :: Recur{LSTMCell{M,V}}
backward :: Recur{LSTMCell{M,V}}
outdim :: Int
end
Flux.trainable(m::BLSTM) = (m.forward, m.backward)
@functor BLSTM
function BLSTM(in::Integer, out::Integer)
forward = LSTM(in, out)
backward = LSTM(in, out)
return BLSTM(forward, backward, Int(out)) |> gpu
end
function (m::BLSTM)(Xs::DenseArray{<:Real,3})
Xs = permutedims(Xs, (1, 3, 2)) # [features × time × batch] -> [features × batch × time]
# preallocate output buffer
Ys = Buffer(Xs, 2m.outdim, size(Xs,3), size(Xs,2))
axisYs₁ = axes(Ys, 1)
time_f = axes(Ys, 2)
time_b = reverse(time_f)
@inbounds begin
# get forward and backward slice indices
slice_f = axisYs₁[1:m.outdim]
slice_b = axisYs₁[(m.outdim+1):end]
# bidirectional run step
setindex!.((Ys,), m.forward.(view.((Xs,), :, :, time_f)), (slice_f,), time_f, :)
setindex!.((Ys,), m.backward.(view.((Xs,), :, :, time_b)), (slice_b,), time_b, :)
# the same as
# @views for (t_f, t_b) ∈ zip(time_f, time_b)
# Ys[slice_f, t_f, :] = m.forward(Xs[:, :, t_f])
# Ys[slice_b, t_b, :] = m.backward(Xs[:, :, t_b])
# end
# but implemented via broadcasting as Zygote differentiates loops much slower than broadcasting
end
return copy(Ys) # [features × time × batch]
end
And here are the results I get:
using BenchmarkTools
D, T, B = 512, 800, 32
m = BLSTM(D, D÷2)
θ = params(m)
Xs = rand(Float32, 512, 800, 32) |> gpu
f(m, Xs) = sum(m(Xs))
g(m, Xs, θ) = gradient(() -> f(m, Xs), θ)
julia> @benchmark f($m, $Xs)
BenchmarkTools.Trial:
memory estimate: 26.50 MiB
allocs estimate: 722527
--------------
minimum time: 607.991 ms (0.00% GC)
median time: 629.910 ms (1.41% GC)
mean time: 625.312 ms (1.13% GC)
maximum time: 631.088 ms (1.46% GC)
--------------
samples: 9
evals/sample: 1
julia> Flux.reset!(m); @benchmark g($m, $Xs, $θ)
BenchmarkTools.Trial:
memory estimate: 28.98 MiB
allocs estimate: 797407
--------------
minimum time: 603.341 ms (0.00% GC)
median time: 622.730 ms (1.30% GC)
mean time: 621.413 ms (1.17% GC)
maximum time: 629.907 ms (1.38% GC)
--------------
samples: 9
evals/sample: 1
Any improvement suggestions?