Turing.jl programmatically set number of clusters in a mixture model

I am trying to model some data using a mixture model with Turing.jl, but I want to programmatically set the number of Gaussian in the mixture without having to edit the model. For example if I want ten Gaussians I don’t want to manually add 10
mu ~ Normal() statements.

Currently my working ‘manual’ model looks like:

using Turing, Distributions

@model function gmm()
   try
       ω ~ Dirichlet(3, 1.0)


       μ1 ~ MvNormal(zeros(2), I)
       Σ1 ~ InverseWishart(2, [1 0;0 1])
  
       μ2 ~ MvNormal(zeros(2), I)
       Σ2 ~ InverseWishart(2, [1 0;0 1])


       data ~ MixtureModel(
           [
               Product(fill(Uniform(-10,10), 2)),
               MvNormal(μ1, Σ1),
               MvNormal(μ2, Σ2)
           ],
           ω
       )
   catch me
       if me isa PosDefException
           Turing.@addlogprob! -Inf
       end
   end
end

I saw others recommend using something like the following:

using Turing, LinearAlgebra, Distributions

@model function test_mixture_model(;K = 3)

    try
        ω ~ Dirichlet(K, 1.0)

        μ ~ filldist(MvNormal(5 .*ones(2), 5*I), K-1)
    
        obs ~ MixtureModel(
            [
                id == 1 ? Product(fill(Uniform(-10, 10), 2)) : MvNormal(μ[id-1], I) for id in 1:K
            ],
            ω
        )

    catch me 
        if me isa PosDefException
            Turing.@addlogprob! -Inf
        end
    end
end

model = test_mixture_model()

ϕ = MixtureModel([Product(fill(Uniform(-10,10), 2)), MvNormal(zeros(2), I), MvNormal([7,7], I)])
data_synth = rand(ϕ,1_000)
model_cond = model | (; obs = data_synth)
chain = sample(model_cond, NUTS(; adtype=AutoForwardDiff()), 1000)

However, I encountered an issue with all the μs becoming tied together, i.e. they contain similar values even though the ground truth has the clusters far apart. Then If I try to add Σ as a parameter they don’t show up as inferred values.

Summary Statistics
  parameters      mean       std      mcse    ess_bulk   ess_tail      rhat   ess_per_sec 
      Symbol   Float64   Float64   Float64     Float64    Float64   Float64       Float64 

        ω[1]    0.3448    0.2338    0.0073   1037.2750   695.0540    1.0017     2160.9896
        ω[2]    0.3301    0.2335    0.0078    864.5741   679.1482    1.0019     1801.1959
        ω[3]    0.3251    0.2315    0.0077    845.4391   809.9245    1.0006     1761.3315
     μ[1, 1]    4.9690    2.0700    0.0555   1356.5816   901.3844    1.0024     2826.2116
     μ[2, 1]    4.9583    2.1976    0.0582   1411.3363   785.8980    1.0046     2940.2841
     μ[1, 2]    5.0194    2.2868    0.0686   1111.6677   861.5815    0.9997     2315.9744
     μ[2, 2]    5.0070    2.1828    0.0701    967.7597   786.9313    0.9993     2016.1660

To summarize, I would like to have a mixture model where the user can select the number of Gaussian to include. Additionally I would like the mean and covariance of those Gaussian to be inferred by the turing from the data. Is there a recommend way of doing this?

2 Likes

As a quick amateur response, you have probably checked this tutorial Gaussian Mixture Models – Turing.jl? I noticed your second approach looks similar to it, just need to make another filldist containing the covariance matrices and loop over it along the means. Mixture models can also suffer from a label switching problem and they solve this using Bijectors.ordered. Finally, it might help to get help if you add code on how you added Σ as a parameter without success.

1 Like

EDIT: Updated how I index mu and Sigma, but it doesn’t seem to help with inferring the parameters.

When I add the covariance to the example the updated model looks like:

@model function test_mixture_model(;K = 3)

    try
        ω ~ Dirichlet(K, 1.0)

        μ ~ filldist(MvNormal(5 .*ones(2), 5*I), K-1)
        Σ ~ filldist(InverseWishart(2, [1 0;0 1]), K-1)
    
        obs ~ MixtureModel(
            [
                id == 1 ? Product(fill(Uniform(-10, 10), 2)) : MvNormal(μ[:,id-1], Σ[:,:,id-1]) for id in 1:K
            ],
            ω
        )

        println("hello")

    catch me 
        if me isa PosDefException
            Turing.@addlogprob! -Inf
        end
    end
end

Then by running:

model = test_mixture_model()

