GEMM kernel, cache oblivious, and Strassen

I think many people are curious about how a GEMM kernel written in pure Julia can be much more performant than a naively written one. One can inspect the source code of LoopVectorization.jl, of course, or even the more professional Octavian.jl, but those are really monsters for non-experts to learn anything from them. Fortunately, some days ago I found the project GemmDemo.jl on github, which only uses SIMD.jl and is quite simple to learn. Its core is a 12x4 microkernel that employs all the 16 YMM registers to reduce memory access. Then I played with it, corrected some tiny bugs in it, removed its SIMD.jl dependence just for cleanliness, and combined it with stevengj’s cache-oblivious matrix multiplication code to get even better performance for large matrices. Now

  1. It is simple enough for everybody to play with;
  2. It has fairly good performance compared with MKL and Octavian on my four computers (ranging from Intel m3-8100Y to i9-11900K), tested in Julia 1.7.3 and 1.10.8;
  3. It uses no external packages but just Julia’s ability of producing AVX2-optimized codes, and so can keep working as long as Julia itself exists.
@inline vectorize(a)=(a,a,a,a)

@inline function v4load(A, x, y)
    T = eltype(A)
    unsafe_load(Base.unsafe_convert(Ptr{NTuple{4,T}}, A) + (stride(A,2)*(y-1) + x-1)*sizeof(T))
end

@inline function v4store!(A, x, y, vT)
    T = eltype(A)
    unsafe_store!(Base.unsafe_convert(Ptr{NTuple{4,T}}, A) + (stride(A,2)*(y-1) + x-1)*sizeof(T), vT)
end

@inline function gemm_nonpacking!(C,A,B, istart,iend, jstart,jend, pstart,pend, cacheK, ::Val{12},::Val{4})
    igroups = Iterators.partition(istart:iend, 12)
    jgroups = Iterators.partition(jstart:jend, 4)
    for macrok in Iterators.partition(pstart:pend, cacheK)
        for macroj in jgroups; j = macroj[1]
            for macroi in igroups; i = macroi[1]
                # load 12x4 block Cij into vector registers
                # load Cij[1:12, 1]
                Cij_r1c1vec = v4load(C, i, j); Cij_r2c1vec = v4load(C, i+4, j); Cij_r3c1vec = v4load(C, i+8, j)
                # load Cij[1:12, 2]
                Cij_r1c2vec = v4load(C, i, j+1); Cij_r2c2vec = v4load(C, i+4, j+1); Cij_r3c2vec = v4load(C, i+8, j+1)
                # load Cij[1:12, 3]
                Cij_r1c3vec = v4load(C, i, j+2); Cij_r2c3vec = v4load(C, i+4, j+2); Cij_r3c3vec = v4load(C, i+8, j+2)
                # load Cij[1:12, 4]
                Cij_r1c4vec = v4load(C, i, j+3); Cij_r2c4vec = v4load(C, i+4, j+3); Cij_r3c4vec = v4load(C, i+8, j+3)
                
                # update Cij with series of outer products
                # from the associated Aip and Bpj micropanels
                @inbounds for p in macrok
                    # update with A[i:(i+11), p] * B[p, j:(j+3)] outer product
                    Aip_vec1 = v4load(A, i, p); Aip_vec2 = v4load(A, i+4, p); Aip_vec3 = v4load(A, i+8, p)
                    # Cij[1:12, 1] += A[i:(i+11), p] * B[p, j]
                    Bpj_c1vec = vectorize(B[p, j])
                    Cij_r1c1vec = fma.(Aip_vec1, Bpj_c1vec, Cij_r1c1vec)
                    Cij_r2c1vec = fma.(Aip_vec2, Bpj_c1vec, Cij_r2c1vec)
                    Cij_r3c1vec = fma.(Aip_vec3, Bpj_c1vec, Cij_r3c1vec)
                    # Cij[1:12, 2] += A[i:(i+11), p] * B[p, j+1]
                    Bpj_c2vec = vectorize(B[p, j+1])
                    Cij_r1c2vec = fma.(Aip_vec1, Bpj_c2vec, Cij_r1c2vec)
                    Cij_r2c2vec = fma.(Aip_vec2, Bpj_c2vec, Cij_r2c2vec)
                    Cij_r3c2vec = fma.(Aip_vec3, Bpj_c2vec, Cij_r3c2vec)
                    # Cij[1:12, 3] += A[i:(i+11), p] * B[p, j+2]
                    Bpj_c3vec = vectorize(B[p, j+2])
                    Cij_r1c3vec = fma.(Aip_vec1, Bpj_c3vec, Cij_r1c3vec)
                    Cij_r2c3vec = fma.(Aip_vec2, Bpj_c3vec, Cij_r2c3vec)
                    Cij_r3c3vec = fma.(Aip_vec3, Bpj_c3vec, Cij_r3c3vec)
                    # Cij[1:12, 4] += A[i:(i+11), p] * B[p, j+3]
                    Bpj_c4vec = vectorize(B[p, j+3])
                    Cij_r1c4vec = fma.(Aip_vec1, Bpj_c4vec, Cij_r1c4vec)
                    Cij_r2c4vec = fma.(Aip_vec2, Bpj_c4vec, Cij_r2c4vec)
                    Cij_r3c4vec = fma.(Aip_vec3, Bpj_c4vec, Cij_r3c4vec)
                end

                # store Cij[1:12, 1]
                v4store!(C, i, j, Cij_r1c1vec); v4store!(C, i+4, j, Cij_r2c1vec); v4store!(C, i+8, j, Cij_r3c1vec)
                # store Cij[1:12, 2]
                v4store!(C, i, j+1, Cij_r1c2vec); v4store!(C, i+4, j+1, Cij_r2c2vec); v4store!(C, i+8, j+1, Cij_r3c2vec)
                # store Cij[1:12, 3]
                v4store!(C, i, j+2, Cij_r1c3vec); v4store!(C, i+4, j+2, Cij_r2c3vec); v4store!(C, i+8, j+2, Cij_r3c3vec)
                # store Cij[1:12, 4]
                v4store!(C, i, j+3, Cij_r1c4vec); v4store!(C, i+4, j+3, Cij_r2c4vec); v4store!(C, i+8, j+3, Cij_r3c4vec)
            end
        end
    end
    return nothing
