Implementing parallel sum

In MATLAB, summing the entries of a vector in a for loop can be easily parallelized using parfor of Parallel Computing Toolbox. I remember @threads served a more or less similar purpose in Julia, but after reading Announcing composable multi-threaded parallelism in Julia, I felt @spawn is going to be the future standard for multithreading, so I decided to give it a try. I am using Julia 1.4.

Here are the contents of my parsum.jl file that implements a parallel sum and compares its performance with the built-in sum:

import Base.Threads: nthreads, @spawn

function parsum(v::Vector{Float64})
    Nt = nthreads()
    Lt = length(v) Ă· Nt

    # Calculate partial sums.
    s = Vector{Task}(undef, Nt-1)
    for j = 1:Nt-1  # last chunk is handled separately just in case length(v) is not multiple of Nt
        lo = (j-1) * Lt + 1
        hi = j * Lt
        s[j] = @spawn sum(@view(v[lo:hi]))
    end

    # Add partial sums.
    lo_last = (Nt-1) * Lt + 1
    stot = sum(@view(v[lo_last:end]))  # handle last chunk
    for j = 1:Nt-1
        stot += fetch(s[j])
    end

    return stot
end

# Check if parsum gives a correct result.
N = 1000000
v = rand(N)

sseq = sum(v); println("sequential sum = $sseq")
spar = parsum(v); println("parallel sum = $spar")
println("error between parallel and sequential = $(abs((spar-sseq) / sseq))")  # check if we get same result except for rounding errors
println()

# Compare the performance of the sequential and parallel sums.
using BenchmarkTools
print("sequential performance: "); @btime sum($v)
print("parallel performance: "); @btime parsum($v)  # faster!

Running this file from the shell gives the following result:

$ JULIA_NUM_THREADS=6 julia benchmark/parsum.jl
sequential sum = 499860.1362742374
parallel sum = 499860.1362742374
error between parallel and sequential = 0.0

sequential performance:   178.240 ÎĽs (0 allocations: 0 bytes)
parallel performance:   35.752 ÎĽs (57 allocations: 4.16 KiB)

So, I get a performance improvement of about a factor of 5–6 by using 6 threads, which is very nice. (The script was run on a 6-core machine.) I also see that the CPU usage soars up to 600%, which means all the 6 cores are used.

I am pretty happy with this result, but I am curious if this is the best practice of using @spawn for implementing a parallel sum, or if there is anything I am missing. Any comments or advices will be very much appreciated!

10 Likes

Very nice. How about a different number of threads? 2, 4? How does it scale?

Here is the performance gain plot. The ideal gain with N threads is a factor of N (blue line). The actual performance gain falls behind the ideal, but it still increases with N until N hits the number of physical cores (6 in my case). Using threads more than the physical cores makes the performance degrade.
performance

5 Likes

Have you tried a pairwise recursive approach that simply spawns a lot of tasks and lets the runtime sort it out? Something like

function _psum(a, istart, iend)
    if iend < istart + 2048
        iend < istart && return zero(eltype(a))
        s = a[istart]
        for i = istart+1:iend
            @inbounds s += a[i]
        end
        return s
    else
        imid = (istart + iend) >> 1
        s1 = @spawn _psum(a, istart, imid)
        s2 = @spawn _psum(a, imid+1, iend)
        return fetch(s1) + fetch(s2)
    end
end
psum(a) = _psum(a, firstindex(a), lastindex(a))

where you replace the 2048 with some appropriate threshold (determined experimentally) to switch over to the serial sum.

Comparing this divide-and-conquer approach to your parsum’s “static” schedule of nthreads() equal-sized sums would be a good test of the scheduler. (The advantage of spawning many more tasks than threads is that it allows the runtime to load-balance if the loop iterations take unequal time or if you have other parallel tasks running.)

(It would be nice to have a threaded mapreduce function, but I don’t see an issue for this?)

7 Likes

