would OhMyThreads.jl primitive work here?
Yes! By removing the x ~ dist from the threaded code, it means that it’s just arbitrary Julia code which the Turing @model macro doesn’t touch. So my belief is that any threading library will work.
I tested the following with 5 threads and to substantiate the above claim, I also threw in a FLoops.jl example.
Turing@0.40.3, DynamicPPL@0.37.5, OhMyThreads@0.8.3, FLoops 0.2.2, Julia 1.11.7
using Turing, Random
# Threads.@threads, not recommended!
@model function f1(x)
m ~ MvNormal(zeros(4), I)
Threads.@threads for i in eachindex(x)
x[i] ~ Normal(m[i])
end
end
# Threads.@spawn
@model function f2(x)
m ~ MvNormal(zeros(4), I)
lls = map(eachindex(x)) do i
Threads.@spawn begin
logpdf(Normal(m[i]), x[i])
end
end
@addlogprob! sum(fetch.(lls))
end
# OhMyThreads.tmapreduce
using OhMyThreads
@model function f3(x)
m ~ MvNormal(zeros(4), I)
let m = m
lls = tmapreduce(+, eachindex(x)) do i
logpdf(Normal(m[i]), x[i])
end
@addlogprob! lls
end
end
# OhMyThreads.@tasks
@model function f4(x)
m ~ MvNormal(zeros(4), I)
let m=m
lls = @tasks for i in eachindex(x)
@set begin
reducer=+
end
logpdf(Normal(m[i]), x[i])
end
@addlogprob! lls
end
end
# FLoops.@floop
using FLoops
@model function f5(x)
m ~ MvNormal(zeros(4), I)
let m=m
@floop for i in eachindex(x)
@reduce lls = 0 + logpdf(Normal(m[i]), x[i])
end
@addlogprob! lls
end
end
x = randn(Xoshiro(468), 4)
# replace model constructor with `f1` through `f5`
mean(sample(Xoshiro(468), f1(x), NUTS(), MCMCThreads(), 1000, 3; progress=false))
Conclusions:
f1is not reproducible; it gives different results on every run despite using explicit rng.f2,f3,f4, andf5are all reproducible.f2,f3, andf4all yield the same result, so I suppose under the hood they behave the same.f5gives a reproducible result that is different from the others. (Note that the results were still qualitatively correct in that they satisfiedmean(m[i]) ≈ x[i] / 2).