I think many people are curious about how a GEMM kernel written in pure Julia can be much more performant than a naively written one. One can inspect the source code of LoopVectorization.jl, of course, or even the more professional Octavian.jl, but those are really monsters for non-experts to learn anything from them. Fortunately, some days ago I found the project GemmDemo.jl on github, which only uses SIMD.jl and is quite simple to learn. Its core is a 12x4 microkernel that employs all the 16 YMM registers to reduce memory access. Then I played with it, corrected some tiny bugs in it, removed its SIMD.jl dependence just for cleanliness, and combined it with stevengj’s cache-oblivious matrix multiplication code to get even better performance for large matrices. Now
- It is simple enough for everybody to play with;
- It has fairly good performance compared with MKL and Octavian on my four computers (ranging from Intel m3-8100Y to i9-11900K), tested in Julia 1.7.3 and 1.10.8;
- It uses no external packages but just Julia’s ability of producing AVX2-optimized codes, and so can keep working as long as Julia itself exists.
@inline vectorize(a)=(a,a,a,a)
@inline function v4load(A, x, y)
T = eltype(A)
unsafe_load(Base.unsafe_convert(Ptr{NTuple{4,T}}, A) + (stride(A,2)*(y-1) + x-1)*sizeof(T))
end
@inline function v4store!(A, x, y, vT)
T = eltype(A)
unsafe_store!(Base.unsafe_convert(Ptr{NTuple{4,T}}, A) + (stride(A,2)*(y-1) + x-1)*sizeof(T), vT)
end
@inline function gemm_nonpacking!(C,A,B, istart,iend, jstart,jend, pstart,pend, cacheK, ::Val{12},::Val{4})
igroups = Iterators.partition(istart:iend, 12)
jgroups = Iterators.partition(jstart:jend, 4)
for macrok in Iterators.partition(pstart:pend, cacheK)
for macroj in jgroups; j = macroj[1]
for macroi in igroups; i = macroi[1]
# load 12x4 block Cij into vector registers
# load Cij[1:12, 1]
Cij_r1c1vec = v4load(C, i, j); Cij_r2c1vec = v4load(C, i+4, j); Cij_r3c1vec = v4load(C, i+8, j)
# load Cij[1:12, 2]
Cij_r1c2vec = v4load(C, i, j+1); Cij_r2c2vec = v4load(C, i+4, j+1); Cij_r3c2vec = v4load(C, i+8, j+1)
# load Cij[1:12, 3]
Cij_r1c3vec = v4load(C, i, j+2); Cij_r2c3vec = v4load(C, i+4, j+2); Cij_r3c3vec = v4load(C, i+8, j+2)
# load Cij[1:12, 4]
Cij_r1c4vec = v4load(C, i, j+3); Cij_r2c4vec = v4load(C, i+4, j+3); Cij_r3c4vec = v4load(C, i+8, j+3)
# update Cij with series of outer products
# from the associated Aip and Bpj micropanels
@inbounds for p in macrok
# update with A[i:(i+11), p] * B[p, j:(j+3)] outer product
Aip_vec1 = v4load(A, i, p); Aip_vec2 = v4load(A, i+4, p); Aip_vec3 = v4load(A, i+8, p)
# Cij[1:12, 1] += A[i:(i+11), p] * B[p, j]
Bpj_c1vec = vectorize(B[p, j])
Cij_r1c1vec = fma.(Aip_vec1, Bpj_c1vec, Cij_r1c1vec)
Cij_r2c1vec = fma.(Aip_vec2, Bpj_c1vec, Cij_r2c1vec)
Cij_r3c1vec = fma.(Aip_vec3, Bpj_c1vec, Cij_r3c1vec)
# Cij[1:12, 2] += A[i:(i+11), p] * B[p, j+1]
Bpj_c2vec = vectorize(B[p, j+1])
Cij_r1c2vec = fma.(Aip_vec1, Bpj_c2vec, Cij_r1c2vec)
Cij_r2c2vec = fma.(Aip_vec2, Bpj_c2vec, Cij_r2c2vec)
Cij_r3c2vec = fma.(Aip_vec3, Bpj_c2vec, Cij_r3c2vec)
# Cij[1:12, 3] += A[i:(i+11), p] * B[p, j+2]
Bpj_c3vec = vectorize(B[p, j+2])
Cij_r1c3vec = fma.(Aip_vec1, Bpj_c3vec, Cij_r1c3vec)
Cij_r2c3vec = fma.(Aip_vec2, Bpj_c3vec, Cij_r2c3vec)
Cij_r3c3vec = fma.(Aip_vec3, Bpj_c3vec, Cij_r3c3vec)
# Cij[1:12, 4] += A[i:(i+11), p] * B[p, j+3]
Bpj_c4vec = vectorize(B[p, j+3])
Cij_r1c4vec = fma.(Aip_vec1, Bpj_c4vec, Cij_r1c4vec)
Cij_r2c4vec = fma.(Aip_vec2, Bpj_c4vec, Cij_r2c4vec)
Cij_r3c4vec = fma.(Aip_vec3, Bpj_c4vec, Cij_r3c4vec)
end
# store Cij[1:12, 1]
v4store!(C, i, j, Cij_r1c1vec); v4store!(C, i+4, j, Cij_r2c1vec); v4store!(C, i+8, j, Cij_r3c1vec)
# store Cij[1:12, 2]
v4store!(C, i, j+1, Cij_r1c2vec); v4store!(C, i+4, j+1, Cij_r2c2vec); v4store!(C, i+8, j+1, Cij_r3c2vec)
# store Cij[1:12, 3]
v4store!(C, i, j+2, Cij_r1c3vec); v4store!(C, i+4, j+2, Cij_r2c3vec); v4store!(C, i+8, j+2, Cij_r3c3vec)
# store Cij[1:12, 4]
v4store!(C, i, j+3, Cij_r1c4vec); v4store!(C, i+4, j+3, Cij_r2c4vec); v4store!(C, i+8, j+3, Cij_r3c4vec)
end
end
end
return nothing
end
function add_matmul_rec!(m,n,p, i0,j0,k0, C,A,B)
if m+n+p <= 3*120
gemm_nonpacking!(C,A,B, i0,i0+m-1, j0,j0+n-1, k0,k0+p-1, 60, Val(12),Val(4))
else
m2 = m ÷ 2; n2 = n ÷ 2; p2 = p ÷ 2
add_matmul_rec!(m2,n2,p2, i0,j0,k0, C,A,B)
add_matmul_rec!(m-m2,n2,p2, i0+m2,j0,k0, C,A,B)
add_matmul_rec!(m2,n-n2,p2, i0,j0+n2,k0, C,A,B)
add_matmul_rec!(m2,n2,p-p2, i0,j0,k0+p2, C,A,B)
add_matmul_rec!(m-m2,n-n2,p2, i0+m2,j0+n2,k0, C,A,B)
add_matmul_rec!(m2,n-n2,p-p2, i0,j0+n2,k0+p2, C,A,B)
add_matmul_rec!(m-m2,n2,p-p2, i0+m2,j0,k0+p2, C,A,B)
add_matmul_rec!(m-m2,n-n2,p-p2, i0+m2,j0+n2,k0+p2, C,A,B)
end
return C
end
using BenchmarkTools, LinearAlgebra
BLAS.set_num_threads(1)
for i=2:6
A=randn(60*2^i,60*2^i)
B=randn(60*2^i,60*2^i)
C=zeros(size(A)) #similar(A)
print("Size: 60*$(2^i)"); @btime mul!($C,$A,$B)
m,n=size(C); p=size(A,2)
@btime add_matmul_rec!($m,$n,$p, 1,1,1, $C,$A,$B)
end
The benchmark result on an i9-11900K with Win 11 is
Size: 60*4 386.900 μs (0 allocations: 0 bytes)
409.900 μs (0 allocations: 0 bytes)
Size: 60*8 3.019 ms (0 allocations: 0 bytes)
3.433 ms (0 allocations: 0 bytes)
Size: 60*16 25.518 ms (0 allocations: 0 bytes)
31.524 ms (0 allocations: 0 bytes)
Size: 60*32 212.340 ms (0 allocations: 0 bytes)
267.145 ms (0 allocations: 0 bytes)
Size: 60*64 1.693 s (0 allocations: 0 bytes)
2.206 s (0 allocations: 0 bytes)
If using the Strassen algorithm
clean!(A) = fill!(A,zero(eltype(A)))
function Strassen!(C, A, B, cA, cB, cC)
n,m=size(A) .÷ 2; _,p=size(B) .÷ 2
if m+n+p <= 3*240
add_matmul_rec!(2m,2n,2p, 1,1,1, C,A,B)
else
A11=@view A[1:n,1:m]; A12=@view A[1:n,(m+1):end]; A21=@view A[(n+1):end,1:m]; A22=@view A[(n+1):end,(m+1):end]
B11=@view B[1:m,1:p]; B12=@view B[1:m,(p+1):end]; B21=@view B[(m+1):end,1:p]; B22=@view B[(m+1):end,(p+1):end]
C11=@view C[1:n,1:p]; C12=@view C[1:n,(p+1):end]; C21=@view C[(n+1):end,1:p]; C22=@view C[(n+1):end,(p+1):end]
cA2=@view cA[1:n,1:m]; cB2=@view cB[1:m,1:p]; ccA=@view cA[1:n ÷ 2,(m+1):end]; ccB=@view cB[1:m ÷ 2,(p+1):end]
cC2=@view cC[1:n,1:p]; ccC=@view cC[1:n ÷ 2,(p+1):end]
@. cA2=A11+A22; @. cB2=B11+B22; Strassen!(clean!(cC2),cA2,cB2,ccA,ccB,ccC); @. C11+=cC2; @. C22+=cC2
@. cA2=A21+A22; Strassen!(clean!(cC2),cA2,B11,ccA,ccB,ccC); @. C21+=cC2; @. C22-=cC2
@. cB2=B12-B22; Strassen!(clean!(cC2),A11,cB2,ccA,ccB,ccC); @. C12+=cC2; @. C22+=cC2
@. cB2=B21-B11; Strassen!(clean!(cC2),A22,cB2,ccA,ccB,ccC); @. C11+=cC2; @. C21+=cC2
@. cA2=A11+A12; Strassen!(clean!(cC2),cA2,B22,ccA,ccB,ccC); @. C11-=cC2; @. C12+=cC2
@. cA2=A21-A11; @. cB2=B11+B12; Strassen!(C22,cA2,cB2,ccA,ccB,ccC)
@. cA2=A12-A22; @. cB2=B21+B22; Strassen!(C11,cA2,cB2,ccA,ccB,ccC)
end
return C
end
for i=2:6
A=randn(60*2^i,60*2^i)
B=randn(60*2^i,60*2^i)
C=zeros(size(A)) #similar(A)
cA=zeros(size(A,1) ÷ 2, size(A,2))
cB=zeros(size(B,1) ÷ 2, size(B,2))
cC=zeros(size(C,1) ÷ 2, size(C,2))
print("Size: 60*$(2^i)")
@btime Strassen!($C,$A,$B,$cA,$cB,$cC)
end
the result is
Size: 60*4 408.900 μs (0 allocations: 0 bytes)
Size: 60*8 3.433 ms (0 allocations: 0 bytes)
Size: 60*16 30.694 ms (0 allocations: 0 bytes)
Size: 60*32 243.806 ms (0 allocations: 0 bytes)
Size: 60*64 1.879 s (0 allocations: 0 bytes)
The efficiency gap between this demo code and MKL.jl remains around 10%.
I know there are many experts here, so I would be very happy to hear about any suggestions for further optimization or generalization.