Batched matrix-multiplication optimization

How can matrix multiplication in a for loop be optimized? Here is an example of mine:
original case:

Z = Array{ComplexF64,3}(undef,512,8,24192);
X = Array{ComplexF64,3}(undef,512,16,24192);
Y = Array{ComplexF64,3}(undef,16,8,24192);
Len  = 24192;
@time begin
    @inbounds for n = 1 : Len
        Z[:,:,n] .= X[:,:,n] * Y[:,:,n]
    end
end
# 4.271584 seconds (407.18 k allocations: 4.486 GiB, 23.31% gc time)

use mul!


Z = Array{ComplexF64,3}(undef,512,8,24192);
X = Array{ComplexF64,3}(undef,512,16,24192);
Y = Array{ComplexF64,3}(undef,16,8,24192);
Len  = 24192;
@time begin
    for n = 1 : Len
        mul!(view(Z,:,:,n) , view(X,:,:,n) , view(Y,:,:,n))
    end
end
# 2.296982 seconds (382.99 k allocations: 10.643 MiB)

I wonder if there are any methods that can be further optimized.This case will be run multiple times in the my project.

View operations create matrix slices at a very low cost. They create a SubArray with two allocations. But here’s the catch. In your code, the loop runs 24,192 times, with 3 views each time, so theoretically there are 145,152 SubArray allocations in total. Therefore, the performance impact of view operations cannot be ignored.

To avoid this, I suggest to manually write a matrix multiplication function to complete the loop. The bad news is that you can no longer use the openBLAS mul! function; the good news is that the Julia language is fast enough on its own, so you can use it to implement the underlying computations without sacrificing performance.

I used ChatGPT to help me write the following code, which implements zero-allocation 3D matrix multiplication. However, I haven’t thoroughly checked whether it’s correct.

function batch_matmul(Z::Array{T,3}, X::Array{T,3}, Y::Array{T,3}) where {T}
    @views for n = 1:Len
        mul!(Z[:, :, n], X[:, :, n], Y[:, :, n])
    end
end
@time batch_matmul(Z, X, Y)
# 1.429842 seconds (382.99 k allocations: 10.643 MiB)

function batch_matmul!(Z::Array{T,3}, X::Array{T,3}, Y::Array{T,3}) where {T}
    M, P, Len = size(X) # size(X) = MΓ—PΓ—Len
    _, N, _ = size(Y)   # size(Y) = PΓ—NΓ—Len,size(Z) = MΓ—NΓ—Len
    @assert size(Z) == (M, N, Len)

    mP = M * P
    pN = P * N
    mN = M * N

    @inbounds for q = 1:Len
        offX = (q - 1) * mP
        offY = (q - 1) * pN
        offZ = (q - 1) * mN
        for j = 1:N
            idxZj = offZ + (j - 1) * M
            idxYj = offY + (j - 1) * P
            @simd for i = 1:M
                s = zero(T)
                baseX = offX + i
                for k = 1:P
                    # X[i,k,q] linear index = offX + (k-1)*M + i
                    # Y[k,j,q] linear index  offY + (j-1)*P + k
                    s += X[baseX+(k-1)*M] * Y[idxYj+k]
                end
                Z[idxZj+i] = s
            end
        end
    end

    return Z
end
@time batch_matmul!(Z, X, Y);
# 1.174894 seconds

I would go with Tullio.jl + LoopVectorization.jl

using Tullio,LoopVectorization,LinearAlgebra

Z = Array{ComplexF64,3}(undef,512,8,24192);
X = Array{ComplexF64,3}(undef,512,16,24192);
Y = Array{ComplexF64,3}(undef,16,8,24192);


function mul1!(Z,X,Y)
    Len  = size(Z,3)
    @inbounds for n = 1 : Len
        mul!(view(Z,:,:,n) , view(X,:,:,n) , view(Y,:,:,n))
    end
    return Z
end

function mul2!(Z,X,Y)
    @tullio Z[i,j,k] = X[i,l,k] * Y[l,j,k]
    return Z
end

using BenchmarkTools
@btime mul1!($Z,$X,$Y);
@btime mul2!($Z,$X,$Y);