end

function add_matmul_rec!(m,n,p, i0,j0,k0, C,A,B)
    if m+n+p <= 3*120
        gemm_nonpacking!(C,A,B, i0,i0+m-1, j0,j0+n-1, k0,k0+p-1, 60, Val(12),Val(4))
    else
        m2 = m ÷ 2; n2 = n ÷ 2; p2 = p ÷ 2
        add_matmul_rec!(m2,n2,p2, i0,j0,k0, C,A,B)
        add_matmul_rec!(m-m2,n2,p2, i0+m2,j0,k0, C,A,B)
        add_matmul_rec!(m2,n-n2,p2, i0,j0+n2,k0, C,A,B)
        add_matmul_rec!(m2,n2,p-p2, i0,j0,k0+p2, C,A,B)
        add_matmul_rec!(m-m2,n-n2,p2, i0+m2,j0+n2,k0, C,A,B)
        add_matmul_rec!(m2,n-n2,p-p2, i0,j0+n2,k0+p2, C,A,B)
        add_matmul_rec!(m-m2,n2,p-p2, i0+m2,j0,k0+p2, C,A,B)
        add_matmul_rec!(m-m2,n-n2,p-p2, i0+m2,j0+n2,k0+p2, C,A,B)
    end
    return C
end

using BenchmarkTools, LinearAlgebra

BLAS.set_num_threads(1)

for i=2:6
    A=randn(60*2^i,60*2^i)
    B=randn(60*2^i,60*2^i)
    C=zeros(size(A)) #similar(A)
    print("Size: 60*$(2^i)"); @btime mul!($C,$A,$B)
    m,n=size(C); p=size(A,2)
    @btime add_matmul_rec!($m,$n,$p, 1,1,1, $C,$A,$B)
end

The benchmark result on an i9-11900K with Win 11 is

