Batched Matrix Multiply

Here are a couple of alternatives. If you use a vector of matrices, that’s faster, otherwise you can get the same effect with eachslice, though at some performance cost. And of course, it’s much faster with StaticArrays.

using BenchmarkTools, Test

A_ = [rand(4,3) for _ in 1:2];
B_ = [rand(3,4) for _ in 1:2];
A = cat(A_...; dims=3)
B = cat(B_...; dims=3)

foo(X, Y) = X .* Y
bar(X, Y) = eachslice(X; dims=3) .* eachslice(Y; dims=3)
julia> @test foo(A_, B_) == bar(A, B)
Test Passed

julia> @btime foo($A_, $B_)
  540.212 ns (3 allocations: 512 bytes)

julia> @btime bar($A, $B);
  1.797 μs (17 allocations: 1.11 KiB)
1 Like