gives 1.1s and 0.6675s. Nevermind LV doesn’t trigger because of Complex numbers it wons because of Threading but you can just add Threads.@threads to your loop and win (0.4).
The best I got is this :

function mul2!(Z,X,Y)
    N1, N2, N3 = size(Z)
    N4 = size(X,2)
    Threads.@threads for n in 1:N3
        for j in 1:N2
            for i in 1:N1
                s = zero(ComplexF64)
                @inbounds @simd for l in 1:N4
                    s += X[i,l,n] * Y[l,j,n]
                end
                @inbounds Z[i,j,n] = s
            end
        end
    end
    return Z
end
2 Likes

This is incorrect. Strided views are allocation free. The allocations you’re seeing in your code are from your function referencing the untyped global Len. Here’s an implementation avoiding this problem:

julia> using LinearAlgebra, BenchmarkTools

julia> Z = Array{ComplexF64,3}(undef, 512, 8, 24192);

julia> X = randn(ComplexF64, 512, 16, 24192);

julia> Y = randn(ComplexF64, 16, 8, 24192);

julia> function batchmul!(Z, X, Y)
           indices = axes(Z, 3)
           if (axes(X, 3) != indices) || (axes(Y, 3) != indices)
               throw(DimensionMismatch("axes(Z, 3) == axes(X, 3) == axes(Y, 3) not satisfied"))
           end
           for n in indices
               mul!(view(Z, :, :, n) , view(X, :, :, n) , view(Y, :, :, n))
           end
           return nothing
       end
batchmul! (generic function with 1 method)

julia> @btime batchmul!($Z, $X, $Y)
  173.214 ms (0 allocations: 0 bytes)

No allocations!

Note that I initialized X and Y with random numbers. While uninitialized arrays won’t lead to crashes when the eltype is floating-point numbers, such arrays will typically have an unusually high proportion of NaNs and Infs, for which hardware operations are often much slower than for normal numbers. For a realistic benchmark, the input arrays should be initialized.

6 Likes

You are right, the problem was that I was referencing the untyped global variable Len. Thanks for the reminder.

I like this one, it’s much cleaner than my version, and it also seems to be faster than the mul! version.

It’s because the last dim is the highest, you can do even better with

BLAS.set_num_threads(1)

function mul1!(Z,X,Y)
    @inbounds Threads.@threads for n in axes(Z,3)
        mul!(view(Z,:,:,n) , view(X,:,:,n) , view(Y,:,:,n))
    end
    return Z
end

to avoid BLAS using your threads and instead use them on the last axes

2 Likes

No allocations!!!, It seems that the performance of your machine is much better than mine, but this method has been of great help to me. Thank you very much.

Adding this just for fun because I tested :

using CUDA,KernelAbstractions

T = ComplexF64
dev = CuArray
Z = Array{T,3}(undef,512,8,24192) |> dev;
X = randn(T,512,16,24192) |> dev;
Y = randn(T,16,8,24192) |> dev;

@kernel inbounds=true unsafe_indices=true function mul_kernel(Z,X,Y)
    i,j,n = @index(Global, NTuple)
    if i <= size(Z, 1) && j <= size(Z, 2) && n <= size(Z, 3)
        s = zero(eltype(Z))
        for k in axes(X, 2)
            s += X[i,k,n] * Y[k,j,n]
        end
        Z[i,j,n] = s
    end
end

function mul2!(Z,X,Y)
    dev = get_backend(Z)
    ker = mul_kernel(dev)
    ker(Z,X,Y,ndrange=size(Z))
    KernelAbstractions.synchronize(dev)
    return Z
end

@btime mul2!($Z,$X,$Y);

this is 0.09s because its gpu but if you replace dev = identity instead you get the same perf as the best I got with the full loopy version around 0.3s. Oh and this is gpu agnostic replacing CUDA with anything.

1 Like

It seems that my GPU is not supported now. It exceeds the memory size :joy:

Oh no :cry: try with ComplexF32 ( you need to relaunch julia)

(Perhaps you changed your example after writing this warning, but I’m responding to your example as currently written.) EDIT: I simply misread what was being said.

