I’d like to be able to be able to broadcast matrix multiplication across multidimensional arrays similar to the following:
a = rand(4,3,2)
b = rand(3,4,2)
a .* b # expect a (4,4,2) array, but instead errors
I understand this would be ambiguous in the case of 2 4x4x2 arrays as to what I wanted to do. Is there a way currently to help broadcast out by specifying a dimension?
Now, I know I can do this in a for loop, with iteration, etc. It looks like batched matrix multiplication has already been discussed by this community. As far as I can tell this was never implemented, but I might be missing something.
The real payoff here is being able to use this syntax with some of the Array interface GPU programming provided by the CuArrays.jl/CUDA.jl packages where the parallelism can really be exploited. It looks like there is already a gemm_batched function wrapping the equivalent cuBLAS functionality, but I can’t access it with simple Julia Linear Algebra calls yet.
Here are a couple of alternatives. If you use a vector of matrices, that’s faster, otherwise you can get the same effect with eachslice, though at some performance cost. And of course, it’s much faster with StaticArrays.
using BenchmarkTools, Test
A_ = [rand(4,3) for _ in 1:2];
B_ = [rand(3,4) for _ in 1:2];
A = cat(A_...; dims=3)
B = cat(B_...; dims=3)
foo(X, Y) = X .* Y
bar(X, Y) = eachslice(X; dims=3) .* eachslice(Y; dims=3)
The idea behind batch multiplication isn’t the coding style of the loop.
The trick in MKL and other libraries implementing Batch Multiplication (Very popular in DL oriented libraries) is getting the computational efficiency of large matrices multiplication. It is done by restructuring the data in a new data form.
Didn’t mean to say you didn’t. I apologize if it was offensive in any way.
I meant in the context he linked to other discussion where it mentions how batch mode is done correctly by rebuilding the data in a structure which maximizes efficiency of the computation.
Has anyone hooked up this MKL batched_gemm stuff to Julia? If I’m reading correctly, the link from before discusses operations which act on an array of matrices, while what you describe sounds more like packed / compact gemm (for many small matrices, stored interleaved).
On the CPU, batched_mul is similar to the eachslice function above (except that it slices the output too and calls mul!). Now that we have https://github.com/JuliaLang/julia/pull/36360 it should be upgraded to multi-thread the outer loop.
For tiny matrices like the example above, just writing the loops is faster than calling BLAS in any form. Perhaps StaticArrays would be faster still.
using NNlib, Einsum
ein(A,B) = @einsum C[i,j,b] := A[i,k,b] * B[k,j,b]
batched_mul(A, B) ≈ ein(A, B) ≈ cat(bar(A,B)...; dims=3)
@code_warntype bar(A, B) # Any
Thanks for the really fast and thorough answers everyone!
I’m actually using pretty large arrays, so I’d like to avoid copy operations if possible - especially on the GPU. My input is two large N-D arrays, so I’m having trouble converting that to an array of arrays on the GPU without a copy, but I can do it with a reshape on an N-D array. There’s probably a trick I’m missing.
At first glance it seems batched_multiply might be the best solution for my application, because it works on both CPU and GPU. Ideally, I’d like a function that could also do a broadcast batch multiply. Does anyone know if that exists? Something like
a = rand(4,3)
b = rand(3,4,2)
batched_multiply(a,b)
On the GPU, NNlib.batched_mul calls gemm_strided_batched!, which wants continuous arrays rather than an array of pointers, and is more efficient (IIRC).
This ought to understand broadcasting in the sense of using the same matrix a for every slice of 3D b, but does not right now. I had a PR to allow this (among other things) https://github.com/JuliaGPU/CuArrays.jl/pull/664, after which batched_mul(reshape(a, size(a)...,1), b) should work. But I didn’t finish it.
You can of course write reshape(a * reshape(b, 3,:), 4,4,2) to do this as one ordinary multiplication. Which (again IIRC) is not as quick for large square-ish arrays as the batched version.