GEMM kernel, cache oblivious, and Strassen

I think many people are curious about how a GEMM kernel written in pure Julia can be much more performant than a naively written one. One can inspect the source code of LoopVectorization.jl, of course, or even the more professional Octavian.jl, but those are really monsters for non-experts to learn anything from them. Fortunately, some days ago I found the project GemmDemo.jl on github, which only uses SIMD.jl and is quite simple to learn. Its core is a 12x4 microkernel that employs all the 16 YMM registers to reduce memory access. Then I played with it, corrected some tiny bugs in it, removed its SIMD.jl dependence just for cleanliness, and combined it with stevengj’s cache-oblivious matrix multiplication code to get even better performance for large matrices. Now

  1. It is simple enough for everybody to play with;
  2. It has fairly good performance compared with MKL and Octavian on my four computers (ranging from Intel m3-8100Y to i9-11900K), tested in Julia 1.7.3 and 1.10.8;
  3. It uses no external packages but just Julia’s ability of producing AVX2-optimized codes, and so can keep working as long as Julia itself exists.
@inline vectorize(a)=(a,a,a,a)

@inline function v4load(A::AbstractMatrix{T}, x, y) where T
    unsafe_load(Base.unsafe_convert(Ptr{NTuple{4,T}}, A) + (stride(A,2)*(y-1) + x-1)*sizeof(T))
end

@inline function v4store!(A::AbstractMatrix{T}, x, y, vT) where T
    unsafe_store!(Base.unsafe_convert(Ptr{NTuple{4,T}}, A) + (stride(A,2)*(y-1) + x-1)*sizeof(T), vT)
end

@inline function gemm_nonpacking!(C,A,B, istart,iend, jstart,jend, pstart,pend, cacheK, ::Val{12},::Val{4})
    igroups = Iterators.partition(istart:iend, 12)
    jgroups = Iterators.partition(jstart:jend, 4)
    for macrok in Iterators.partition(pstart:pend, cacheK)
        for macroj in jgroups; j = macroj[1]
            for macroi in igroups; i = macroi[1]
                # load 12x4 block Cij into vector registers
                # load Cij[1:12, 1]
                Cij_r1c1vec = v4load(C, i, j); Cij_r2c1vec = v4load(C, i+4, j); Cij_r3c1vec = v4load(C, i+8, j)
                # load Cij[1:12, 2]
                Cij_r1c2vec = v4load(C, i, j+1); Cij_r2c2vec = v4load(C, i+4, j+1); Cij_r3c2vec = v4load(C, i+8, j+1)
                # load Cij[1:12, 3]
                Cij_r1c3vec = v4load(C, i, j+2); Cij_r2c3vec = v4load(C, i+4, j+2); Cij_r3c3vec = v4load(C, i+8, j+2)
                # load Cij[1:12, 4]
                Cij_r1c4vec = v4load(C, i, j+3); Cij_r2c4vec = v4load(C, i+4, j+3); Cij_r3c4vec = v4load(C, i+8, j+3)
                
                # update Cij with series of outer products
                # from the associated Aip and Bpj micropanels
                @inbounds for p in macrok
                    # update with A[i:(i+11), p] * B[p, j:(j+3)] outer product
                    Aip_vec1 = v4load(A, i, p); Aip_vec2 = v4load(A, i+4, p); Aip_vec3 = v4load(A, i+8, p)
                    # Cij[1:12, 1] += A[i:(i+11), p] * B[p, j]
                    Bpj_c1vec = vectorize(B[p, j])
                    Cij_r1c1vec = fma.(Aip_vec1, Bpj_c1vec, Cij_r1c1vec)
                    Cij_r2c1vec = fma.(Aip_vec2, Bpj_c1vec, Cij_r2c1vec)
                    Cij_r3c1vec = fma.(Aip_vec3, Bpj_c1vec, Cij_r3c1vec)
                    # Cij[1:12, 2] += A[i:(i+11), p] * B[p, j+1]
                    Bpj_c2vec = vectorize(B[p, j+1])
                    Cij_r1c2vec = fma.(Aip_vec1, Bpj_c2vec, Cij_r1c2vec)
                    Cij_r2c2vec = fma.(Aip_vec2, Bpj_c2vec, Cij_r2c2vec)
                    Cij_r3c2vec = fma.(Aip_vec3, Bpj_c2vec, Cij_r3c2vec)
                    # Cij[1:12, 3] += A[i:(i+11), p] * B[p, j+2]
                    Bpj_c3vec = vectorize(B[p, j+2])
                    Cij_r1c3vec = fma.(Aip_vec1, Bpj_c3vec, Cij_r1c3vec)
                    Cij_r2c3vec = fma.(Aip_vec2, Bpj_c3vec, Cij_r2c3vec)
                    Cij_r3c3vec = fma.(Aip_vec3, Bpj_c3vec, Cij_r3c3vec)
                    # Cij[1:12, 4] += A[i:(i+11), p] * B[p, j+3]
                    Bpj_c4vec = vectorize(B[p, j+3])
                    Cij_r1c4vec = fma.(Aip_vec1, Bpj_c4vec, Cij_r1c4vec)
                    Cij_r2c4vec = fma.(Aip_vec2, Bpj_c4vec, Cij_r2c4vec)
                    Cij_r3c4vec = fma.(Aip_vec3, Bpj_c4vec, Cij_r3c4vec)
                end

                # store Cij[1:12, 1]
                v4store!(C, i, j, Cij_r1c1vec); v4store!(C, i+4, j, Cij_r2c1vec); v4store!(C, i+8, j, Cij_r3c1vec)
                # store Cij[1:12, 2]
                v4store!(C, i, j+1, Cij_r1c2vec); v4store!(C, i+4, j+1, Cij_r2c2vec); v4store!(C, i+8, j+1, Cij_r3c2vec)
                # store Cij[1:12, 3]
                v4store!(C, i, j+2, Cij_r1c3vec); v4store!(C, i+4, j+2, Cij_r2c3vec); v4store!(C, i+8, j+2, Cij_r3c3vec)
                # store Cij[1:12, 4]
                v4store!(C, i, j+3, Cij_r1c4vec); v4store!(C, i+4, j+3, Cij_r2c4vec); v4store!(C, i+8, j+3, Cij_r3c4vec)
            end
        end
    end
    return nothing
