If you use multi-threading, you can simply use an atomic counter combined with the entire reduction infrastructure in JuliaFolds. In particular, FLoops’ support of break
in parallel reduction is useful here:
using FLoops
using BangBang
function sample_somethings(f, nitems)
len = Threads.Atomic{Int}(0)
@floop for x in 1:typemax(Int)
if len[] > nitems
break
end
y = f()
if y !== nothing
Threads.atomic_add!(len, 1)
items = (something(y),)
@reduce() do (results = Union{}[]; items)
results = append!!(results, items)
end
end
end
return results
end
julia> sample_somethings(() -> rand() > 0.5 ? 1 : nothing, 100)
101-element Vector{Int64}:
1
1
1
1
⋮
1
1
1
Unfortunately, this strategy is rather tricky to implement for a distributed setting. Perhaps the easiest way to implement this is to use a RemoteChannel
:
using Distributed
using BangBang
function distributed_sample_somethings(f, nitems, batchsize = 20)
chan = RemoteChannel()
@sync try
for w in workers()
Distributed.@spawnat w try
while true
ys = map(1:batchsize) do _
f()
end
try
put!(chan, ys)
catch
break
end
end
finally
close(chan)
end
end
results = Union{}[]
while true
ys = take!(chan)::Vector
results = append!!(results, Iterators.filter(!isnothing, ys))
if length(results) > nitems
close(chan)
break
end
end
return results
finally
close(chan)
end
end
julia> distributed_sample_somethings(() -> rand() > 0.5 ? 1 : nothing, 100)
108-element Vector{Int64}:
1
1
1
1
⋮
1
1
1
Here, batchsize
is an extra parameter to exchange the communication cost with possibly-wasted computation.
As a side note, I’d point out that we are forced to use an ugly catch-all pattern that can hide potential bugs:
try
put!(chan, ys)
catch
break
end
This is probably impossible to avoid when only using public API. It’d be nice if we can have something like maybetake!
in https://github.com/JuliaLang/julia/pull/41966