ϕ = MixtureModel([Product(fill(Uniform(-10,10), 2)), MvNormal(zeros(2), I), MvNormal([7,7], I)])
data_synth = rand(ϕ,1_000)
model_cond = model | (; obs = data_synth)
chain = sample(model_cond, NUTS(; adtype=AutoForwardDiff()), 1000)

I get:

julia> chain = sample(model_cond, NUTS(; adtype=AutoForwardDiff()), 1000)

┌ Info: Found initial step size
└   ϵ = 1.6
Chains MCMC chain (1000×19×1 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 7.25 seconds
Compute duration  = 7.25 seconds
parameters        = ω[1], ω[2], ω[3], μ[1, 1], μ[2, 1], μ[1, 2], μ[2, 2]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse    ess_bulk   ess_tail      rhat   ess_per_sec 
      Symbol   Float64   Float64   Float64     Float64    Float64   Float64       Float64 

        ω[1]    0.3270    0.2311    0.0066   1129.8189   515.3923    0.9997      155.7512
        ω[2]    0.3299    0.2365    0.0069   1058.7846   738.7187    1.0032      145.9587
        ω[3]    0.3431    0.2415    0.0069   1149.3102   721.1299    1.0005      158.4381
     μ[1, 1]    5.0655    2.2270    0.0613   1319.3656   789.2345    1.0007      181.8811
     μ[2, 1]    5.0261    2.1903    0.0578   1485.8065   757.1258    0.9995      204.8258
     μ[1, 2]    5.0966    2.0562    0.0670    939.1151   713.4876    0.9990      129.4617
     μ[2, 2]    4.9521    2.1826    0.0628   1209.1306   843.9087    0.9993      166.6847

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

        ω[1]    0.0160    0.1326    0.2861    0.4718    0.8387
        ω[2]    0.0152    0.1284    0.2739    0.4994    0.8348
        ω[3]    0.0168    0.1329    0.3077    0.5112    0.8547
     μ[1, 1]    0.7715    3.5757    5.0487    6.6641    9.2216
     μ[2, 1]    0.6594    3.6700    4.9968    6.3735    9.4630
     μ[1, 2]    1.0718    3.7447    5.1711    6.4348    9.0233
     μ[2, 2]    0.8104    3.4796    4.9707    6.4078    9.2173

I am not too worried about using Bijectors.ordered for this set of code, since I am just trying to fit an arbitrary number of Gaussian’s to my dataset to approximate the pdf. So I am not concerned with which cluster contains which data. Although if I update my code to use MCMCThreads() I’ll have to modify things a bit.

But currently even if I use filldist my model parameters never match what the ground truth is. I am not sure if I need to index the parameters in a particular way.

I think I see what your referring to in the GMM example (last section - “inferring assignments”). That does help move things in the right direction.

Updated Model

@model function test_mixture_model(;K = 3)

    try
        ω ~ Dirichlet(K, 1.0)

        μ ~ MvNormal(zeros(2*(K-1)), I)
        
    
        obs ~ MixtureModel(
            [
                id == 1 ? Product(fill(Uniform(-10, 10), 2)) : MvNormal(μ[(id-1)*2-1:(id-1)*2], I) for id in 1:K
            ],
            ω
        )

    catch me 
        if me isa PosDefException
            Turing.@addlogprob! -Inf
        end
    end
end

Still working on getting the covariance to work…

So I was getting some weird results before, i.e. incorrect parameter values and seemly having some parameters not be inferred. But the updated model below seems to work, but I had to update my sample call to have 10_000 samples for the means to converge.

This new model does complain about not being able to find initial values:

┌ Warning: failed to find valid initial parameters in 10 tries; consider providing explicit initial parameters using the `initial_params` keyword
└ @ Turing.Inference ~/.julia/packages/Turing/r3Hmj/src/mcmc/hmc.jl:192

Didn’t have this issue with the ‘manual’ model, so I’ll have to do some reading…

Updated Model with Covariance

@model function test_mixture_model(;K = 3)

    try
        ω ~ Dirichlet(K, 1.0)

        μ ~ MvNormal(zeros(2*(K-1)), I)
        Sigma ~ MvNormal(zeros(3*(K-1)), I)
        
    
        obs ~ MixtureModel(
            [
                id == 1 ? Product(fill(Uniform(-10, 10), 2)) : MvNormal(μ[(id-1)*2-1:(id-1)*2], [Sigma[(id-1)*3-2] Sigma[(id-1)*3-1]; Sigma[(id-1)*3-1] Sigma[(id-1)*3]]) for id in 1:K
            ],
            ω
        )

    catch me 
        if me isa PosDefException
            Turing.@addlogprob! -Inf
        end
    end
end