Here is a gist with a version that hits the main points of the code that is running into the problems:
The key insight I have found so far based on the suggestions above is to unpin threads before doing any multithreaded blas call and pin directly before any Julia thread loop calling single-threaded blas calls.
#unpin threads
ThreadPinning.unpinthreads() # turned out that when using mutlitreaded blas unpinning was necessary
BLAS.set_num_threads(64) # Set the number of threads for BLAS
BLAS.gemm!('T', 'N', -1.0, reshape(W, Q*n_occ, p), reshape(W, Q*n_occ, p), 0.0, correct_matrix)
ThreadPinning.pinthreads(:affinitymask)
BLAS.set_num_threads(1) # Set the number of threads for BLAS
#... setup
time_for_matmul = @elapsed Threads.@threads for block_index in 1:55
p_range = p_ranges[block_index]
q_range = q_ranges[block_index]
A_view = view(W_reshape, :, p_range)
B_view = view(W_reshape, :, q_range)
C_view = view(blocks, :, :, block_index)
# Perform the matrix multiplication for the block
LinearAlgebra.BLAS.gemm!('T', 'N', -1.0, A_view, B_view, 0.0, C_view)
end
full gist (for some reason i can’t post a reply with a link)
using BLISBLAS
# using MKL #(if running on intel)
using LinearAlgebra
using MPI
using ThreadPinning
#hard coded for the case of 10x10 blocks for simplicity
# a symettic matrix multiplication of a p X p matrix by doing the blocks of the lower triangle
function main_bottleneck(W::Array{Float64, 3})
ThreadPinning.pinthreads(:affinitymask)
BLAS.set_num_threads(1) # Set the number of threads for BLAS
p = size(W, 3)
final_matrix = zeros(Float64, p,p)
blocks_wide = 10
M = p ÷ blocks_wide
N = p ÷ blocks_wide
K = size(W, 1) * size(W, 2)
W_reshape = reshape(W, (K, p)) # reshape for matrix multiplication
blocks = zeros(Float64, M, N, 55) # 10 x 10 grid of blocks, 55 is the blocks in the lower triangle
block_width = p ÷ blocks_wide
p_ranges = Array{UnitRange{Int}}(undef, 0)
q_ranges = Array{UnitRange{Int}}(undef, 0)
for iii in 1:blocks_wide
for jjj in 1:iii
p_range = (iii-1)*block_width+1:iii*block_width
q_range = (jjj-1)*block_width+1:jjj*block_width
push!(p_ranges, p_range)
push!(q_ranges, q_range)
end
end
time_for_matmul = @elapsed Threads.@threads for block_index in 1:55
p_range = p_ranges[block_index]
q_range = q_ranges[block_index]
A_view = view(W_reshape, :, p_range)
B_view = view(W_reshape, :, q_range)
C_view = view(blocks, :, :, block_index)
# Perform the matrix multiplication for the block
LinearAlgebra.BLAS.gemm!('T', 'N', -1.0, A_view, B_view, 0.0, C_view)
end
#separated the copy step out to just focus on the timing of the matrix multiplication
time_for_copy = @elapsed Threads.@threads for block_index in 1:55
p_range = p_ranges[block_index]
q_range = q_ranges[block_index]
C_view = view(blocks, :, :, block_index)
final_matrix[p_range, q_range] .= C_view
if p_range != q_range
# Copy the transpose for the upper triangle
final_matrix[q_range, p_range] .= transpose(C_view)
end
end
return final_matrix, time_for_matmul, time_for_copy
end
# Example usage
function main(w_example_index)
println("Running example $w_example_index")
n_ranks = MPI.Comm_size(MPI.COMM_WORLD)
rank = MPI.Comm_rank(MPI.COMM_WORLD)
#example sizes for p, Q, n_occ
W_sizes = [
510 1950 81;
1010 3870 161;
1510 5790 241;
2010 7710 321;]
p = W_sizes[w_example_index, 1] # p is the third dimension size
Q = W_sizes[w_example_index, 2] ÷ n_ranks # Q is the second dimension size
n_occ = W_sizes[w_example_index, 3] # n_occ is the first dimension size
# Create a random 3D array W with dimensions (n_occ, Q, p)
W = rand(Float64, Q, n_occ, p)
final_matrix = zeros(Float64, p, p)
final_matrix, time, copy_time = main_bottleneck(W) #precompile step don't print time
for test in 1:2
final_matrix, time, copy_time = main_bottleneck(W)
println("Test $test: Time for matrix multiplication: ", time, " seconds")
# println("Time for copying results: ", copy_time, " seconds") #negligible for sufficiently large matrices
end
#reduce the final matrix on to rank 0
correct_matrix = zeros(Float64, p, p)
#unpin threads
ThreadPinning.unpinthreads() # turned out that when using mutlitreaded blas unpinning was necessary
BLAS.set_num_threads(64) # Set the number of threads for BLAS
BLAS.gemm!('T', 'N', -1.0, reshape(W, Q*n_occ, p), reshape(W, Q*n_occ, p), 0.0, correct_matrix)
multithreaded_blas_time = @elapsed BLAS.gemm!('T', 'N', -1.0, reshape(W, Q*n_occ, p), reshape(W, Q*n_occ, p), 0.0, correct_matrix)
println("Time for multithreaded BLAS gemm: ", multithreaded_blas_time, " seconds")
# Call the main_bottleneck function
MPI.Reduce!(final_matrix, MPI.SUM, MPI.COMM_WORLD; root=0)
MPI.Reduce!(correct_matrix, MPI.SUM,MPI.COMM_WORLD; root=0)
MPI.Barrier(MPI.COMM_WORLD) #just to be sure
if rank == 0
# Check if the final matrix matches the correct matrix
diff = final_matrix - correct_matrix
max_diff = maximum(abs.(diff))
println("Maximum difference: ", max_diff)
end
end
MPI.Init()
println("cpuids for rank $(MPI.Comm_rank(MPI.COMM_WORLD)): ", sort(ThreadPinning.getcpuids()), "\n")
main(1)
main(2)
main(3)
#run.sh
# #!/bin/bash
# module load julia
# srun -N 1 -n 2 --cpu-bind=sockets --cpu-bind=v julia --project=$j_project_path --threads=64 './model_bottlenecks.jl'