Parallelism within Turing.jl model

Hi there,

just wondering how safe it is to use Threads.@threads for loops within turing models e.g.

@model function my_func(Y)
    alpha ~ Normal(0,1)
    sigma ~ Normal(0,1)

    Threads.@threads for i in 1:size(Y)[2]
        Y[:,j] .~ Normal(alpha,sigma)
    end
end

This seems to work on my laptop. But I don’t currently have more than 1 thread available to me, so I can’t test it out properly. Is there any reason to avoid this in Turing?

Would also be nice to get a general view of how nicely Turing plays with parallel and distributed precessing within models. For instance, I’m working on a bayesian neural network using Turing and Flux. And it would be nice to utilise the GPU.

Cheers,
Arthur

This is threadsafe. We actually also used threads for our COVID-19 replication study, see: Covid19/models.jl at master · cambridge-mlg/Covid19 · GitHub.

Getting Turing working smoothly with the GPU, however, is not so nice and currently a big issue. If you are only interested in using Turing for a BNN, then I would recommend to write the log joint yourself and use AdvancedVI or AdvancedHMC directly. Then using the GPU should work (more or less). This is how I’m doing this atm but I haven’t actually used the GPU with HMC for now. So I would be a bit cautious with this as you might run into so weird problems.

@Kai_Xu will be able to say more about the HMC on GPU stuff.

Turing model is thread-safe. Even with laptop you should have more than one thread. Did you set the envirment variable (Multi-Threading · The Julia Language) correctly to let Julia use more than one thread?

Following Martin’s point, if you write your own log-density (computed by a Flux model), you can use the static HMC methods in AHMC, either on CPU or GPU. Vectorization is also supported as to run multiple chains in parallell, but you need to make sure your Flux model is coded to work with vectorization.

Thank you both. I will have a go with my own log density. How silly, I wasn’t setting the threading environment variable when using Julia on my laptop. The covid19 model is a useful example for moving beyond the basics presented in the tutorials.

Thanks for all your work on Turing. It really is brilliant!

I don’t think this should be so hard in Soss, and I’d love to have a nice example to dig into. Ideal would be something that works well with Flux for a point estimate, but currently requires some hand-tuning for a BNN. Any suggestions?

Just a BNN will do. The main issue why this is nontrivial in Turing is that 1) the bookkeeping in Turing currently would need some refactoring, 2) you would need to build some optimised computation graph out of the model to reduce memory mapping between CPU and GPU. KNet.jl is doing this but only for deep nets. Flux is in my understanding (which is probably outdated) still not very performant on the GPU because of this issue. But maybe this has changed.

As a Turing.jl developer in 2025 I’d just like to give a bit of an update on this, particularly for anybody who comes here via Google.

The official line is: Threads.@threads still works in Turing models, but it only works for likelihood terms i.e. x ~ dist where x is observed data (either specified in model arguments or conditioned upon). If you attempt to use it with prior terms, i.e. x is a random variable, the behaviour is undefined; last time I looked into it, doing this will either error or give wrong results.

Now, my (strong) personal opinion is that Threads.@threads, as well as the code in Turing that exists to handle this gracefully (DynamicPPL.ThreadSafeVarInfo for those in the know), is an antipattern. For example, see this Julia blog post. I would like to get rid of all of it, but I don’t yet know how to do it in a nice way. If you’re interested, you can see this issue and this PR.

It is a bit less “friendly” and more manual, but for those who are comfortable with Julia / Turing, I would like to gently suggest that instead of using

Threads.@threads for i in eachindex(x)
    some_other_computation()
    x[i] ~ dist
end

you instead explicitly calculate the log-likelihoods using logpdf, and then add the cumulative likelihood using @addlogprob! (see docs)

loglikes = map(x) do xi
    Threads.@spawn begin
        some_other_computation()
        return logpdf(dist, xi)
    end
end
@addlogprob! sum(fetch.(loglikes))

