In my work I usually have to perform some operations on sparse matrices that involve mostly checking which entries are still important and remove those that are not. I thought that I could use multithreading for that, but my attempts didnβt produce a noticeable improvement. Therefore I am looking for a way to improve the performance of this use-case.
To give a simple example, imagine you have a sparse N x N matrix C, two vectors a and b of length N, and I want to make a new matrix with the entries of C that fulfill C[i,j] β₯ a[i] + b[j]. My serial implementation looks like this:
function keep_entries(C, a, b)
    n, m = size(C)
    I = Int[]
    J = Int[]
    V = Float64[]
    rows = rowvals(C) 
    Cval = nonzeros(C)
    for j in 1:m
        for r in nzrange(C, j)
            i = rows[r]
            if Cval[r] β₯ a[i] + b[j]
                push!(I, i)
                push!(J, j)
                push!(V, Cval[r])
            end
        end
    end
    K = sparse(I, J, V, n, m)
    return K
end
(I am using this answer to iterate over the columns of C).
Then, my approach using Threads.@threads consists in the following: divide the columns of C into as many chunks as available threads, then send each to work on one portion of C, and concatenate the results. The code looks like this:
function keep_entries_threaded(C, a, b)
    m, n = size(C)
    
    Nchunks = Threads.nthreads()
    Is = [Int[] for i in 1:Nchunks]
    Js = [Int[] for i in 1:Nchunks]
    Vs = [Float64[] for i in 1:Nchunks]
    
    chunk_size = nΓ·Nchunks
    j_chunks = [(k-1)*chunk_size+1:k*chunk_size for k in 1:Nchunks]
    j_chunks[end] = (Nchunks-1)*chunk_size+1:n # Adjust last chunk
    rows = rowvals(C) 
    Cval = nonzeros(C)
    Threads.@threads for k in 1:Nchunks
        for j in j_chunks[k]
            for r in nzrange(C, j)
                i = rows[r]
                if Cval[r] β₯ a[i] + b[j]
                    push!(Is[k], i)
                    push!(Js[k], j)
                    push!(Vs[k], Cval[r])
                end
            end
        end
    end
    I = vcat(Is...)
    J = vcat(Js...)
    V = vcat(Vs...)
    return sparse(I, J, V, m, n)
end
Starting julia with 4 threads I get the following results:
julia> N = 100000;
julia> C = sprand(N, N, 20/N); # On average 20 non-zero entries per column
julia> a = 0.5.*rand(N); b = 0.5.*rand(N); # On average 50% of entries satisfy `C[i,j] β₯ a[i] + b[j]`, 50% not
julia> @benchmark keep_entries($C, $a, $b)
BenchmarkTools.Trial: 54 samples with 1 evaluation.
 Range (min β¦ max):  86.917 ms β¦ 107.341 ms  β GC (min β¦ max): 0.00% β¦ 0.00%
 Time  (median):     91.397 ms               β GC (median):    0.00%
 Time  (mean Β± Ο):   92.826 ms Β±   4.475 ms  β GC (mean Β± Ο):  0.37% Β± 0.97%
           ββ β
β
    β                      β                    
  β
β
ββ
ββ
β
β
βββββββ
βββββββββ
ββ
ββββββββββ
βββ
βββββββ
ββββββββββββββ
 β
  86.9 ms         Histogram: frequency by time          104 ms <
 Memory estimate: 59.77 MiB, allocs estimate: 77.
julia> @benchmark keep_entries_threaded($C, $a, $b)
BenchmarkTools.Trial: 64 samples with 1 evaluation.
 Range (min β¦ max):  68.283 ms β¦ 91.087 ms  β GC (min β¦ max): 0.00% β¦ 0.00%
 Time  (median):     78.294 ms              β GC (median):    0.00%
 Time  (mean Β± Ο):   78.989 ms Β±  4.517 ms  β GC (mean Β± Ο):  0.16% Β± 0.32%
                       βββ β β   β
β                            
  β
ββββ
ββββ
βββββββ
β
β
βββββββ
βββββ
ββββ
β
ββ
β
βββ
β
β
β
β
β
β
βββββ
β
ββββ
ββ
 β
  68.3 ms         Histogram: frequency by time        89.2 ms <
 Memory estimate: 91.64 MiB, allocs estimate: 270.
So the improvement is pathetic. Raising the number of threads has a marginal effect on performance. Besides, the CPU activity looks like this:
which shows that the multi-threaded code is not taking advantage of all the coresβ power.
Some things I have tried and that didnβt improve the performance:
- Checking on vcatrunning time. For these vectors it runs in ~ 4 ms, so doesnβt seem to be the bottleneck.
- Checking on vcattype instability. I have tried annotating the output ofvcatwith the correct type, and replacingvcatwith a self-made function (type stable and approximately the same running time), but the results are basically identical.
Any ideas/help making this run like a proper parallel algorithm are more than welcome! 
