Parallelism within Turing.jl model

would OhMyThreads.jl primitive work here?

Yes! By removing the x ~ dist from the threaded code, it means that it’s just arbitrary Julia code which the Turing @model macro doesn’t touch. So my belief is that any threading library will work.

I tested the following with 5 threads and to substantiate the above claim, I also threw in a FLoops.jl example.

Turing@0.40.3, DynamicPPL@0.37.5, OhMyThreads@0.8.3, FLoops 0.2.2, Julia 1.11.7

using Turing, Random

# Threads.@threads, not recommended!
@model function f1(x)
    m ~ MvNormal(zeros(4), I)
    Threads.@threads for i in eachindex(x)
        x[i] ~ Normal(m[i])
    end
end

# Threads.@spawn
@model function f2(x)
    m ~ MvNormal(zeros(4), I)
    lls = map(eachindex(x)) do i
        Threads.@spawn begin
            logpdf(Normal(m[i]), x[i])
        end
    end
    @addlogprob! sum(fetch.(lls))
end

# OhMyThreads.tmapreduce
using OhMyThreads
@model function f3(x)
    m ~ MvNormal(zeros(4), I)
    let m = m
        lls = tmapreduce(+, eachindex(x)) do i
            logpdf(Normal(m[i]), x[i])
        end
        @addlogprob! lls
    end
end

# OhMyThreads.@tasks
@model function f4(x)
    m ~ MvNormal(zeros(4), I)
    let m=m
        lls = @tasks for i in eachindex(x)
            @set begin
                reducer=+
            end
            logpdf(Normal(m[i]), x[i])
        end
        @addlogprob! lls
    end
end

# FLoops.@floop
using FLoops
@model function f5(x)
    m ~ MvNormal(zeros(4), I)
    let m=m
        @floop for i in eachindex(x)
            @reduce lls = 0 + logpdf(Normal(m[i]), x[i])
        end
        @addlogprob! lls
    end
end

x = randn(Xoshiro(468), 4)
# replace model constructor with `f1` through `f5`
mean(sample(Xoshiro(468), f1(x), NUTS(), MCMCThreads(), 1000, 3; progress=false))

Conclusions:

  • f1 is not reproducible; it gives different results on every run despite using explicit rng.
  • f2, f3, f4, and f5 are all reproducible.
  • f2, f3, and f4 all yield the same result, so I suppose under the hood they behave the same. f5 gives a reproducible result that is different from the others. (Note that the results were still qualitatively correct in that they satisfiedmean(m[i]) ≈ x[i] / 2).