Size: 60*4  386.900 μs (0 allocations: 0 bytes)
  409.900 μs (0 allocations: 0 bytes)
Size: 60*8  3.019 ms (0 allocations: 0 bytes)
  3.433 ms (0 allocations: 0 bytes)
Size: 60*16  25.518 ms (0 allocations: 0 bytes)
  31.524 ms (0 allocations: 0 bytes)
Size: 60*32  212.340 ms (0 allocations: 0 bytes)
  267.145 ms (0 allocations: 0 bytes)
Size: 60*64  1.693 s (0 allocations: 0 bytes)
  2.206 s (0 allocations: 0 bytes)

If using the Strassen algorithm

clean!(A) = fill!(A,zero(eltype(A)))

function Strassen!(C, A, B, cA, cB, cC)
    n,m=size(A) .÷ 2; _,p=size(B) .÷ 2
    if m+n+p <= 3*240
        add_matmul_rec!(2m,2n,2p, 1,1,1, C,A,B)
    else
        A11=@view A[1:n,1:m]; A12=@view A[1:n,(m+1):end]; A21=@view A[(n+1):end,1:m]; A22=@view A[(n+1):end,(m+1):end]
        B11=@view B[1:m,1:p]; B12=@view B[1:m,(p+1):end]; B21=@view B[(m+1):end,1:p]; B22=@view B[(m+1):end,(p+1):end]
        C11=@view C[1:n,1:p]; C12=@view C[1:n,(p+1):end]; C21=@view C[(n+1):end,1:p]; C22=@view C[(n+1):end,(p+1):end]
        cA2=@view cA[1:n,1:m]; cB2=@view cB[1:m,1:p]; ccA=@view cA[1:n ÷ 2,(m+1):end]; ccB=@view cB[1:m ÷ 2,(p+1):end]
        cC2=@view cC[1:n,1:p]; ccC=@view cC[1:n ÷ 2,(p+1):end]
        @. cA2=A11+A22; @. cB2=B11+B22; Strassen!(clean!(cC2),cA2,cB2,ccA,ccB,ccC); @. C11+=cC2; @. C22+=cC2
        @. cA2=A21+A22; Strassen!(clean!(cC2),cA2,B11,ccA,ccB,ccC); @. C21+=cC2; @. C22-=cC2
        @. cB2=B12-B22; Strassen!(clean!(cC2),A11,cB2,ccA,ccB,ccC); @. C12+=cC2; @. C22+=cC2
        @. cB2=B21-B11; Strassen!(clean!(cC2),A22,cB2,ccA,ccB,ccC); @. C11+=cC2; @. C21+=cC2
        @. cA2=A11+A12; Strassen!(clean!(cC2),cA2,B22,ccA,ccB,ccC); @. C11-=cC2; @. C12+=cC2
        @. cA2=A21-A11; @. cB2=B11+B12; Strassen!(C22,cA2,cB2,ccA,ccB,ccC)
        @. cA2=A12-A22; @. cB2=B21+B22; Strassen!(C11,cA2,cB2,ccA,ccB,ccC)
    end
    return C
end

for i=2:6
    A=randn(60*2^i,60*2^i)
    B=randn(60*2^i,60*2^i)
    C=zeros(size(A)) #similar(A)
    cA=zeros(size(A,1) ÷ 2, size(A,2))
    cB=zeros(size(B,1) ÷ 2, size(B,2))
    cC=zeros(size(C,1) ÷ 2, size(C,2))
    print("Size: 60*$(2^i)")
    @btime Strassen!($C,$A,$B,$cA,$cB,$cC)
end

the result is

Size: 60*4  408.900 μs (0 allocations: 0 bytes)
Size: 60*8  3.433 ms (0 allocations: 0 bytes)
Size: 60*16  30.694 ms (0 allocations: 0 bytes)
Size: 60*32  243.806 ms (0 allocations: 0 bytes)
Size: 60*64  1.879 s (0 allocations: 0 bytes)

The efficiency gap between this demo code and MKL.jl remains around 10%.
I know there are many experts here, so I would be very happy to hear about any suggestions for further optimization or generalization.

13 Likes