# Turing Mixture Models with Dirichlet weightings

Hi all - I’ve been trying to use the `MixtureModel` in Turing to sample from a multi-modal dataset, and generating the probability for each mode using the `Dirichlet` distributions. Yet somehow I’ve been getting an `BoundsError` and couldn’t see how.

Would appreciate if someone can see where I’ve gone wrong in the example below.

``````using Pkg
Pkg.activate(".")
using Distributions, Turing, DataFrames, Optim

# generate data
# group 1
g1 = rand(Normal(7.5, 0.2), 20)
# group 2
g2 = zeros(40)
# group 3
g3 = rand(Normal(-1, 0.2), 10)

v1 = Vector{Float64}(vcat(g1, g2, g3))

w1 = [length(g1) / length(v1)
, length(g2) / length(v1)
, length(g3) / length(v1)
]

m1 = [7, 0, -1]

# model
@model function mx1(data, mu, wgt, n_data)
# generate probability per group based on input prior
w ~ arraydist( [LogNormal((wgt[w_i]), 0.001) for w_i in 1:3]) # log to force positive..for now
dirichlet_prob ~ Dirichlet(w)

μ ~ arraydist( [Normal(mu[g_i], 0.01) for g_i in 1:3] )

for i in 1:n_data

data[i] ~ MixtureModel(
[truncated(Normal(μ[1], 0.2), 0.001, 100)
, truncated(Normal(μ[2], 0.001), -0.001, 0.001)
, truncated(Normal(μ[3], 0.2), -100, -0.001)],
[dirichlet_prob[1], dirichlet_prob[2], dirichlet_prob[3]]
)

end
end

fmla1 = mx1(v1, m1, w1, length(v1))

map_estimate = optimize(fmla1, MAP())
``````

The above is giving me the error
`BoundsError: attempt to access 2-element Vector{Float64} at index [1:3]`

However, if I replace the `dirichlet_prob ~ Dirichlet(w)` line with `dirichlet_prob = rand(Dirichlet([wgt[1], wgt[2], wgt[3]]), 1)`, then it runs without issues.

If instead I change the `rand(Dirichlet(...))` line to use the `w` as weighting generated from the LogNormal: `dirichlet_prob = rand(Dirichlet([w[1], w[2], w[3]]), 1)` instead of the `wgt[1]` etc, then it gives me another error: `DomainError with Dual{Tag{DynamicPPLTag, Float64}, Float64, ...`

Any pointers would be greatly welcomed…

Have now tried another variation and got a different error:

``````
# model
@model function mx2(data, mu, wgt, n_data)
# generate probability per group based on input prior
# w ~ arraydist( [LogNormal((wgt[w_i]), 0.001) for w_i in 1:3] ) #Diagonal(abs2.(0.001 .* ones(3))) ) for w_i in 1:3]) # log to force positive..for now
# dirichlet_prob ~ Dirichlet(w[:, 1])
# dirichlet_prob = rand(Dirichlet([w[1], w[2], w[3]]), 1)
w1 ~ LogNormal(wgt[1], 0.01)
w2 ~ LogNormal(wgt[2], 0.01)
w3 ~ LogNormal(wgt[3], 0.01)

# display(w1)
μ ~ arraydist( [Normal(mu[g_i], 0.01) for g_i in 1:3] )

for i in 1:n_data

w = Vector{Float64}([w1, w2, w3])
dirichlet_prob ~ Dirichlet(w)
# dirichlet_prob ~ Dirichlet(w[:, 1])

data[i] ~ MixtureModel(
[truncated(Normal(μ[1], 0.2), 0.001, 100)
, truncated(Normal(μ[2], 0.001), -0.001, 0.001)
, truncated(Normal(μ[3], 0.2), -100, -0.001)],
[dirichlet_prob[1], dirichlet_prob[2], dirichlet_prob[3]]
)

end
end

fmla2 = mx2(v1, m1, w1, length(v1))

map_estimate = optimize(fmla2, MAP())

``````

and the error is now: `MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, Float64, 8}) `

@torfjelde Is that a similar issue as for ADVI here?

For the second model version, the issue is the line

``````w = Vector{Float64}([w1, w2, w3])
``````

when used in combination with the autodiff library ForwardDiff.jl (whiich is what we use by default). Basically, the type that ForwardDiff.jl uses (which will be the `w1`, etc. here) is not a standard `Float64`, and so this will be break.

You can either use `Real` (which works with both `Float64` and ForwardDiff’s numbers), or use the trick described here: Performance Tips

But in this particular scenario, you don’t even need any of that:) You can just do:

``````w = [w1, w2, w3]
``````

and Julia will infer the correct type! That should do the trick here.

For the first model; lemme check!

As for the first model: this is indeed a bug!

Should be a quick merge though:)

1 Like

Thanks for the quick fix!

There seems to be an occasional error however for both model 1 and 2:

`ERROR: DomainError with Dual{Tag{DynamicPPLTag, Float64}, Float64, 8}[Dual{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}}(0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0), Dual{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}}(4.27874e57,0.0,4.27874e57,0.0,0.0,0.0,0.0,0.0,0.0), Dual{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}}(2.23784e-73,0.0,0.0,2.23784e-73,0.0,0.0,0.0,0.0,0.0)]: Dirichlet: alpha must be a positive vector.`

The exact same code would run most of the time without problem, but will churn out the above sometimes. Feels like this is purely down to the random seed/starting point leading to some edge cases?

Code:

``````# model
@model function mx2(data, mu, wgt, n_data)

w1 ~ LogNormal(wgt[1], 0.001)
w2 ~ LogNormal(wgt[2], 0.001)
w3 ~ LogNormal(wgt[3], 0.001)

μ ~ arraydist( [Normal(mu[g_i], 0.01) for g_i in 1:3] )

for i in 1:n_data

w = [w1, w2, w3]
dirichlet_prob ~ Dirichlet(w)

data[i] ~ MixtureModel(
[truncated(Normal(μ[1], 0.2), 0.001, 100)
, truncated(Normal(μ[2], 0.001), -0.001, 0.001)
, truncated(Normal(μ[3], 0.2), -100, -0.001)],
[dirichlet_prob[1], dirichlet_prob[2], dirichlet_prob[3]]
)

end
end

fmla2 = mx2(v1, m1, w1, length(v1))

map_estimate = optimize(fmla2, MAP())

``````

That is unfortunately not just numerical instability that is somewhat obfuscated by these dual numbers that ForwardDiff uses to perform AD. But you can see for example that there are some very large values involved, e.g. `4.27874e57`, which leads to certain checks failing. And the reason for it only occuring for certain random seeds means it might just come down to the initial parameeters used during optimization. Might be worth providing the initial parameters. I believe you can do that as follows

``````map_estimate = optimize(fmla2, MAP(), initparams)
``````

Hmm interesting. I’m currently using that `map_estimate` to form the initial parameters in the turing sampling step. Not ideal to have to specify initial params to the step that is meant to find the initial params…! And definitely don’t want that to fail the whole sampling step too.

Have applied a `truncated(...)` on the LogNormals for now to avoid these extremes - seem to be working.

Thanks!

1 Like