Julia SLURM + BLAS + Multithreading, threads not mapping well leading to poor performance

I am running benchmarking of a code that uses:

multithreading
BLAS (multithreaded)
BLAS (single threaded in parallel on Julia Threads)
on NERSC Perlmutter (AMD 2x [AMD EPYC 7763] 64 cores x 2 hyperthreads per core) (CRAYMPICH)

From my testing, I am seeing that when run a job:

salloc --nodes 1 --qos interactive --time 04:00:00 --constraint "cpu" --account=$ACCT_NUM  --hint=nomultithread --ntasks-per-node=2 --cpus-per-task=64  --exclusive

When I use ThreadPinning.jl to check the CPUs that are being used:

using ThreadPinning
println(ThreadPinning.getcpuids())

The threads returned are a mix of the hardware threads (0-63) and hyper threads (64-127).

When I use, for example


if rank == 0
    pinthreads(:affinitymask)
    # pinthreads(0:63) #has the same effect
else
    pinthreads(:affinitymask)
    # pinthreads(64:127) #has the same effect
end
.... more code

Some portions of the code show better performance, but not all.

My real question is:

Have others encountered performance issues when running Julia threads in a SLURM environment?

If it’s helpful, I can try to code up a small script that displays the various types of operations being performed. But in general, the main bottlenecks are some multithreaded BLAS calls (i.e. BLAS.set_num_threads(64)) and some threaded loops which call single-threaded BLAS (BLAS.set_num_threads(1))

I suspect @carstenbauer would be the most knowledgeable (also thank you for all your wonderful work on ThreadPinning.jl and the documentation you have on your wiki)

Can you elaborate which portion does better and which doesn’t? Does completely disabling BLAS parallelism help? Have you seen Pinning BLAS Threads · ThreadPinning.jl?

1 Like

I did play with pinning the BLAS threads a bit but didn’t see a noticeable improvement.

It was the multithreaded BLAS calls that don’t improve if I recall correctly. Removing all BLAS multithreading doesn’t help, and isn’t a viable solution because I rely it for a couple of matrix / LAPACK operations. I will play with that more and report what I see though.

I think to get to the bottom of this I will have to post a example script that replicates what I am seeing to separate the complexity of the code from the essentials of the bottleneck portions that can be replicated in spirit pretty easily.

1 Like

My guess is that in your MPI-program, where rank 0 has cores 0-63, and rank 1 has cores 64-127, the two ranks call the BLAS simultaneously. But since BLAS is typically doing vectorized operations, it’s hard to benefit from using hyper threads. There might not be enough spare resources. It may actually slow things down, I’m not sure.

If that’s the case, it’s a tricky problem, because the rest of the program may benefit from hyper threading.

I agree that is likely the case. I am trying to work with hyperthreading completely disabled. My understanding is that hyperthreading slows things down in the case where single-threaded BLAS calls are being used (the main bottleneck). I will post an example program soon.

Here is a gist with a version that hits the main points of the code that is running into the problems:

The key insight I have found so far based on the suggestions above is to unpin threads before doing any multithreaded blas call and pin directly before any Julia thread loop calling single-threaded blas calls.

 #unpin threads 
    ThreadPinning.unpinthreads() # turned out that when using mutlitreaded blas unpinning was necessary
    BLAS.set_num_threads(64) # Set the number of threads for BLAS
    BLAS.gemm!('T', 'N', -1.0, reshape(W, Q*n_occ, p), reshape(W, Q*n_occ, p), 0.0, correct_matrix)
    ThreadPinning.pinthreads(:affinitymask)
    BLAS.set_num_threads(1) # Set the number of threads for BLAS

    #... setup

    time_for_matmul = @elapsed Threads.@threads for block_index in 1:55
        
        p_range = p_ranges[block_index]
        q_range = q_ranges[block_index]
        A_view = view(W_reshape, :, p_range)
        B_view = view(W_reshape, :, q_range)
        C_view = view(blocks, :, :, block_index)

        # Perform the matrix multiplication for the block
        LinearAlgebra.BLAS.gemm!('T', 'N', -1.0, A_view, B_view, 0.0, C_view)
    end

full gist (for some reason i can’t post a reply with a link)

using BLISBLAS 
# using MKL #(if running on intel)
using LinearAlgebra
using MPI
using ThreadPinning