end

function add_matmul_rec!(m,n,p, i0,j0,k0, C,A,B)
    if m+n+p <= 3*120
        gemm_nonpacking!(C,A,B, i0,i0+m-1, j0,j0+n-1, k0,k0+p-1, 60, Val(12),Val(4))
    else
        m2 = m ÷ 2; n2 = n ÷ 2; p2 = p ÷ 2
        add_matmul_rec!(m2,n2,p2, i0,j0,k0, C,A,B)
        add_matmul_rec!(m-m2,n2,p2, i0+m2,j0,k0, C,A,B)
        add_matmul_rec!(m2,n-n2,p2, i0,j0+n2,k0, C,A,B)
        add_matmul_rec!(m2,n2,p-p2, i0,j0,k0+p2, C,A,B)
        add_matmul_rec!(m-m2,n-n2,p2, i0+m2,j0+n2,k0, C,A,B)
        add_matmul_rec!(m2,n-n2,p-p2, i0,j0+n2,k0+p2, C,A,B)
        add_matmul_rec!(m-m2,n2,p-p2, i0+m2,j0,k0+p2, C,A,B)
        add_matmul_rec!(m-m2,n-n2,p-p2, i0+m2,j0+n2,k0+p2, C,A,B)
    end
    return C
end

using BenchmarkTools, LinearAlgebra

BLAS.set_num_threads(1)

for i=2:6
    A=randn(60*2^i,60*2^i)
    B=randn(60*2^i,60*2^i)
    C=zeros(size(A)) #similar(A)
    print("Size: 60*$(2^i)"); @btime mul!($C,$A,$B)
    m,n=size(C); p=size(A,2)
    @btime add_matmul_rec!($m,$n,$p, 1,1,1, $C,$A,$B)
end

The benchmark result on an i9-11900K with Win 11 is