You initialized them with randn, which never produces Inf or NaN (or even subnormal, I’m quite sure – though even if it could you’d virtually never see one except in Float16). Your arrays were initialized and the above warning does not apply. If you’d merely allocated them but left them truly uninitialized (via undef, similar, or a related mechanism), then you would risk a large number of NaNs and subnormals (Inf would be rarer – of the 2^{64} possible Float64s, 2^{53}-2 are NaN, 2^{53}-2 are subnormal, and only 2 are Inf – though we’re talking about recycled memory so it depends more on what was there before than uniform priors).

That aside, Inf or NaN don’t lead to any slowdown on the most common architectures. Generally, subnormal numbers are the only numbers you should expect to cause significant slowdowns.

julia> using Chairmarks, LinearAlgebra

julia> X = Matrix{ComplexF64}(undef, 512, 16); Y = similar(X, 16, 8); Z = similar(X, 512, 8);

julia> v = complex(1e0,1e0); fill!(X, v); fill!(Y, v); @b mul!($Z, $X, $Y) # normal numbers
12.972 ΞΌs

julia> v = complex(Inf,1e0); fill!(X, v); fill!(Y, v); @b mul!($Z, $X, $Y) # with Infs
12.595 ΞΌs

julia> v = complex(NaN,1e0); fill!(X, v); fill!(Y, v); @b mul!($Z, $X, $Y) # with NaNs
12.495 ΞΌs

julia> v = complex(1e-320,1e0); fill!(X, v); fill!(Y, v); @b mul!($Z, $X, $Y) # with subnormals
953.371 ΞΌs # yikes!!

This is exactly what NNlib.batched_mul does. (It will restore your thread setting afterwards, and not use @threads if the problem is very small.)

For CuArray it will also dispatch to the appropriate CUBLAS strided_batched_mul routine. This is basically the operation GPUs are happiest to do all day.

1 Like

Yes, that’s the point. The OP does not, I did, so I made sure to point out that this was another difference between our implementations (in addition to the implementation itself).

1 Like

You’re right. I was seeing a clear difference, but on more careful inspection it looks like this is explained by the effect of initialization on caching. Amortizing over a large number of evals per trial washes out the difference:

julia> @btime batchmul!($Z, X, Y) setup=begin
           X = Array{ComplexF64,3}(undef, 512, 16, 24192)
           Y = Array{ComplexF64,3}(undef, 16, 8, 24192)
       end evals=1
  307.308 ms (0 allocations: 0 bytes)

julia> @btime batchmul!($Z, X, Y) setup=begin
           X = randn(ComplexF64, 512, 16, 24192)
           Y = randn(ComplexF64, 16, 8, 24192)
       end evals=1
  209.252 ms (0 allocations: 0 bytes)

julia> @btime batchmul!($Z, X, Y) setup=begin
           X = Array{ComplexF64,3}(undef, 512, 16, 24192)
           Y = Array{ComplexF64,3}(undef, 16, 8, 24192)
       end evals=50
  190.275 ms (0 allocations: 0 bytes)

julia> @btime batchmul!($Z, X, Y) setup=begin
           X = randn(ComplexF64, 512, 16, 24192)
           Y = randn(ComplexF64, 16, 8, 24192)
       end evals=50
  188.018 ms (0 allocations: 0 bytes)
1 Like

I took another look at this, and it turns out the uninitialized arrays don’t actually contain a lot of NaNs and Infs, at least not on my laptop (I guess this may depend on both hardware and libc/malloc). So here’s a more rigorous comparison, verifying that the arrays contain/not contain non-finite numbers in the respective cases, but don’t contain subnormals in either case:

julia> begin
           real_finite = -6.0:6.0
           real_nonfin = [real_finite; -Inf; Inf; NaN]
           complex_finite = complex.(real_finite, real_finite')
           complex_nonfin = complex.(real_nonfin, real_nonfin')
       end;

julia> @btime batchmul!($Z, X, Y) setup=begin
           X = rand(complex_finite, 512, 16, 24192)
           Y = rand(complex_finite, 16, 8, 24192)
           @assert all(isfinite, X)
           @assert all(isfinite, Y)
       end evals=10
  185.307 ms (0 allocations: 0 bytes)

julia> @btime batchmul!($Z, X, Y) setup=begin
           X = rand(complex_nonfin, 512, 16, 24192)
           Y = rand(complex_nonfin, 16, 8, 24192)
           @assert !all(isfinite, X)
           @assert !all(isfinite, Y)
       end evals=10
  185.573 ms (0 allocations: 0 bytes)

No difference.

Super hacky use of LoopVectorization

using LoopVectorization
function turbo_batched_matmul!(
  Cc::AbstractArray{Complex{T},3},
  Ac::AbstractArray{Complex{T},3},
  Bc::AbstractArray{Complex{T},3},
) where {T}
  C = reinterpret(T, Cc)
  A = reinterpret(T, Ac)
  B = reinterpret(reshape, T, Bc)
  @tturbo vectorize = 3 for b ∈ indices((A,B,C),(3,4,3)), n ∈ indices((C, B), (2, 3)), m ∈ indices((C, A), 1)
    Cmn = zero(T)
    for k ∈ indices((A, B), (2, 2))
      Amk = A[m, k, b]
      Aperm = vpermilps177(Amk)
      Cmn = vfmaddsub(Amk, B[1, k, n, b], vfmaddsub(Aperm, B[2, k, n, b], Cmn))
    end
    C[m, n, b] = Cmn
  end
  return Cc
end

using LinearAlgebra
BLAS.set_num_threads(1)

function batched_mul!(Z,X,Y)
    Threads.@threads for n in axes(Z,3)
        @inbounds mul!(view(Z,:,:,n) , view(X,:,:,n) , view(Y,:,:,n))
    end
    return Z
end


Z = Array{ComplexF64,3}(undef,512,8,24192);
X = rand(ComplexF64,512,16,24192);
Y = rand(ComplexF64,16,8,24192);

turbo_batched_matmul!(Z,X,Y) β‰ˆ batched_mul!(similar(Z),X,Y) # true
@benchmark turbo_batched_matmul!($Z,$X,$Y)
@benchmark batched_mul!($Z,$X,$Y)

I get

julia> turbo_batched_matmul!(Z,X,Y) β‰ˆ batched_mul!(similar(Z),X,Y)
true

julia> @benchmark turbo_batched_matmul!($Z,$X,$Y)
BenchmarkTools.Trial: 8 samples with 1 evaluation per sample.
 Range (min … max):  132.097 ms … 148.658 ms  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):     136.215 ms               β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   138.137 ms Β±   5.062 ms  β”Š GC (mean Β± Οƒ):  0.00% Β± 0.00%

  ▁         ▁   β–β–ˆ              ▁▁                            ▁  
  β–ˆβ–β–β–β–β–β–β–β–β–β–ˆβ–β–β–β–ˆβ–ˆβ–β–β–β–β–β–β–β–β–β–β–β–β–β–β–ˆβ–ˆβ–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–ˆ ▁
  132 ms           Histogram: frequency by time          149 ms <

 Memory estimate: 0 bytes, allocs estimate: 0.

julia> @benchmark batched_mul!($Z,$X,$Y)
BenchmarkTools.Trial: 10 samples with 1 evaluation per sample.
 Range (min … max):  109.267 ms … 111.225 ms  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):     110.093 ms               β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   110.167 ms Β± 564.070 ΞΌs  β”Š GC (mean Β± Οƒ):  0.00% Β± 0.00%

  β–ˆ           β–ˆ    β–ˆ    β–ˆ β–ˆ β–ˆβ–ˆ             β–ˆ    β–ˆ             β–ˆ  
  β–ˆβ–β–β–β–β–β–β–β–β–β–β–β–ˆβ–β–β–β–β–ˆβ–β–β–β–β–ˆβ–β–ˆβ–β–ˆβ–ˆβ–β–β–β–β–β–β–β–β–β–β–β–β–β–ˆβ–β–β–β–β–ˆβ–β–β–β–β–β–β–β–β–β–β–β–β–β–ˆ ▁
  109 ms           Histogram: frequency by time          111 ms <

 Memory estimate: 16.81 KiB, allocs estimate: 162.