#hard coded for the case of 10x10 blocks for simplicity
# a symettic matrix multiplication of a p X p matrix by doing the blocks of the lower triangle 
function main_bottleneck(W::Array{Float64, 3})
    ThreadPinning.pinthreads(:affinitymask)
    BLAS.set_num_threads(1) # Set the number of threads for BLAS
    
    p = size(W, 3)
    final_matrix = zeros(Float64, p,p)
    blocks_wide = 10
    M = p ÷ blocks_wide
    N = p ÷ blocks_wide   
    K = size(W, 1) * size(W, 2)
    W_reshape = reshape(W, (K, p)) # reshape for matrix multiplication

    blocks = zeros(Float64, M, N, 55) # 10 x 10 grid of blocks, 55 is the blocks in the lower triangle  
    block_width = p ÷ blocks_wide
    p_ranges = Array{UnitRange{Int}}(undef, 0)
    q_ranges = Array{UnitRange{Int}}(undef, 0)

    for iii in 1:blocks_wide
        for jjj in 1:iii
            p_range = (iii-1)*block_width+1:iii*block_width
            q_range = (jjj-1)*block_width+1:jjj*block_width
            push!(p_ranges, p_range)
            push!(q_ranges, q_range)
        end
    end

    time_for_matmul = @elapsed Threads.@threads for block_index in 1:55
        
        p_range = p_ranges[block_index]
        q_range = q_ranges[block_index]
        A_view = view(W_reshape, :, p_range)
        B_view = view(W_reshape, :, q_range)
        C_view = view(blocks, :, :, block_index)

        # Perform the matrix multiplication for the block
        LinearAlgebra.BLAS.gemm!('T', 'N', -1.0, A_view, B_view, 0.0, C_view)
    end

    #separated the copy step out to just focus on the timing of the matrix multiplication
    time_for_copy = @elapsed Threads.@threads for block_index in 1:55
        p_range = p_ranges[block_index]
        q_range = q_ranges[block_index]
        C_view = view(blocks, :, :, block_index)
        final_matrix[p_range, q_range] .= C_view
        if p_range != q_range
            # Copy the transpose for the upper triangle
            final_matrix[q_range, p_range] .= transpose(C_view)
        end
    end

    return final_matrix, time_for_matmul, time_for_copy
    
end

# Example usage
function main(w_example_index)
    println("Running example $w_example_index")
    n_ranks = MPI.Comm_size(MPI.COMM_WORLD)
    rank = MPI.Comm_rank(MPI.COMM_WORLD)

    #example sizes for p, Q, n_occ
    W_sizes = [
    510 1950 81;
    1010 3870 161;
    1510 5790 241;
    2010 7710 321;]

    p = W_sizes[w_example_index, 1] # p is the third dimension size
    Q =  W_sizes[w_example_index, 2] ÷ n_ranks  # Q is the second dimension size
    n_occ =  W_sizes[w_example_index, 3] # n_occ is the first dimension size

    # Create a random 3D array W with dimensions (n_occ, Q, p)
    W = rand(Float64, Q, n_occ, p)

    final_matrix =  zeros(Float64, p, p)
    
    final_matrix, time, copy_time = main_bottleneck(W) #precompile step don't print time
    
    for test in 1:2
        final_matrix, time, copy_time = main_bottleneck(W)
        println("Test $test: Time for matrix multiplication: ", time, " seconds")
        # println("Time for copying results: ", copy_time, " seconds") #negligible for sufficiently large matrices
    end

    #reduce the final matrix on to rank 0 


    correct_matrix = zeros(Float64, p, p)
    
    #unpin threads 
    ThreadPinning.unpinthreads() # turned out that when using mutlitreaded blas unpinning was necessary
    BLAS.set_num_threads(64) # Set the number of threads for BLAS
    BLAS.gemm!('T', 'N', -1.0, reshape(W, Q*n_occ, p), reshape(W, Q*n_occ, p), 0.0, correct_matrix)
    multithreaded_blas_time = @elapsed BLAS.gemm!('T', 'N', -1.0, reshape(W, Q*n_occ, p), reshape(W, Q*n_occ, p), 0.0, correct_matrix)
    println("Time for multithreaded BLAS gemm: ", multithreaded_blas_time, " seconds")
    # Call the main_bottleneck function

    MPI.Reduce!(final_matrix, MPI.SUM, MPI.COMM_WORLD; root=0)
    MPI.Reduce!(correct_matrix,  MPI.SUM,MPI.COMM_WORLD; root=0)
    MPI.Barrier(MPI.COMM_WORLD) #just to be sure
    if rank == 0
       # Check if the final matrix matches the correct matrix
        diff =  final_matrix - correct_matrix
        max_diff = maximum(abs.(diff))
        println("Maximum difference: ", max_diff)
    end
   
end


MPI.Init()

println("cpuids for rank $(MPI.Comm_rank(MPI.COMM_WORLD)): ", sort(ThreadPinning.getcpuids()), "\n")

main(1)
main(2)
main(3)

#run.sh
# #!/bin/bash
# module load julia
# srun -N 1 -n 2 --cpu-bind=sockets --cpu-bind=v julia --project=$j_project_path --threads=64 './model_bottlenecks.jl'