There are a few separate but possibly confounding issues here:
-
rand()
is now thread-safe, but this has an overhead
(see The need for rand speed | Blog by Bogumił Kamiński) -
You can explicity pass an RNG to
rand()
to bypass this overhead, but beware, if you do this you also bypass thread-safety.
unsafe threaded code
using Random
function pi_serial(rng, n::Int)
s = 0
for _ in 1:n
s += rand(rng)^2 + rand(rng)^2 < 1
end
return 4 * s / n
end
function pi_threaded(rng, n::Int)
s = 0
Threads.@threads for _ in 1:n
s += rand(rng)^2 + rand(rng)^2 < 1
end
return 4 * s / n
end
const n = 10^8
const rng = MersenneTwister()
println("Num threads: ", Threads.nthreads())
pi_serial(rng, n) # correct
pi_threaded(rng, n) # incorrect
- To gain full benefit from threading with OP example:
- explicitly supply a separate RNG for each thread
- accumulate into thread-local variable, rather than accumulating directly into array
safe threaded code
using Random, Future, BenchmarkTools
function sum_rand_serial(rng, n)
s = 0.0
for i in 1:n
s += rand(rng)
end
s
end
function sum_rand_parallel(rngs, n)
nthreads = Threads.nthreads()
s = zeros(nthreads)
n_per_thread = n ÷ nthreads
Threads.@threads for i in 1:nthreads
rng = rngs[i]
si = 0.0
for j in 1:n_per_thread
si += rand(rng)
end
s[i] = si
end
sum(s)
end
function parallel_rngs(rng::MersenneTwister, n::Integer)
step = big(10)^20
rngs = Vector{MersenneTwister}(undef, n)
rngs[1] = copy(rng)
for i = 2:n
rngs[i] = Future.randjump(rngs[i-1], step)
end
return rngs
end
println("Num threads: ", Threads.nthreads())
const N = Threads.nthreads() * 10^8
const rng = MersenneTwister();
const rngs = parallel_rngs(MersenneTwister(), Threads.nthreads());
@btime sum_rand_serial(rng, N)
@btime sum_rand_parallel(rngs, N)
On my machine using 12 threads, I get ~11x speed up.