I’ve been trying to replicate the results from @rikh great blog post on Bayesian Latent Profile Analysis and the related post on this forum.
This was my attempt
using Turing
x = randn(200);
x[1:100] .+= 10
@model function model(k, Y)
w ~ Dirichlet(k, 1)
μ1 = Uniform(-5, 5)
μ2 = Uniform(5, 15)
μ ~ arraydist([μ1, μ2])
n = length(Y)
Y ~ filldist(MixtureModel(Normal, μ, w), n)
end;
m = model(2, x);
advi = ADVI(10, 1000)
q = vi(m, advi);
In the last line, I get the error input length mismatch (3 != 4)
.
It seems that the input lengths being referred to here relate to k: for models not tied to two gaussians like the one above, the numbers in the brackets are always 2k-1 != 2k.
Using the Turing version from the Pluto notebook in the blog post linked above works well, but switching to the current version of Turing brought the above error message back.
Thank you for your help!
1 Like
Hi, I was trying this a lot, in the end I think something like this is a basis to work on
x = randn(200);
x[1:100] .+= 10
@model function model(Y)
# w ~ Dirichlet([1.0, 2.0])
w ~ Beta(1,1)
# distribution_assignments = Categorical([w, 1- w])
μ1 = Uniform(-5, 5)
μ2 = Uniform(5, 15)
μ ~ arraydist([μ1, μ2])
n = length(Y)
# for i in 1:n
# k = rand(distribution_assignments)
# Y[i] ~ Normal(μ[k], 1)
# end
Y ~ filldist(MixtureModel(Normal, μ, [w, 1-w]), n)
end;
m = model( x);
advi = ADVI(10, 1000)
q = vi(m, advi);
That seems to a bug. Here is a minimal non-working example
using Turing
@model function foo(k)
q ~ Dirichlet(k, 1.2)
end
vi(foo(3), ADVI(10, 1_000)) # fails with: input length mismatch (2 != 3)
Skimming at the code, it seems that the line
num_params = length(varinfo[DynamicPPL.SampleFromPrior()])
from Turing.Variational.meanfield
is wrong as it gets the dimension from a sample on the constrained (dimension k
for the Dirichlet) instead of the unconstrained space (dimension k-1
for the Dirichlet) – unfortunately, fixing seems to require more work than simply reworking this function.
Probably best to open an issue.
1 Like
Thank you for the investigation! I will file an issue then.