Batched matrix multiplication in CUDA

I am thinking of performing the following batched matrix multiplication using CUDA, but I can’t find an appropriate function for it.

using LinearAlgebra

N = 50; Nt = 3000;
A = rand(Nt,2*N,2*N); C = rand(Nt,2*N,2*N); B = rand(Nt,Nt);

Abatch = [ transpose(A[:, 2*pq-1, :]) for pq = 1:N ];
Cbatch = [ C[:, :, 2*pq] for pq = 1:N ];
Bbatch = [ B for pq = 1:N ];
res = tr.( Abatch .* Bbatch .* Cbatch );

You can do this, which mostly avoids slices:

julia> using NNlib: ⊠  # batched_mul

julia> D = permutedims(A[:, 1:2:end, :],(3,1,2)) ⊠ B ⊠ C[:, :, 2:2:end];

julia> res ≈ tr.(eachslice(D; dims=3))
true

You can also try things like this, maybe the last way will be most efficient?

julia> using Tullio  # and using KernelAbstractions, for CUDA

julia> res ≈ @tullio out[k] := D[i,i,k]
true

julia> AB = permutedims(A[:, 1:2:end, :],(3,1,2)) ⊠ B;

julia> C2 = C[:, :, 2:2:end];

julia> res ≈ @tullio out2[b] := AB[i,j,b] * C2[j,i,b]
true

julia> BC = B ⊠ @view C[:, :, 2:2:end];

julia> res ≈ @tullio out3[b] := A[i,2b-1,j] * BC[i,j,b]
true

Thanks, I have made a comparison of three methods. It seems the garden variety full and explicit matrix multiplication is the fastest in my system. I am using 16 threads in i713700K. My GPU is Nvidia RTX 3060.

using LinearAlgebra
using MKL
using NNlib, CUDA, cuDNN

N = 50; Nt = 3000;
A = rand(Nt,2*N,2*N); C = rand(Nt,2*N,2*N); B = rand(Nt,Nt);

nruns = 100;
ta = 0; tb = 0; tc = 0;

for nr = 1:nruns
	println(nr)

	t0a = time()
	Dtemp = permutedims(A[:, 1:2:end, :],(3,1,2)) ⊠ B ⊠ C[:, :, 2:2:end];
	res1 = tr.(eachslice(Dtemp; dims=3));
	t1a = time()
	global ta = ta + (t1a-t0a)/nruns;

	t0b = time()
	Abatch = [ transpose(A[:, 2*pq-1, :]) for pq = 1:N ];
	Cbatch = [ C[:, :, 2*pq] for pq = 1:N ];
	Bbatch = [ B for pq = 1:N ];
	res2 = tr.( Abatch .* Bbatch .* Cbatch );
	t1b = time()
	global tb = tb + (t1b-t0b)/nruns;

	res3 = zeros(Float64,N);
	t0c = time()
	Threads.@threads for pq = 1:N    #16 threads
		res3[pq] = tr( transpose(@view A[:, 2*pq-1, :]) * B * (@view C[:, :, 2*pq]) )
	end
	t1c = time()
	global tc = tc + (t1c-t0c)/nruns;
end

println("NNlib+CUDA time = ",t1a-t0a)
println("Vectorised batch time = ",t1b-t0b)
println("Explicit time = ",t1c-t0c)

Results, times averaged over 100 runs,

NNlib+CUDA time = 0.3775529026985167
Vectorised batch time = 0.36217491626739495
Explicit time = 0.23013512134552006

Any ideas how to make it any faster?

None of these use CUDA yet. On ordinary Arrays, batched_mul is just a loop over slices (although I believe it uses single-threaded BLAS inside @threads).

You need to move the data to the GPU with e.g. cu, to get a CuArray. This will probably mean working in Float32.

(I also recommend making functions & using BenchmarkTools, instead of your own timing loops.)

Thanks, this is what I am getting now,

using LinearAlgebra
using MKL
using NNlib, NNlibCUDA
using CUDA
CUDA.allowscalar(false)
using BenchmarkTools

