I’m implementing a clustering algorithm where N is the number of samples, D the number of features and K the number of clusters. For each cluster 1 \leq k \leq K, I have to compute a probability that an individual belongs to a specific cluster p(C_n = k | \dots). In a simplified formulation, this amounts to compute:
p(C_n = k | \dots) = \prod_{d = 1}^D p(x_{dn} | \theta_{dk})
where the features are indexed by d, x_{dn} is the data corresponding to the d-th feature of the n-th sample, and p is the probability density function of some distribution with \theta_{dk} the parameters associated to the distribution. To simplify, let’s say that we consider only normal distributions, this reduces to (taking the log):
\log p(C_n = k | \dots) = \sum_{d = 1}^D \log p_{\mathcal{N}(\mu_{dk}, \sigma_{dk})}(x_{dn} | \mu_{dk}, \sigma_{dk})
So my first try at implementing this was the following:
for k = 1:K
for n = 1:N
p[n,k] = sum(map(d -> log_normal_pdf(x[d,n], μ[d,k], σ[d,k]), 1:D))
end
end
where x
is a D\times N data matrix, μ
and σ
are D \times K matrices for the different mean and standard deviations vectors for each cluster (and log_normal_pdf
computes the log-pdf of a normal distribution). Unfortunately this was very slow, so I rewrote the log-probability as follows:
\log p(C_n = k | \dots) = - \sum_{d = 1}^D \log \sigma_{dk} - \frac{D}{2} \log(2 \pi) - \frac{1}{2} \sum_{d=1}^D \left(\frac{x_{dn} - \mu_{dk}}{\sigma_{dk}} \right)^2
and made use of the mapslices
function:
sum_log_sigma = vec(sum(log.(σ), 1))
for k = 1:K
p[:,k] = - sum_log_sigma[k] - D / 2 * log(2 * π) - 1 / 2 * vec(sum(mapslices(z -> ((z - vec(μ[:,k])) ./ vec(σ[:,k])) .^ 2, x, 1), 1))
end
to perform vector operations for each cluster k for all the different samples 1 \leq n \leq N at once. It’s better than the previous formulation, in particular with large number of features D, but still too slow for what I’d like to do. I was therefore wondering if there were a cleverer implementation to perform such a task, involving different matrices and where it’s not a straightforward matrix product.
Would that be easy to parallelize the code w.r.t. the different clusters…? I don’t know much about parallelism in Julia I must say.
Anyway, thank you for taking the time to read all this and for any possible suggestions. If something is not clear in my message, I’d be glad to clarify it.
EDIT: following @kristoffer.carlsson’s sensible remark, I provide the following minimal example.
EDIT2: updated the code by also adding a parallel version.
EDIT3: added all the versions proposed in this discussion + another one based on @fastmath
using Base.Test
srand(10)
N = 100
D = 10000
K = 5
x = rand(D, N)
μ = rand(D, K)
σ = rand(D, K)
x0 = x[1:2,1:2]
μ0 = μ[1:2,1:2]
σ0 = σ[1:2,1:2]
log_normal_pdf(x, μ, σ) = - log(σ * sqrt(2 * π)) - (x - μ) ^ 2 / (2 * σ ^ 2)
function f1(x::Matrix, μ::Matrix, σ::Matrix)
D, K = size(μ)
D, N = size(x)
p = zeros(N, K)
for k = 1:K
for n = 1:N
p[n,k] = sum(map(d -> log_normal_pdf(x[d,n], μ[d,k], σ[d,k]), 1:D))
end
end
return p
end
function f2(x::Matrix, μ::Matrix, σ::Matrix)
D, K = size(μ)
D, N = size(x)
p = zeros(N, K)
sum_log_sigma = vec(sum(log.(σ), 1))
for k = 1:K
p[:,k] = - sum_log_sigma[k] - D / 2 * log(2 * π) - 1 / 2 * vec(sum(mapslices(z -> ((z - vec(μ[:,k])) ./ vec(σ[:,k])) .^ 2, x, 1), 1))
end
return p
end
function f2_parallel(x::Matrix, μ::Matrix, σ::Matrix)
D, K = size(μ)
D, N = size(x)
p = SharedArray{Float64}(N, K)
sum_log_sigma = vec(sum(log.(σ), 1))
@sync @parallel for k = 1:K
p[:,k] = - sum_log_sigma[k] - D / 2 * log(2 * π) - 1 / 2 * vec(sum(mapslices(z -> ((z - vec(μ[:,k])) ./ vec(σ[:,k])) .^ 2, x, 1), 1))
end
return p
end
function f3(x::Matrix, μ::Matrix, σ::Matrix)
D, K = size(μ)
D, N = size(x)
p = zeros(K, N)
sum_log_sigma = vec(sum(log.(σ), 1))
c = -sum_log_sigma - D / 2 * log(2 * π)
for j = 1:N
z = vec(sum(((x[:, j] .- μ) ./ σ) .^ 2, 1))
p[:, j] .= c - 1 / 2 * z
end
return p'
end
function f3_threads(x::Matrix, μ::Matrix, σ::Matrix)
D, K = size(μ)
D, N = size(x)
p = zeros(K, N)
sum_log_sigma = vec(sum(log.(σ), 1))
c = -sum_log_sigma - D / 2 * log(2 * π)
Threads.@threads for j = 1:N
z = vec(sum(((x[:, j] .- μ) ./ σ) .^ 2, 1))
p[:, j] .= c - 1 / 2 * z
end
return p'
end
function f4b_mt(x::Matrix, μ::Matrix, σ::Matrix)
DD, K = size(μ)
D, N = size(x)
assert(D == DD)
p = zeros(N, K)
cc = 0.5 * log(2.0 * pi) * D
# Threads.@threads for k = 1:K
for k = 1:K
logσ = cc
iσ = zeros(D)
@simd for d = 1:D
logσ += log(σ[d,k])
end
@simd for d = 1:D
iσ[d] = 0.5 / (σ[d,k] * σ[d,k])
end
@inbounds for n = 1:N
pnk = -logσ
@simd for d = 1:D
pnk -= (x[d,n] - μ[d,k]) * (x[d,n] - μ[d,k]) * iσ[d]
end
p[n,k] = pnk
end
end
return p
end
function f4b_mt_threads(x::Matrix, μ::Matrix, σ::Matrix)
DD, K = size(μ)
D, N = size(x)
assert(D == DD)
p = zeros(N, K)
cc = 0.5 * log(2.0 * pi) * D
Threads.@threads for k = 1:K
logσ = cc
iσ = zeros(D)
@simd for d = 1:D
logσ += log(σ[d,k])
end
@simd for d = 1:D
iσ[d] = 0.5 / (σ[d,k] * σ[d,k])
end
@inbounds for n = 1:N
pnk = -logσ
@simd for d = 1:D
pnk -= (x[d,n] - μ[d,k]) * (x[d,n] - μ[d,k]) * iσ[d]
end
p[n,k] = pnk
end
end
return p
end
function f4b_mt_threads_fast(x::Matrix, μ::Matrix, σ::Matrix)
DD, K = size(μ)
D, N = size(x)
assert(D == DD)
p = zeros(N, K)
cc = 0.5 * log(2.0 * pi) * D
Threads.@threads for k = 1:K
logσ = cc
iσ = zeros(D)
@fastmath @simd for d = 1:D
logσ += log(σ[d,k])
end
@fastmath @simd for d = 1:D
iσ[d] = 0.5 / (σ[d,k] * σ[d,k])
end
@fastmath @inbounds for n = 1:N
pnk = -logσ
@simd for d = 1:D
pnk -= (x[d,n] - μ[d,k]) * (x[d,n] - μ[d,k]) * iσ[d]
end
p[n,k] = pnk
end
end
return p
end
p10 = f1(x0, μ0, σ0)
p20 = f2(x0, μ0, σ0)
p20_parallel = f2_parallel(x0, μ0, σ0)
p30 = f3(x0, μ0, σ0)
p3_threads0 = f3_threads(x0, μ0, σ0)
p4b_mt0 = f4b_mt(x0, μ0, σ0)
p4b_mt_threads0 = f4b_mt_threads(x0, μ0, σ0)
p4b_mt_threads_fast0 = f4b_mt_threads_fast(x0, μ0, σ0)
function print_name(name, space = 30)
print(name)
for i = 1:space-length(name) print(" ") end
end
print_name("f1"); @time p1 = f1(x, μ, σ)
print_name("f2"); @time p2 = f2(x, μ, σ)
print_name("f2_parallel"); @time p2_parallel = f2_parallel(x, μ, σ)
print_name("f3"); @time p3 = f3(x, μ, σ)
print_name("f3_threads"); @time p3_threads = f3_threads(x, μ, σ)
print_name("f4b_mt"); @time p4b_mt = f4b_mt(x, μ, σ)
print_name("f4b_mt_threads"); @time p4b_mt_threads = f4b_mt_threads(x, μ, σ)
print_name("f4b_mt_threads_fast"); @time p4b_mt_threads_fast = f4b_mt_threads_fast(x, μ, σ)
@test p1 ≈ p2 ≈ p2_parallel ≈ p3 ≈ p3_threads ≈ p4b_mt ≈ p4b_mt_threads ≈ p4b_mt_threads_fast