Im interested in computing the matrix multiplication of 2x2 matrices, but these are sampled at a very large number of points. Here is my current benchmark
using Tullio,TensorCast, LinearAlgebra, StaticArrays
function MatrixMultiplication(A,B,C)
@tullio X[k,m,l,i,j] := A[k,i,a] * B[l,k,a,b] * A[k,b,c] *C[m,k,c,j] #- 2*(A[j,k,a] * C[i,j,a,b] * A[j,b,c] *B[m,j,c,l])
return X;
end;
function RowMatrixMultiplication(A,B,C)
@tullio X[i,j,k,m,l] := A[i,a,k] * B[a,b,m,k] * A[b,c,k]* C[c,j,l,k]
return X;
end;
function CastMatrixMultiplication(A,B,C)
@cast A2[k]{i,j} := A[k,i,j]
@cast B2[m,k]{i,j} := B[m,k,i,j]
@cast C2[m,k]{i,j} := C[m,k,i,j]
@tullio X[m,l,k] := A2[k]*B2[m,k]*A2[k]*C2[l,k]
return X
end
which I call via
function main()
A_dim = 6*6
B_dim = 2^9
C_dim = B_dim
A = rand(A_dim,2,2)+rand(A_dim,2,2)*1im
B = rand(B_dim,A_dim,2,2)+rand(B_dim,A_dim,2,2)*1im
C = rand(C_dim,A_dim,2,2)+rand(C_dim,A_dim,2,2)*1im
ws = rand(B_dim)
@time xa = MatrixMultiplication(A,B,C);
@time xaa = MatrixMultiplication(A,B,C);
@time xb = CastMatrixMultiplication(A,B,C);
@time xbb = CastMatrixMultiplication(A,B,C);
A = rand(2,2,A_dim)+rand(2,2,A_dim)*1im
B = rand(2,2,B_dim,A_dim)+rand(2,2,B_dim,A_dim)*1im
C = rand(2,2,C_dim,A_dim)+rand(2,2,C_dim,A_dim)*1im
@time xc = RowMatrixMultiplication(A,B,C);
@time xcc = RowMatrixMultiplication(A,B,C);
end
and I get (for the second run, so without compile time)
1.601916 seconds (3 allocations: 576.000 MiB, 9.17% gc time)
2.227399 seconds (47.19 M allocations: 4.078 GiB, 27.83% gc time)
1.286178 seconds (3 allocations: 576.000 MiB, 1.95% gc time)
My best try in Python is
def Matrix(A,B,C):
X = A[None,None,:,:,:] @ B[:,None,:,:,:] @ A[None,None,:,:,:] @ C[None,:,:,:,:]
return(X)
which gives ~1.6
seconds as well. The right row first one beats Python somewhat, but I’m used to Tullio seeing 10x increases in speed somewhere. Is this down to BLAS speeds or is there anything I can do? Thanks in advance. I know I can speed it up somewhat by pre-defining the array as well, but am praying for an even more dramatic improvement.