Hi!
I’ve been learning Turing for the last couple of weeks and while trying to make the most out of the 8 cores in my computer, I realized performance was not scaling as I expected when sampling multiple chains with multiple cores. So I setup the turing coinflip example from the turing intro to run in a single thread, multi-threaded and also distributed. Did the runs using 6 cores just to be sure nothing else in my system is providing a bottleneck. For the single and multi-threaded cases I used a common setup for the model and sampler parameters:
# Using Base modules.
using Random
# Load Turing
using Turing
# Load the distributions library.
using Distributions
#Create model for a set of coin flips
@model function coinflip(y)
# Our prior belief about the probability of heads in a coin.
p ~ Beta(1, 1)
# The number of observations.
N = length(y)
for n in 1:N
# Heads or tails of a coin are drawn from a Bernoulli distribution.
y[n] ~ Bernoulli(p)
end
end;
#Create artificial data for coin flips
# Iterate from having seen 0 observations to 100 observations.
Ns = 0:100;
# Set the true probability of heads in a coin.
p_true = 0.5
# Draw data from a Bernoulli distribution, i.e. draw heads or tails.
Random.seed!(12)
data = rand(Bernoulli(p_true), last(Ns));
# Settings of the Hamiltonian Monte Carlo (HMC) sampler.
iterations = 5000000
ϵ = 0.05
τ = 10
and then sampled with one thread (used @time rather than @benchmark but repeating a couple of times numbers were consistent)
#SINGLE THREAD EXAMPLE
function perform_sampling_solo(ϵ,τ,iterations)
chains = sample(coinflip(data), HMC(ϵ, τ), iterations);
return chains
end
# Start sampling.
@time chains = mapreduce(c -> perform_sampling_solo(ϵ,τ,iterations), chainscat, 1:6)
and sampled with multiple threads (used JULIA_NUM_THREADS=6)
#MULTI-THREADED EXAMPLE
function perform_sampling(ϵ,τ,iterations)
chains = sample(coinflip(data), HMC(ϵ, τ), MCMCThreads(), iterations, 6);
return chains
end
# Start sampling
@time perform_sampling(ϵ,τ,iterations)
To perform the distributed test I did the following
#DISTRIBUTED EXAMPLE
# Load Distributed to add processes and the @everywhere macro.
using Distributed
addprocs(6)
@everywhere begin
# Using Base modules.
using Random
# Load Turing
using Turing
# Load the distributions library.
using Distributions
using BenchmarkTools
#Create artificial data for coin flips
# Iterate from having seen 0 observations to 100 observations.
Ns = 0:100;
# Set the true probability of heads in a coin.
p_true = 0.5
# Draw data from a Bernoulli distribution, i.e. draw heads or tails.
Random.seed!(12)
data = rand(Bernoulli(p_true), last(Ns));
# Settings of the Hamiltonian Monte Carlo (HMC) sampler.
iterations = 5000000
ϵ = 0.05
τ = 10
end
#Create model for a set of coin flips
@everywhere @model function coinflip(y)
# Our prior belief about the probability of heads in a coin.
p ~ Beta(1, 1)
# The number of observations.
N = length(y)
for n in 1:N
# Heads or tails of a coin are drawn from a Bernoulli distribution.
y[n] ~ Bernoulli(p)
end
end;
function perform_sampling(ϵ,τ,iterations)
chains = sample(coinflip(data), HMC(ϵ, τ), MCMCDistributed(), iterations, 6);
return chains
end
# Start sampling
@time perform_sampling(ϵ,τ,iterations)
Additional info, I’m using julia 1.7.1 with the following packages in my environment:
[6e4b80f9] BenchmarkTools v0.7.0
[31c24e10] Distributions v0.24.18
[7073ff75] IJulia v1.23.2
[c7f686f2] MCMCChains v4.14.1
[c3e4b0f8] Pluto v0.17.5
[f3b207a7] StatsPlots v0.14.30
[fce5fe82] Turing v0.15.1
So the results I get from these tests are:
- single thread = 1590 seconds
- multi threaded = 1130 seconds
- distributed = 360 seconds
During both the multi-threaded and the distributed runs I see 6 of my cpu cores running at almost 100%. But still there is a very significant difference in performance. From the single threaded run I get just a 30% increase in performance by using 6 cores with multi-threading. The distributed run performs far better, with a ~440% boost in performance.
Am I missing something fundamental in the implementation of the multi-threaded case here?