Size: 60*4  386.900 μs (0 allocations: 0 bytes)
  409.900 μs (0 allocations: 0 bytes)
Size: 60*8  3.019 ms (0 allocations: 0 bytes)
  3.433 ms (0 allocations: 0 bytes)
Size: 60*16  25.518 ms (0 allocations: 0 bytes)
  31.524 ms (0 allocations: 0 bytes)
Size: 60*32  212.340 ms (0 allocations: 0 bytes)
  267.145 ms (0 allocations: 0 bytes)
Size: 60*64  1.693 s (0 allocations: 0 bytes)
  2.206 s (0 allocations: 0 bytes)

If using the Strassen algorithm

@inline clean!(A) = fill!(A,zero(eltype(A)))

function Strassen!(C, A, B, cA, cB, cC)
    n,m=size(A) .÷ 2; _,p=size(B) .÷ 2
    if m+n+p <= 3*240
        add_matmul_rec!(2m,2n,2p, 1,1,1, C,A,B)
    else
        A11=@view A[1:n,1:m]; A12=@view A[1:n,(m+1):end]; A21=@view A[(n+1):end,1:m]; A22=@view A[(n+1):end,(m+1):end]
        B11=@view B[1:m,1:p]; B12=@view B[1:m,(p+1):end]; B21=@view B[(m+1):end,1:p]; B22=@view B[(m+1):end,(p+1):end]
        C11=@view C[1:n,1:p]; C12=@view C[1:n,(p+1):end]; C21=@view C[(n+1):end,1:p]; C22=@view C[(n+1):end,(p+1):end]
        cA2=@view cA[1:n,1:m]; cB2=@view cB[1:m,1:p]; ccA=@view cA[1:n ÷ 2,(m+1):end]; ccB=@view cB[1:m ÷ 2,(p+1):end]
        cC2=@view cC[1:n,1:p]; ccC=@view cC[1:n ÷ 2,(p+1):end]
        @. cA2=A11+A22; @. cB2=B11+B22; Strassen!(clean!(cC2),cA2,cB2,ccA,ccB,ccC); @. C11+=cC2; @. C22+=cC2
        @. cA2=A21+A22; Strassen!(clean!(cC2),cA2,B11,ccA,ccB,ccC); @. C21+=cC2; @. C22-=cC2
        @. cB2=B12-B22; Strassen!(clean!(cC2),A11,cB2,ccA,ccB,ccC); @. C12+=cC2; @. C22+=cC2
        @. cB2=B21-B11; Strassen!(clean!(cC2),A22,cB2,ccA,ccB,ccC); @. C11+=cC2; @. C21+=cC2
        @. cA2=A11+A12; Strassen!(clean!(cC2),cA2,B22,ccA,ccB,ccC); @. C11-=cC2; @. C12+=cC2
        @. cA2=A21-A11; @. cB2=B11+B12; Strassen!(C22,cA2,cB2,ccA,ccB,ccC)
        @. cA2=A12-A22; @. cB2=B21+B22; Strassen!(C11,cA2,cB2,ccA,ccB,ccC)
    end
    return C
end

for i=2:6
    A=randn(60*2^i,60*2^i)
    B=randn(60*2^i,60*2^i)
    C=zeros(size(A)) #similar(A)
    cA=zeros(size(A,1) ÷ 2, size(A,2))
    cB=zeros(size(B,1) ÷ 2, size(B,2))
    cC=zeros(size(C,1) ÷ 2, size(C,2))
    print("Size: 60*$(2^i)")
    @btime Strassen!($C,$A,$B,$cA,$cB,$cC)
end

the result is

Size: 60*4  408.900 μs (0 allocations: 0 bytes)
Size: 60*8  3.433 ms (0 allocations: 0 bytes)
Size: 60*16  30.694 ms (0 allocations: 0 bytes)
Size: 60*32  243.806 ms (0 allocations: 0 bytes)
Size: 60*64  1.879 s (0 allocations: 0 bytes)

The efficiency gap between this demo code and MKL.jl remains around 10%.
I know there are many experts here, so I would be very happy to hear about any suggestions for further optimization or generalization.

16 Likes

