rand()
is not thread-safe. You can solve that by creating a separate RNG per thread. See sample code below.
Also note that you should use @btime
in BenchmarkTools instead of @time
if you want to measure time somewhat accurately.
I also got rid of the Distributed
and @everywhere
stuff since it’s not needed for multi-threading. And I’ll repeat my earlier question – are you looking for multi-threading (single computer, multiple cores), or parallel computing (cluster, several distributed machines)? What does your cluster look like? The code below does multi-threading only.
Finally note that since your function main
is so short in this example, the way you’re splitting work per thread will incur a lot of overhead. So you can’t expect to see a speedup linear to the number of threads used. There are ways around that, if indeed the work done per iteration is so small. A simple way is to handle batches at a time instead of individual elements.
const RNG = 1:Threads.nthreads() .|>_-> MersenneTwister()
function main(rng)
roll = rand(rng)
pars = rand(rng, 5)
status = roll < 0.5 ? "cure" : "progression"
(roll, status, pars)
end
function serial(f,nT)
results = Array{Tuple{Float64,String,Array{Float64}}}(undef,nT)
rng = MersenneTwister()
for n = 1 : nT
results[n] = f(rng)
end
return results
end
function threaded(f,nT)
results = Array{Tuple{Float64,String,Array{Float64}}}(undef,nT)
Threads.@threads for n = 1 : nT
rng = RNG[Threads.threadid()]
results[n] = f(rng)
end
return results
end
nT = 10^5
@time serialresults = serial(main, nT)
@time threadedresults = threaded(main, nT)
nothing