Looks like the obvious solution wins here.

Something weird is happening with LV’s threading

julia> @inline function LoopVectorization._choose_num_threads(
         C::T,
         NT::UInt,
         x::Base.BitInteger
       ) where {T<:Union{Float32,Float64}}
       15
       end

julia> @benchmark turbo_batched_matmul!($Z,$X,$Y)
BenchmarkTools.Trial: 10 samples with 1 evaluation per sample.
 Range (min … max):  109.023 ms … 109.778 ms  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):     109.225 ms               β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   109.313 ms Β± 257.286 ΞΌs  β”Š GC (mean Β± Οƒ):  0.00% Β± 0.00%

  ▁ ▁   ▁     ▁   β–ˆ            ▁              ▁▁              ▁  
  β–ˆβ–β–ˆβ–β–β–β–ˆβ–β–β–β–β–β–ˆβ–β–β–β–ˆβ–β–β–β–β–β–β–β–β–β–β–β–β–ˆβ–β–β–β–β–β–β–β–β–β–β–β–β–β–β–ˆβ–ˆβ–β–β–β–β–β–β–β–β–β–β–β–β–β–β–ˆ ▁
  109 ms           Histogram: frequency by time          110 ms <

 Memory estimate: 0 bytes, allocs estimate: 0.

Lowering it from 16 to 15 on this 16 core computer boosts performance.

Also cool, LV does 20 batched multiplies with 2.95e10 instructions, while the @threads mul! loop requires 2.9x the instructions, at 8.6e10:

julia> using LinuxPerf # this is with LV hacked to use only 15 threads

julia> @time @pstats for _ in 1:20; turbo_batched_matmul!(Z,X,Y); end
β”Œ Warning: LinuxPerf.EventTypeExt(hw:stalled_cycles_backend, false, 0x0000000000000006) not supported, skipping
β”” @ LinuxPerf ~/.julia/packages/LinuxPerf/Ylq05/src/LinuxPerf.jl:303
  2.212762 seconds (33.80 k allocations: 1.962 MiB, 1.10% compilation time)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
β”Œ cpu-cycles               1.63e+11  100.0%  #  5.0 cycles per ns
β”” stalled-cycles-frontend  6.25e+08  100.0%  #  0.4% of cycles
β”Œ instructions             2.95e+10  100.0%  #  0.2 insns per cycle
β”‚ branch-instructions      4.40e+08  100.0%  #  1.5% of insns
β”” branch-misses            6.16e+05  100.0%  #  0.1% of branch insns
β”Œ task-clock               3.28e+10  100.0%  # 32.8 s
β”‚ context-switches         0.00e+00  100.0%
β”‚ cpu-migrations           0.00e+00  100.0%
β”” page-faults              0.00e+00  100.0%
                 aggregated from 15 threads
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

julia> @time @pstats for _ in 1:20; batched_mul!(Z,X,Y); end
β”Œ Warning: LinuxPerf.EventTypeExt(hw:stalled_cycles_backend, false, 0x0000000000000006) not supported, skipping
β”” @ LinuxPerf ~/.julia/packages/LinuxPerf/Ylq05/src/LinuxPerf.jl:303
  2.242579 seconds (37.03 k allocations: 2.291 MiB, 1.08% compilation time)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
β”Œ cpu-cycles               3.12e+11  100.0%  #  4.8 cycles per ns
β”” stalled-cycles-frontend  2.56e+09  100.0%  #  0.8% of cycles
β”Œ instructions             8.60e+10  100.0%  #  0.3 insns per cycle
β”‚ branch-instructions      3.65e+09  100.0%  #  4.2% of insns
β”” branch-misses            4.03e+06  100.0%  #  0.1% of branch insns
β”Œ task-clock               6.52e+10  100.0%  # 65.2 s
β”‚ context-switches         0.00e+00  100.0%
β”‚ cpu-migrations           0.00e+00  100.0%
β”” page-faults              1.00e+00  100.0%
                 aggregated from 32 threads
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
5 Likes