Optimizing Complex Batch Matrix Multiplication

Hello fellow Julians!
In my code the bottleneck step is a batched ComplexF64 matrix multiplication of the form C[i,k,n] = conj(A)[j,i,n] * B[j,k,n]. Naively, one should loop over n and do in-place multiplication. However, in the case of floats, it seems that tullio can speed things up considerably. Is it possible to do something better in the complex case as well? Can you help me optimize things further?

To start off, here’s a reference implementation:

using LinearAlgebra
using BenchmarkTools
BLAS.set_num_threads(1)
N, M, K = 500, 10, 200
A = rand(ComplexF64,N,M,K)
B = rand(ComplexF64,N,M,K)
C = zeros(ComplexF64,M,M,K);

function batch_zgemm1!(C,A,B)
    for k in axes(C,3)
        @views mul!(C[:,:,k],A[:,:,k]',B[:,:,k])
    end
    return C
end 
@btime batch_zgemm1!($C, $A, $B) # 4.561 ms (0 allocations: 0 bytes)

A few notes:

  • I require complex matrices with full precision for my use case.
  • Typical problem sizes for me are N = 500-1500, M = 2-20, K = 100-1600.
  • I am aware that MKL has a batched matrix multiplication library, but it would be good to have a pure-Julia solution for my non-Intel computer.
  • The individual multiplications are large enough that BLAS benefits from multithreading, so paralellizing over K doesn’t offer any obvious improvements.

You want to parellize over k and use single threaded matmul. Matmul paralellizes, but doesn’t do so perfectly.

I see: you get a speedup of x Nthreads from paralellizing over the outer loop, while BLAS gets an multiplier of < Nthreads from using more BLAS threads.

function batch_zgemm2!(C, A, B)
    Threads.@threads for k in axes(A, 3)
        @views mul!(C[:,:,k],A[:,:,k]',B[:,:,k])
    end
    return C
end
Threads.nthreads() # 6
@btime batch_zgemm2!($C, $A, $B) # 758.167 μs (35 allocations: 3.38 KiB)

However, is it possible to do better?