N = 50; Nt = 3000;
A = rand(Nt,2*N,2*N); C = rand(Nt,2*N,2*N); B = rand(Nt,Nt);

function NNlibnocu(A,B,C)
	Dtemp = permutedims(A[:, 1:2:end, :],(3,1,2)) ⊠ B ⊠ C[:, :, 2:2:end];
	res = tr.(eachslice(Dtemp; dims=3));
	
	return res
end

function NNlibcu(A,B,C)
	Acu = cu(permutedims(A[:, 1:2:end, :],(3,1,2))); Bcu = cu(B); Ccu = cu(C[:, :, 2:2:end]);
	Dtemp = Acu ⊠ Bcu ⊠ Ccu;
	res = Array( tr.(eachslice(Dtemp; dims=3)) );
	
	return res
end

function vecbatch(A,B,C)
	Abatch = [ transpose(A[:, 2*pq-1, :]) for pq = 1:N ];
	Cbatch = [ C[:, :, 2*pq] for pq = 1:N ];
	Bbatch = [ B for pq = 1:N ];
	res = tr.( Abatch .* Bbatch .* Cbatch );
	
	return res
end

function expl(A,B,C)
	res = zeros(Float64,N);
	Threads.@threads for pq = 1:N 
		res[pq] = tr( transpose(@view A[:, 2*pq-1, :]) * B * (@view C[:, :, 2*pq]) );
	end
	
	return res
end

@btime res0 = NNlibnocu(A,B,C);
@btime res1 = NNlibcu(A,B,C);
@btime res2 = vecbatch(A,B,C);
@btime res3 = expl(A,B,C);

Note that I have tried replacing the multithreaded for loop with a single threaded basic one. I consistently find that the former is faster. In any case, the explicit multiplication still turns out faster, assuming there aren’t any mistakes.

284.134 ms (239 allocations: 461.64 MiB)
241.266 ms (2081 allocations: 492.21 MiB)
295.904 ms (420 allocations: 347.15 MiB)
211.852 ms (301 allocations: 118.27 MiB)

The first piece of advice I have is to remember that the trace of the product of two matrices only uses the diagonal elements, and therefore you don’t need to calculate the whole matrix. So if you find yourself calculating Tr(A*B) it is probably better to not do the full matrix multiplication, when all you actually need is the sum of the elementwise multiplication of one matrix by the transpose of the other. So working from your vecbatch function it is faster to do something like this:

function vecbatch(A,B,C)
	Abatch = [A[:, 2*pq-1, :] for pq = 1:N ];
	Cbatch = [ C[:, :, 2*pq] for pq = 1:N ];
	res = [sum( Abatch[pq] .* (B * Cbatch[pq])) for pq=1:N];
	
	return res
end

The second piece of advice I have is that the cuBLAS library has a few functions for batched matrix multiplication that you might want to look into if you are planning to use CUDA for this problem.

2 Likes

In my pc, the explicit product with the trace still runs faster. While your suggestion is logical, probably MKL does something under the hood? I have even tried replacing the trace with a for loop and it doesn’t help, by quite some margin actually.

using LinearAlgebra
using MKL
using NNlib, NNlibCUDA
using CUDA
CUDA.allowscalar(false)
using BenchmarkTools

function expl(A,B,C)
	res = zeros(Float64,N);
	Threads.@threads for pq = 1:N 
		res[pq] = tr( transpose(A[:, 2*pq-1, :]) * B * (C[:, :, 2*pq]) );
	end
	
	return res
end

function explnotrace(A,B,C)
	res = zeros(Float64,N);
	Threads.@threads for pq = 1:N
		for ij = 1:N
			res[pq] = res[pq] + transpose(A[:, 2*pq-1, ij]) * B * (C[:, ij, 2*pq]);
		end
	end
	
	return res
end

@btime res4 = expl(A,B,C);
@btime res5 = explf(A,B,C);
197.631 ms (501 allocations: 347.16 MiB)
1.912 s (107652 allocations: 173.77 MiB)

Yes, this is what @tullio out3 above was trying to say, probably much too obscurely.

NNlib.batched_mul exists to wrap these functions. On the CPU it falls back to a loop over slices (so won’t beat a loop you write yourself).

