Multithreading with shared memory caches

Thanks for this! I wasn’t aware of Iterators.partition.

I suppose if you want the load balancing of multiple chunks per thread without having to spawn more tasks than threads, you could combine our approaches:

chunks_per_thread = 2
chunks = Iterators.partition(eachindex(A), length(A) ÷ (Threads.nthreads() * chunks_per_thread))

ch = Channel{eltype(chunks)}(Inf)
foreach(chunk -> put!(ch, chunk), chunks)
close(ch)

tasks = map(1:Threads.nthreads()) do _
    Threads.@spawn begin
        cache = zeros(3)
        for chunk in ch
            for i in chunk
                f!(cache, A[i])
                B[i] = g(cache, A[i])
            end
        end
        return cache
    end
end

caches = fetch.(tasks)

For long-running compute-bound tasks, it seems plausible that the overhead from the occasional lock when taking the next chunk from the channel would be less than the overhead from tasks stepping on each other due to there being more tasks than threads. Of course, one should profile to know for sure.

1 Like