I want to parallelize a for loop across multiple threads. Each thread will update some shared memory, so I need to avoid concurrent access of that memory.
One way is to simply give each thread a copy of the shared memory, and aggregate at the end. What tools are there in Julia to make this very simple?
If the shared memory is quite large, and the access time is a relatively small proportion of the overall computation, and I have N threads, then a better approach might be to have n copies of the shared memory and a pool of n write locks, where n < N. Each thread can then request a write lock when it needs to access the shared memory, and the lock returns the first shared memory instance available. The copies of shared memory still need to be aggregated at the end. Is such a mechanism available in Julia, currently?
There are various techniques available in OhMyThreads.jl, though I don’t know if they fit your use case. You can e.g. set up task local memory quite easily.
I’m not aware of something premade for your second suggestion, keeping a pool around, but there are locks(ReentrantLock, @lock), semaphores(Base.Semaphore, Base.acquire, Base.release) and atomic operations around (Threads.Atomic, @atomic) to handle low-level synchronization. It’s also possible to handle the synchronization via a Channel
mem = [fill(0.0, 1000) for _ in 1:10] # memory buffers
store = Channel{eltype(mem)}(length(mem))
# fill channel
for m in mem
put!(store, m)
end
Then, when you need a memory buffer, just take! it, and put it back when finished.
mymem = take!(store)
... do stuff to mymem...
put!(store, mymem)
OhMyThreads.jl is the most versatile way to go. The simplest from a point of view of basic Julia usage, is, I think, to use ChunkSplitters.jl directly:
julia> using ChunkSplitters, Base.Threads
julia> function parallel_sum(x; ntasks=Threads.nthreads())
partial_sum = zeros(ntasks)
@threads for (ichunk, chunk) in enumerate(chunks(x; n=ntasks))
for val in chunk
partial_sum[ichunk] += val
end
end
return sum(partial_sum)
end
parallel_sum (generic function with 1 method)
julia> x = rand(1000);
julia> sum(x)
489.82387373881164
julia> parallel_sum(x)
489.8238737388118