Note that some of these are Float32 and some Float64. It also times the copying of the data to the GPU, which I’m not sure you want to include.

Below are my attempts to time this… not terribly conclusive though, much will depend on what sizes you pick, and on your machine.

julia> function loop1(A,B,C,N=size(C,3)÷2)
       res3 = zeros(N)
       for pq in 1:N
          res3[pq] = tr( transpose(@view A[:, 2*pq-1, :]) * B * (@view C[:, :, 2*pq]) )
       end
       res3
       end
loop1 (generic function with 2 methods)

julia> function loop2(A,B,C,N=size(C,3)÷2)
       res3 = zeros(N)
       th = BLAS.get_num_threads()
       BLAS.set_num_threads(1)
       Threads.@threads for pq in 1:N
          res3[pq] = tr( transpose(@view A[:, 2*pq-1, :]) * B * (@view C[:, :, 2*pq]) )
       end
       BLAS.set_num_threads(th)
       res3
       end
loop2 (generic function with 2 methods)

julia> using NNlib, TensorCore, Tullio

julia> function nnlib1(A,B,C)
         ABC = @views permutedims(A[:, 1:2:end, :],(3,1,2)) ⊠ B ⊠ C[:, :, 2:2:end];
         tr.(eachslice(ABC; dims=3))
       end
nnlib1 (generic function with 1 method)

julia> function tullio3(A,B,C)
         BC = batched_mul(B, @view C[:, :, 2:2:end])  # threaded loop (on Arrays)
         @tullio out3[b] := A[i,2b-1,j] * BC[i,j,b]
       end
tullio3 (generic function with 1 method)

julia> function tullio4(A,B,C)
         BC = boxdot(B, C[:, :, 2:2:end])  # reshape & mul
         @tullio out3[b] := A[i,2b-1,j] * BC[i,j,b]
       end
tullio4 (generic function with 1 method)

julia> function bcmichael_vecbatch(A,B,C,N=size(C,3)÷2)  # changed to make N local
               Abatch = [A[:, 2*pq-1, :] for pq = 1:N ];
               Cbatch = [ C[:, :, 2*pq] for pq = 1:N ];
               res = [sum( Abatch[pq] .* (B * Cbatch[pq])) for pq=1:N]
       end
bcmichael_vecbatch (generic function with 1 method)

julia> function bcmichael_vecbatch_views(A,B,C,N=size(C,3)÷2)
               @views Abatch = [A[:, 2*pq-1, :] for pq = 1:N ];
               @views Cbatch = [ C[:, :, 2*pq] for pq = 1:N ];
               res = [sum( Abatch[pq] .* (B * Cbatch[pq])) for pq=1:N]
       end
bcmichael_vecbatch_views (generic function with 1 method)

julia> using BenchmarkTools, LinearAlgebra

julia> let N = 50, Nt = 30, T = Float64, move=Array
         println(move, "{$T}, N=$N, Nt=$Nt")
         A = move(randn(T,Nt,2*N,2*N)); C = move(randn(T,Nt,2*N,2*N)); B = move(randn(T,Nt,Nt));
         println("loops over tr( * * )")
         res1 = @btime loop1($A,$B,$C)
         res2 = @btime loop2($A,$B,$C)
         println("batched_mul then tr")
         res3 = @btime nnlib1($A,$B,$C)
         println("one mul then tullio")
         res4 = @btime tullio3($A,$B,$C)
         res5 = @btime tullio4($A,$B,$C)
         println("bcmichael_vecbatch")
         res6 = @btime bcmichael_vecbatch($A,$B,$C)
         res6 = @btime bcmichael_vecbatch_views($A,$B,$C)
         res1 ≈ res2 ≈ res3 ≈ res4 ≈ res5 ≈ res6
       end
Array{Float32}, N=50, Nt=30
loops over tr( * * )
  min 499.458 μs, mean 957.305 μs (302 allocations, 4.97 MiB)
  min 543.333 μs, mean 823.141 μs (324 allocations, 4.97 MiB)