For Turing users, there are several benefits to this latter approach.

  1. You guard against future changes to the way parallelism works in Julia / Turing. As explained in the linked issue above, the threadsafe handling code in Turing is hacky and only happens to work in practice because we don’t typically run into the threadid issues described here. If Julia’s behaviour changes, we will have to change things in Turing accordingly, which may end up breaking Threads.@threads.
  2. Sampling a model that uses Threads.@spawn together with MCMCThreads() will yield reproducible results when the RNG is seeded. On the other hand, the combination of Threads.@threads and MCMCThreads() is not reproducible, even when seeding the RNG.
  3. Instead of directly using Threads.@spawn you can also use any parallelism library of your choice to calculate the likelihoods. (See subsequent comments in this thread for examples of OhMyThreads.jl and FLoops.jl.)
  4. You can also more easily extract the results of some_other_computation() in a thread-safe manner, should you need it outside the threaded loop: just add it to the return value of the Threads.@spawn block.
1 Like

would OhMyThreads.jl primitive work here?

loglikes = OhMyThreads.tmapreduce(+, x) do xi
        some_other_computation()
        logpdf(dist, xi)
end
@addlogprob! loglikes

or

loglikes = @tasks for xi in x
        @set begin
            reducer=+
        end
        some_other_computation()
        logpdf(dist, xi)
    end
@addlogprob! loglikes
1 Like

would OhMyThreads.jl primitive work here?

Yes! By removing the x ~ dist from the threaded code, it means that it’s just arbitrary Julia code which the Turing @model macro doesn’t touch. So my belief is that any threading library will work.

I tested the following with 5 threads and to substantiate the above claim, I also threw in a FLoops.jl example.

Turing@0.40.3, DynamicPPL@0.37.5, OhMyThreads@0.8.3, FLoops 0.2.2, Julia 1.11.7

using Turing, Random

# Threads.@threads, not recommended!
@model function f1(x)
    m ~ MvNormal(zeros(4), I)
    Threads.@threads for i in eachindex(x)
        x[i] ~ Normal(m[i])
    end
end

# Threads.@spawn
@model function f2(x)
    m ~ MvNormal(zeros(4), I)
    lls = map(eachindex(x)) do i
        Threads.@spawn begin
            logpdf(Normal(m[i]), x[i])
        end
    end
    @addlogprob! sum(fetch.(lls))
end

# OhMyThreads.tmapreduce
using OhMyThreads
@model function f3(x)
    m ~ MvNormal(zeros(4), I)
    let m = m
        lls = tmapreduce(+, eachindex(x)) do i
            logpdf(Normal(m[i]), x[i])
        end
        @addlogprob! lls
    end
end

# OhMyThreads.@tasks
@model function f4(x)
    m ~ MvNormal(zeros(4), I)
    let m=m
        lls = @tasks for i in eachindex(x)
            @set begin
                reducer=+
            end
            logpdf(Normal(m[i]), x[i])
        end
        @addlogprob! lls
    end
end

# FLoops.@floop
using FLoops
@model function f5(x)
    m ~ MvNormal(zeros(4), I)
    let m=m
        @floop for i in eachindex(x)
            @reduce lls = 0 + logpdf(Normal(m[i]), x[i])
        end
        @addlogprob! lls
    end
end

x = randn(Xoshiro(468), 4)
# replace model constructor with `f1` through `f5`
mean(sample(Xoshiro(468), f1(x), NUTS(), MCMCThreads(), 1000, 3; progress=false))

Conclusions:

  • f1 is not reproducible; it gives different results on every run despite using explicit rng.
  • f2, f3, f4, and f5 are all reproducible.
  • f2, f3, and f4 all yield the same result, so I suppose under the hood they behave the same. f5 gives a reproducible result that is different from the others. (Note that the results were still qualitatively correct in that they satisfiedmean(m[i]) ≈ x[i] / 2).