The realization of the Strassen algorithm in the OP is simple but does not make use of the fact that the buffers for each layer of Strassen multiplication can have a contiguous memory layout. The following is an improved version:

@inline clean!(A) = fill!(A,zero(eltype(A)))

@inline unsafe_reshape(A::AbstractArray{T}, start, dims) where T = unsafe_wrap(Array, pointer(A) + (start-1)*sizeof(T), dims)

function Strassen!(C, A, B, cA, indA, cB, indB, cC, indC)
    n,m=size(A) .÷ 2; _,p=size(B) .÷ 2
    if m+n+p <= 3*240
        add_matmul_rec!(2m,2n,2p, 1,1,1, C,A,B)
    else
        A11=@view A[1:n,1:m]; A12=@view A[1:n,(m+1):end]; A21=@view A[(n+1):end,1:m]; A22=@view A[(n+1):end,(m+1):end]
        B11=@view B[1:m,1:p]; B12=@view B[1:m,(p+1):end]; B21=@view B[(m+1):end,1:p]; B22=@view B[(m+1):end,(p+1):end]
        C11=@view C[1:n,1:p]; C12=@view C[1:n,(p+1):end]; C21=@view C[(n+1):end,1:p]; C22=@view C[(n+1):end,(p+1):end]
        cA2=unsafe_reshape(cA, indA, (n, m)); cB2=unsafe_reshape(cB, indB, (m, p))
        cC2=unsafe_reshape(cC, indC, (n, p))
        @. cA2=A11+A22; @. cB2=B11+B22; Strassen!(clean!(cC2),cA2,cB2, cA,indA + n*m, cB,indB + m*p, cC,indC + n*p); @. C11+=cC2; @. C22+=cC2
        @. cA2=A21+A22; Strassen!(clean!(cC2),cA2,B11, cA,indA + n*m, cB,indB + m*p, cC,indC + n*p); @. C21+=cC2; @. C22-=cC2
        @. cB2=B12-B22; Strassen!(clean!(cC2),A11,cB2, cA,indA + n*m, cB,indB + m*p, cC,indC + n*p); @. C12+=cC2; @. C22+=cC2
        @. cB2=B21-B11; Strassen!(clean!(cC2),A22,cB2, cA,indA + n*m, cB,indB + m*p, cC,indC + n*p); @. C11+=cC2; @. C21+=cC2
        @. cA2=A11+A12; Strassen!(clean!(cC2),cA2,B22, cA,indA + n*m, cB,indB + m*p, cC,indC + n*p); @. C11-=cC2; @. C12+=cC2
        @. cA2=A21-A11; @. cB2=B11+B12; Strassen!(C22,cA2,cB2, cA,indA + n*m, cB,indB + m*p, cC,indC + n*p)
        @. cA2=A12-A22; @. cB2=B21+B22; Strassen!(C11,cA2,cB2, cA,indA + n*m, cB,indB + m*p, cC,indC + n*p)
    end
    return C
end

for i=3:6
    A=randn(60*2^i,60*2^i)
    B=randn(60*2^i,60*2^i)
    C=zeros(size(A)) #similar(A)
    cA=zeros(length(A) ÷ 3)
    cB=zeros(length(B) ÷ 3)
    cC=zeros(length(C) ÷ 3)
    print("Size: 60*$(2^i)")
    @btime Strassen!($C,$A,$B, $cA,1, $cB,1, $cC,1)
end

As a side advantage, the buffer size for a matrix M can be reduced from length(M)/2 to length(M)/3. I am not clear about why there are a few allocations, but anyway, it is faster:

Size: 60*8  3.499 ms (0 allocations: 0 bytes)
Size: 60*16  29.562 ms (6 allocations: 240 bytes)
Size: 60*32  235.413 ms (48 allocations: 1.88 KiB)
Size: 60*64  1.757 s (342 allocations: 13.36 KiB)

I’m surprised that the optimal threshold for Strassen is not larger than 1000x1000?

In my experiments, the optimal threshold is about 400x400 (definitely depending on the machine). There is probably quite some room to optimize my Strassen code. Upon optimized sufficiently, I guess the threshold could be as low as 200x200 or so.