# Batched Matrix Multiply

I’d like to be able to be able to broadcast matrix multiplication across multidimensional arrays similar to the following:

``````a = rand(4,3,2)
b = rand(3,4,2)

a .* b  # expect a (4,4,2) array, but instead errors
``````

I understand this would be ambiguous in the case of 2 4x4x2 arrays as to what I wanted to do. Is there a way currently to help broadcast out by specifying a dimension?

Now, I know I can do this in a for loop, with iteration, etc. It looks like batched matrix multiplication has already been discussed by this community. As far as I can tell this was never implemented, but I might be missing something.

The real payoff here is being able to use this syntax with some of the Array interface GPU programming provided by the CuArrays.jl/CUDA.jl packages where the parallelism can really be exploited. It looks like there is already a `gemm_batched` function wrapping the equivalent cuBLAS functionality, but I can’t access it with simple Julia Linear Algebra calls yet.

You could use an array of arrays. For small matrices, using an array of `SMatrix` (from StaticArrays) should be especially efficient.

3 Likes

There are so many features in `MKL` that can improve many real world use cases.
I wish 2 things happened:

1. The integration of `MKL` (Be it `MKL.jl` will take advantage of that).
2. Julia will work with OpenBLAS to implement them as well.

You are probably looking for `NNlib.batched_mul`.

On the CPU this is a simple loop, because nobody has got around to hooking it up to the special MKL routines. On the GPU it calls the cuBLAS function.

Here are a couple of alternatives. If you use a vector of matrices, that’s faster, otherwise you can get the same effect with `eachslice`, though at some performance cost. And of course, it’s much faster with StaticArrays.

``````using BenchmarkTools, Test

A_ = [rand(4,3) for _ in 1:2];
B_ = [rand(3,4) for _ in 1:2];
A = cat(A_...; dims=3)
B = cat(B_...; dims=3)

foo(X, Y) = X .* Y
bar(X, Y) = eachslice(X; dims=3) .* eachslice(Y; dims=3)
``````
``````julia> @test foo(A_, B_) == bar(A, B)
Test Passed

julia> @btime foo(\$A_, \$B_)
540.212 ns (3 allocations: 512 bytes)

julia> @btime bar(\$A, \$B);
1.797 μs (17 allocations: 1.11 KiB)
``````

The idea behind batch multiplication isn’t the coding style of the loop.
The trick in `MKL` and other libraries implementing Batch Multiplication (Very popular in DL oriented libraries) is getting the computational efficiency of large matrices multiplication. It is done by restructuring the data in a new data form.

I just tried to answer the question in the OP.

Didn’t mean to say you didn’t. I apologize if it was offensive in any way.
I meant in the context he linked to other discussion where it mentions how batch mode is done correctly by rebuilding the data in a structure which maximizes efficiency of the computation.

Has anyone hooked up this MKL batched_gemm stuff to Julia? If I’m reading correctly, the link from before discusses operations which act on an array of matrices, while what you describe sounds more like packed / compact gemm (for many small matrices, stored interleaved).

On the CPU, `batched_mul` is similar to the `eachslice` function above (except that it slices the output too and calls `mul!`). Now that we have https://github.com/JuliaLang/julia/pull/36360 it should be upgraded to multi-thread the outer loop.

For tiny matrices like the example above, just writing the loops is faster than calling BLAS in any form. Perhaps StaticArrays would be faster still.

``````using NNlib, Einsum
ein(A,B) = @einsum C[i,j,b] := A[i,k,b] * B[k,j,b]
batched_mul(A, B) ≈ ein(A, B) ≈ cat(bar(A,B)...; dims=3)
@code_warntype bar(A, B) # Any
``````
``````julia> @btime ein(\$A, \$B);
133.104 ns (1 allocation: 336 bytes)

julia> @btime batched_mul(\$A, \$B);
328.615 ns (1 allocation: 336 bytes)
``````
1 Like

Thanks for the really fast and thorough answers everyone!

I’m actually using pretty large arrays, so I’d like to avoid copy operations if possible - especially on the GPU. My input is two large N-D arrays, so I’m having trouble converting that to an array of arrays on the GPU without a copy, but I can do it with a `reshape` on an N-D array. There’s probably a trick I’m missing.

At first glance it seems `batched_multiply` might be the best solution for my application, because it works on both CPU and GPU. Ideally, I’d like a function that could also do a broadcast batch multiply. Does anyone know if that exists? Something like

``````a = rand(4,3)
b = rand(3,4,2)
batched_multiply(a,b)
``````

On the GPU, `NNlib.batched_mul` calls `gemm_strided_batched!`, which wants continuous arrays rather than an array of pointers, and is more efficient (IIRC).

This ought to understand broadcasting in the sense of using the same matrix `a` for every slice of 3D `b`, but does not right now. I had a PR to allow this (among other things) https://github.com/JuliaGPU/CuArrays.jl/pull/664, after which `batched_mul(reshape(a, size(a)...,1), b)` should work. But I didn’t finish it.

You can of course write `reshape(a * reshape(b, 3,:), 4,4,2)` to do this as one ordinary multiplication. Which (again IIRC) is not as quick for large square-ish arrays as the batched version.