Chains vs iterations behavior in Turing sampler (conditional priors)

I wrote a simple model that looks like this:

@model function stam()
    # Prior distributions
    op = [1.0, 2.0, 3.0, 4.0]
	α ~ DiscreteUniform(1, length(op))
	β ~ truncated(Normal(op[α] , 0.1); lower=0.01, upper=10.0)
    return nothing

Basically, I want each iteration of the sampler to “choose” from op the mean value of the parameter β. However when I ran it using:

model2 = stam()
chain2 = sample(model2, MH(), MCMCSerial(), 20_000, 3; progress=true)

The result is:

Is there a way to make the sampler choose a different α on each iteration? I don’t understand why it chooses a single α for each chain, instead of each iteration.

Thank you!

There used to be an example of mixture models in the Turing tutorials, but the link no longer works. Unfortunately, I do not remember the approach they used. You could potentially use MCMCTempering.jl, which deals with multimodal distributions. I don’t know how well it is integrated with Turing however.

Thank you! However, I still don’t understand what’s going on under the hood that α is not being sampled under each iteration. Any clarification on this will be useful

Possibly, each iteration the sampler either resamples β or α. Resampling β keeping the same α works as expected. But resampling α while keeping β has very low probability because SD 0.1 is small. Essentially, the state space is split into several weakly connected basins. A simple way to test this hypothesis is to increase SD to 0.5 and see the chains switch α values.

I think Dan is probably right. The sampler seems to get stuck on one of the mixture components. If you are estimating parameters from data, you could marginalize the mixture components in this case. For example, the likelihood for one observation could be:

    y = 2.0

    μs = rand(Normal(0, 1), 4)

    αs = rand(Dirichlet(ones(4)))

    likelihoods = αs .* pdf.(Normal.(μs, 1), y)

You use logsumexp to change to a logpdf. Marginalizing is preferred when possible because it leads to more efficient sampling.

Thanks, both of you! Very insightful comments!
Any chance you can expand the marginalization example you wrote a little? I never did that on Turing.jl and couldn’t find good examples. I’m not sure how to define the variables and the likelihood for the simple case I wrote.
Thank you!!

No problem. I am working on a solution using MixtureModel which does not require you to handle the logpdf.

using Distributions 
using Turing
using Random 


# number of observations
n_obs = 100
# number of components
n_c = 4
# sample true μ for each component
μs = randn(n_c)
# true standard deviation of each component
σ = 1
# true component probabilities 
θs = rand(Dirichlet(ones(n_c)))
mixture = MixtureModel(Normal.(μs, σ), θs)
# generate some simulated data 
data = rand(mixture, n_obs)

@model function my_model(y, n_c, n_obs)
    μs ~ filldist(Normal(0, 1), n_c)
    θs ~ Dirichlet(ones(n_c))
    σ = 1.0
    y .~ MixtureModel(Normal.(μs, σ), θs)

chains = sample(my_model(data, n_c, n_obs), NUTS(1000, .65), MCMCThreads(), 1000, 4)

Thank you!! it’s really helpful!

No problem. The example above now works. Unfortunately, the error was quite stupid: I intended to pass data to the model, but I passed an old variable y, which was lurking in my session. I was discussing parameter recovery with someone on slack and they said that parameters are often difficult to recover in mixture models.

Evidently, the Turing website has changed. You can find some more info here:

You can also find information about identifiability here: