Matrix Multiplication of a Large Number of Small Matrices

Im interested in computing the matrix multiplication of 2x2 matrices, but these are sampled at a very large number of points. Here is my current benchmark

using Tullio,TensorCast, LinearAlgebra, StaticArrays

function MatrixMultiplication(A,B,C)

    @tullio X[k,m,l,i,j] := A[k,i,a] * B[l,k,a,b] * A[k,b,c] *C[m,k,c,j] #- 2*(A[j,k,a] * C[i,j,a,b] * A[j,b,c] *B[m,j,c,l])

    return X;
end;
function RowMatrixMultiplication(A,B,C)
    @tullio X[i,j,k,m,l] := A[i,a,k] * B[a,b,m,k] * A[b,c,k]* C[c,j,l,k] 
    
    return X;
end;

function CastMatrixMultiplication(A,B,C)
    @cast A2[k]{i,j} := A[k,i,j]
    @cast B2[m,k]{i,j} := B[m,k,i,j]
    @cast C2[m,k]{i,j} := C[m,k,i,j]
    @tullio X[m,l,k] := A2[k]*B2[m,k]*A2[k]*C2[l,k]
    return X
end

which I call via

function main()
    A_dim = 6*6
    B_dim = 2^9
    C_dim = B_dim
    A = rand(A_dim,2,2)+rand(A_dim,2,2)*1im
    B = rand(B_dim,A_dim,2,2)+rand(B_dim,A_dim,2,2)*1im
    C = rand(C_dim,A_dim,2,2)+rand(C_dim,A_dim,2,2)*1im
    ws = rand(B_dim)
    @time xa = MatrixMultiplication(A,B,C);
    @time xaa = MatrixMultiplication(A,B,C);
    @time xb = CastMatrixMultiplication(A,B,C);
    @time xbb = CastMatrixMultiplication(A,B,C);
    A = rand(2,2,A_dim)+rand(2,2,A_dim)*1im
    B = rand(2,2,B_dim,A_dim)+rand(2,2,B_dim,A_dim)*1im
    C = rand(2,2,C_dim,A_dim)+rand(2,2,C_dim,A_dim)*1im
    @time xc = RowMatrixMultiplication(A,B,C);
    @time xcc = RowMatrixMultiplication(A,B,C);

 end

and I get (for the second run, so without compile time)

  1.601916 seconds (3 allocations: 576.000 MiB, 9.17% gc time)
  2.227399 seconds (47.19 M allocations: 4.078 GiB, 27.83% gc time)
  1.286178 seconds (3 allocations: 576.000 MiB, 1.95% gc time)

My best try in Python is

def Matrix(A,B,C):
    X = A[None,None,:,:,:] @ B[:,None,:,:,:] @ A[None,None,:,:,:] @ C[None,:,:,:,:]
    return(X)

which gives ~1.6 seconds as well. The right row first one beats Python somewhat, but I’m used to Tullio seeing 10x increases in speed somewhere. Is this down to BLAS speeds or is there anything I can do? Thanks in advance. I know I can speed it up somewhat by pre-defining the array as well, but am praying for an even more dramatic improvement.