Just FYI Transducers.jl has thread-based reduce and process-based dreduce (so that you can do not only mapreduce in parallel but also filtering and flattening). The reduce implementation use the divide-and-conquer approach to handle unequal process time per element and also to support early termination. But my impression is that spawning tasks is still a bottleneck so that base case size has to be close to length(input) / nthreads() for “light weight” computation like sum. I haven’t done systematic benchmark yet, though.

8 Likes

I tried a similar version a few hours ago and found that as @tkf says, spawning tasks is quite costly versus the underlying sum, so if you spawn too many it’ll dominate the computational cost.

1 Like

Certainly you don’t want to recurse down to a small base case. The question is whether the cost of spawning is 100x more expensive than addition, or 1000x, or 1000000x, etcetera, as determined by the required size of the base case. (Base.mapreduce uses a base-case size of 1024 just to eliminate the overhead of recursion, so presumably a significantly larger base case is required for spawning.)

length(input) / nthreads() doesn’t sound right — the cost of spawning can’t depend on the size of the input. There has to be a fixed base-case size (for a given operation like sum) beyond which the cost of spawning is negligible.

I am assuming (1) many operations users want to do with reduce does not need auto-scheduling benefit from the divide-and-conquer approach but (2) by the fact that users calling threaded reduce, the overall computation time is long enough. With those assumptions, I think length(input) / nthreads() is a sane default at least for now. At the call site where you know the actual computation and input so that the trade-off is clearer, I agree that it is definitely worth setting the base case size (there is a keyword argument for it).

I agree that not spawning too many more tasks than threads is a sensible default for unknown operations given that @spawn is fairly expensive. But I would tend to use at least something like length(input) Ă· 5nthreads() in a parallel mapreduce, so that you are spawning around 5nthreads() tasks at the leaves, in order to give the scheduler some leeway to load balance unequal costs.

2 Likes

Thanks for the suggestion. Decreasing base case size sounds like a good approach. I’ll tweak the parameter once I gather some practical use cases for the benchmark.

Also note that when the two sums differ, the parallel sum will be more nearly correct more often than not. It is correct (to Float64 accuracy) in the above example.

i made a simple case to test the threshold in my machine (i7 4720HQ). that threshold is around 35000 elements

x = rand(35_000)
julia> a = (@belapsed parsum($x))/length(x)
1.6750081380208332e-10 # 0.16 ns per element

julia> a = (@belapsed sum($x))/length(x)
1.8338216145833332e-10 # 0.18 ns per element
2 Likes

from just theory, the algorithm of parsum is the following:

  1. nt times spawn a thread, (´nt*t_thr´)
  2. perform a operation that depends linearly w.r.t the length of elements (tf = ni*fi) with a subset of the total elements (tf/nt). as this is performed in threads we just count the last operation
  3. the total time of parsum is nt*t_thr+tf/nt
  4. the performance of parsum wil be better that the performance of sum if:
    nt*t_thr+tf/nt > tf
    in the limit, solving for tf:
    tf = (nt^2)/(nt-1)*t_thread

supposing optimal thread performance:

function ftr(n::Int) #just to measure time to spawn a thread. total time/nthreads = t_thread
    x = Vector{Task}(undef,n)
    for i = 1:n
    x[i] = @spawn 1
    end
    res = 0
    for i = 1:n 
        res = fetch(x[i])
    end
end
(@belapsed ftr(nthreads()))*nthreads()/(nthreads()-1)
7.600285714285715e-6

so, in theory, sum(x) has to take more that 7.6 microseconds to be outperformed with parasum,
if we need a good starting point, one microsecond to spawn a thread and 0.2 nanosecods to sum per element can be used, and the resulting threshold is 1*nt^2/(nt-1)/0.0002, or 5000*nt*nt/nt-1 as a conservative limit. that seems in agreement with my measurements (i used 8 threads and (5000*64/9 = 35555)

3 Likes