What you are seeing is likely just overhead from creating and scheduling Tasks. Assuming length(nodes)==10
this means that your nested loops run 100 times so each iteration takes ~400µs. This thread measured the overhead of threading on the order of ~20µs so a bit smaller but that was a somewhat idealized scenario with a single thread. I suppose the scheduling overhead will be somewhat larger for multiple threads.
Generally, when using Julia threads you will cause some allocations, so the workload of each thread needs to be significant. So using larger code chunks is generally better than having many small chunks. This is why your second version is much faster. Note that it is also dangerously close to being incorrect since
I am actually not sure whether you are completely safe in this case.
You could rewrite it like so:
using ChunkSplitters
function nll(mR::T, sigmaR::T, mC::T, sigmaC::T, Y = Y, pit = pit, nodes = nodes, weights = weights) where {T <: Real}
nchunks = Threads.nthreads()
ll = zeros(T, nchunks)
Threads.@threads for (Yrange, chunkid) in ChunkSplitters.chunks(Y, nchunks)
workspace = zeros(length(nodes), length(nodes) # preallocate workspace
for i in Yrange
ll[:,chunkid] += indll!(workspace, mR, sigmaR, mC, sigmaC, Y[i])
end
end
return -sum(ll)
end
function indll!(indll, mR::T, sigmaR::T, mC::T, sigmaC::T, yi, pit = pit, nodes = nodes, weights = weights) where {T <: Real}
for k1 ∈ eachindex(nodes)
for k2 ∈ eachindex(nodes)
r = exp(mR + sqrt(2) * sigmaR * nodes[k1])
c = exp(mC + sqrt(2) * sigmaC * nodes[k2])
EMAX = calcEMAX(r, c)
indllnode = zero(T)
for t ∈ eachindex(yi)
u0 = pit[t] * r - c + (1 - pit[t]) * EMAX[t]
if yi[t] == 0
indllnode += u0 - log1pexp(u0)
else
indllnode += -log1pexp(u0)
end
end
indllnode = indllnode + log(weights[k1] * weights[k2] / pi)
indll[k2,k1] = indllnode
end
end
return logsumexp(vec(indll)) # vec reshapes any array to a vector (it's free!)
end
I took the liberty to also preallocate some workspace for the innerfunction, which I changed to a 2D-Array for easier indexing. It is safer than your original versionsince it does not rely on Threads.threadid()
and should be faster as well.
EDIT: Fixed usage of ChunkSplitters