Julia and python have the opposite major orders, so your CastMatrixMultiplication is inefficient. The simplest thing you can do to speed this up substantially is transpose that operation.`

function CastRowMatrixMultiplication(A,B,C)
    @cast A2[k]{i,j} := A[i,j,k]
    @cast B2[m,k]{i,j} := B[i,j,m,k]
    @cast C2[m,k]{i,j} := C[i,j, m,k]
    @tullio X[m,l,k] := A2[k]*B2[m,k]*A2[k]*C2[l,k]
    return X
end

let
    A_dim = 6*6
    B_dim = 2^9
    C_dim = B_dim
    A = rand(A_dim,2,2)+rand(A_dim,2,2)*1im
    B = rand(B_dim,A_dim,2,2)+rand(B_dim,A_dim,2,2)*1im
    C = rand(C_dim,A_dim,2,2)+rand(C_dim,A_dim,2,2)*1im

    @btime xa = MatrixMultiplication($A,$B,$C);
    @btime xb = CastMatrixMultiplication($A,$B,$C);

    A = rand(2,2,A_dim)+rand(2,2,A_dim)*1im
    B = rand(2,2,B_dim,A_dim)+rand(2,2,B_dim,A_dim)*1im
    C = rand(2,2,C_dim,A_dim)+rand(2,2,C_dim,A_dim)*1im
    
    @btime xcc = RowMatrixMultiplication($A,$B,$C);
    @btime xcc = CastRowMatrixMultiplication($A,$B,$C);
    
 end

gives me

:   325.585 ms (77 allocations: 576.00 MiB) # MatrixMultiplication
:   177.564 ms (95 allocations: 576.01 MiB) # CastMatrixMultiplication
:   226.280 ms (77 allocations: 576.00 MiB) # RowMatrixMultiplication
:   87.582 ms (98 allocations: 576.01 MiB) # CastRowMatrixMultiplication
3 Likes

@Mason When I use the @cast versions, I get from using @btime

2.253 s (47185934 allocations: 4.08 GiB)

for both versions. What version of Julia are you using, or what might I be doing wrong?

Try StaticArrays.

1 Like

@stevengj This is exactly what the Cast functions do…

To get the full benefit, you need the inputs to be arrays of SMatrix, so that the compiler knows the sizes. Don’t convert to static arrays only for the multiplication step, use them from the beginning (and elsewhere throughout your program), e.g. generate them via:

T = SMatrix{2,2,ComplexF64,4}
A = rand(T,A_dim)
B = rand(T,B_dim,A_dim)
C = rand(T,C_dim,A_dim)

This will also fix the issue of the dimension ordering noted by @Mason above (and issues with constant propagation of array sizes, which is no longer needed), though to get even better locality I would also probably transpose A_dim with B_dim so that it becomes:

@tullio X[k,l,m] := A[k]*B[k,m]*A[k]*C[k,l]

(but you can experiment a bit with different dimension orderings of X and B to see which one is fastest with Tullio).

3 Likes

I’m on 1.9-beta3. You’re probably seeing a type instability because the array sizes aren’t being constant propagated, which if I recall correctly is a fairly recent change.

I’ve had a good speedup by calculating:

@tullio AB[k,m] := A[k]*B[k,m]
@tullio AC[k,l] := A[k]*C[k,l]

and then doing the cross product (which is much larger):

@tullio X[k,l,m] := AB[k,m] * AC[k,l]
2 Likes

Good point, this is a classic space-time tradeoff — it’s a good idea to cache these products with A since they are re-used many times. This sort of thing comes up a lot in broadcast-type operations, e.g. Why is this Julia code considerably slower than Matlab - #62 by stevengj

4 Likes

For large arrays, a chain like this is Tullio’s worst nightmare. It will make 8 nested loops, O(N^8), whereas 3 separate multiplications will be O(N^5) if I’m counting right. For very small arrays, that’s not obviously the right way to think, but I wouldn’t hope for much.

Something similar applies to the arrays of matrices in A[k]*B[k,m]*A[k]*C[k,l], as noted.

FWIW I see 2.5s for CastMatrixMultiplication but 160ms for CastRowMatrixMultiplication, on Julia nightly.

TensorCast has a (slightly weird) notation to provide the sizes, if you know they are always 2x2, like so: @cast B2[m,k]{i:2,j:2} := B[m,k,i,j]. This removes the instability.

2 Likes

Thanks to @stevengj @Mason, @Dan and @mcabbott, these are all helpful suggestions. I’ve seen a lot of advice about trying to embed as much Julia processing in a single for-loop for optimization reasons. The rest of the process involves

@reduce XSum[mu,nu]  := sum(k,l,m)  imag.(X[k,l,m])
result = tr.(XSum)

This seems quite tidy, but is there a similar way in that the cross-product gives a ~30% speed-up, I should be unvectorizing some of this for optimization? What’s the intuition (coming from a Python background) about this?

This is pretty wasteful — you are computing the whole Xsum matrix just to take the trace? Also, no need for the dot in the imag call. Why not just sum tr(imag(X[…])) directly?

2 Likes