batched_mul then tr
  min 1.318 ms, mean 1.804 ms (181 allocations, 6.16 MiB)
one mul then tullio
  min 385.417 μs, mean 461.693 μs (41 allocations, 1.15 MiB)
  min 284.209 μs, mean 575.084 μs (10 allocations, 2.29 MiB)
bcmichael_vecbatch
  min 600.625 μs, mean 1.070 ms (606 allocations, 4.59 MiB)
  min 274.792 μs, mean 559.201 μs (308 allocations, 2.30 MiB)
true

julia> versioninfo()
Julia Version 1.11.0-DEV.901
Commit 4bc45a7f0a* (2023-11-14 16:49 UTC)
Platform Info:
  OS: macOS (arm64-apple-darwin21.6.0)
  CPU: 8 × Apple M1
  WORD_SIZE: 64
  LLVM: libLLVM-15.0.7 (ORCJIT, apple-m1)
  Threads: 5 on 4 virtual cores
Environment:
  JULIA_NUM_THREADS = 4

# on another machine, xeon with openblas, and a GPU
Array{Float32}, N=50, Nt=30
loops over tr( * * )
  3.055 ms (251 allocations: 4.96 MiB)
  1.069 ms (222 allocations: 4.97 MiB)
batched_mul then tr
  5.805 ms (128 allocations: 6.15 MiB)
one mul then tullio
  729.539 μs (40 allocations: 1.15 MiB)
  1.360 ms (9 allocations: 2.29 MiB)
bcmichael_vecbatch
  2.095 ms (403 allocations: 4.59 MiB)
  1.449 ms (203 allocations: 2.30 MiB)
true

julia> using CUDA, KernelAbstractions

julia> function tullio3(A,B,C)  # NB @tullio must be after using CUDA, KernelAbstractions
         BC = batched_mul(B, @view C[:, :, 2:2:end])
         @tullio out3[b] := A[i,2b-1,j] * BC[i,j,b]
       end
tullio3 (generic function with 1 method)

julia> function tullio4(A,B,C)
         BC = boxdot(B, C[:, :, 2:2:end])  # reshape & mul
         @tullio out3[b] := A[i,2b-1,j] * BC[i,j,b]
       end
tullio4 (generic function with 1 method)

julia> let N = 50, Nt = 300, T = Float32, move=CuArray
         println(move, "{$T}, N=$N, Nt=$Nt")
         A = move(randn(T,Nt,2*N,2*N)); C = move(randn(T,Nt,2*N,2*N)); B = move(randn(T,Nt,Nt));
         println("loops over tr( * * )")
         res1 = @btime loop1($A,$B,$C)
         # res2 = @btime loop2($A,$B,$C)
         println("batched_mul then tr")
         res3 = @btime nnlib1($A,$B,$C)
         println("one mul then tullio")
         res4 = Array(@btime CUDA.@sync tullio3($A,$B,$C))  # the function returns a CuArray
         res5 = Array(@btime CUDA.@sync tullio4($A,$B,$C))
         println("bcmichael_vecbatch")
         res6 = @btime bcmichael_vecbatch($A,$B,$C)
         res6 = @btime bcmichael_vecbatch_views($A,$B,$C)
         res1 ≈ res1 ≈ res3 ≈ res4 ≈ res5 ≈ res6
       end
CuArray{Float32}, N=50, Nt=300
loops over tr( * * )
  9.277 ms (7753 allocations: 377.64 KiB)  # NB this is only CPU alloc, not GPU
batched_mul then tr
  3.959 ms (3705 allocations: 178.12 KiB)
one mul then tullio
  4.957 ms (103 allocations: 4.12 KiB)
  5.064 ms (194 allocations: 7.25 KiB)
bcmichael_vecbatch
  12.022 ms (9955 allocations: 496.25 KiB)
  12.407 ms (6755 allocations: 363.16 KiB)
true

julia> CUDA.device()
CuDevice(0): Tesla V100-PCIE-16GB
1 Like

Thanks a lot.
In my PC, your tullio3 and the explicit multiplication with trace turn out to be equal, and the fastest options. Probably MKL uses an optimised trace.