I am trying to write a good thread-parallel implementation of a problem that can be described quickly like this:
The user has to do n tasks, each of which require a workspace.
They can provide any number of workspaces they want—say they provide m of them.
Finally, they might be running it with any number of threads, say l, and it is possible that l > m.
What’s the correct way to still utilize multi-threading to the highest degree possible that is aware of m?
Here’s a somewhat motivating example: say that you wanted the log-determinant of some number of giant matrices. You don’t have enough RAM to create all of them at once, but perhaps you have enough RAM to create three or four of the matrices. But other parts of your program benefit from using many more threads than that, so you’re using more than three or four threads. Here’s a MWE that works when l \leq m:
using StableRNGs
BLAS.set_num_threads(1)
# Atomic doesn't seem to be designed for this type.
const workspaces = [zeros(100,100) for _ in 1:3]
function dostuff(n, spaces)
logdets = zeros(n)
Threads.@threads for i in 1:n
# Choose some _available_ workspace. If length(workspaces) is greater than
# Threads.nthreads() this works, but if not then obviously this fails, which
# makes me think this isn't the right way to do this.
w = workspaces[Threads.threadid()]
s = size(w,1)
# Fill the buffer with something:
x = randn(StableRNG(i), s)
for k in 1:s
xk = x[k]
@simd for j in 1:s
xj = x[j]
@inbounds w[j,k] = exp(-abs(xj-xk))
end
end
# Put the computed value in your output array:
w_fact = cholesky!(w)
logdets[i] = logdet(w_fact)
end
logdets
end
dostuff(30, workspaces)
But the fact that this breaks for l > m makes me think that this is not the correct way to write something like this. I’m aware of the Atomic structure, but I gather that a collection of matrices is not really the intended use case of that object.
Can somebody more knowledgeable than me provide any guidance? Thanks for reading!
Maybe this is useful. What I do is to use the ThreadPools package to control exactly to which thread each computation goes, and I control manually how many threads are being used all the time. This is useful for trivially parallelizable tasks, as it seems that yours is. One detail is that one needs to submit Threads.nthreads()-1 tasks, leaving thread 1 free for the main thread to control the execution of the others. Otherwise, if one of the long calculations enters into thread 1, the submission of new calculations to other threads will be stalled.
nspawn = 4 # Number of threads that will be used (typically Threads.nthreads()-1)
itask = 0
ndone = 0
free = ones(Bool,nspawn) # Vector that will store if the thread is free
t = Vector{Task}(undef,nspawn) # Vector that will manage the tasks
while ndone < ntasks
# Launch for each free thread the computation one computation
while itask < ntasks && count(free) > 0
ifree = findfirst(x->x==true,free)
# Submit next task
itask += 1
#ifree+1 is to avoid using thread number 1
t[ifree] = ThreadPools.@tspawnat ifree+1 run_the_calculation
free[ifree] = false
end
# Wait a little bit before checking
sleep(0.1)
# Check thread status
for ispawn in 1:nspawn
if ! free[ispawn]
if istaskfailed(t[ispawn])
error(" Computation failed in thread: $ispawn", fetch(t[ispawn]))
end
if istaskdone(t[ispawn])
ndone += 1
free[ispawn] = true
# fetch data for this run...
# Do garbage collection (I needed this in my case)
if options.GC && (Sys.free_memory() / Sys.total_memory() < 0.1)
GC.gc()
end
end
end
end
You can create one worker (one Task) for each workspace, where each worker receives inputs from a Channel:
function eachstuff(input, w)
s = size(w,1)
# Fill the buffer with something:
x = randn(StableRNG(input), s)
for k in 1:s
xk = x[k]
@simd for j in 1:s
xj = x[j]
@inbounds w[j,k] = exp(-abs(xj-xk))
end
end
# Put the computed value in your output array:
w_fact = cholesky!(w)
logdet(w_fact)
end
function dostuff(n, spaces)
outputs = zeros(n)
chn = Channel{Int}(n) do chn
for input = 1 : n
put!(chn, input)
end
end
workers = collect(Threads.@spawn(for input = chn
outputs[input] = eachstuff(input, spaces[$k])
end) for k = eachindex(spaces))
foreach(wait, workers)
outputs
end
If you wish, you can create only min(length(spaces), Threads.nthreads()) workers to guarantee that you are not creating more workers than you can run on parallel.
Hey @lmiq and @lucas711642—thank you so much for your responses! They are both incredibly helpful. These examples and the docs for some keywords that I didn’t know to be looking for—like Task—are all wonderful. I wish I could mark them both as solutions. Thanks again!