Complicated GMM in Turing.jl - StackOverFlow error

Unfortunately Turing.jl can’t handle cases where the same variable is sampled multiple times properly, e.g.

# BIG NO-NO!
for i = 1:n
    x ~ Normal()
end

# BIG YES-YES!
for i = 1:n
    x[i] ~ Normal()
end

This also means that you need to pre-allocate a container for x.

Your particular example therefore needs to be changed to:

@model function MGMM2(data, ::Type{TV} = Vector{Float64}) where {TV}
    
    K, T = 3, 2

    # draw α and χ
    α ~ Dirichlet(T,2)
    χ1 ~ Dirichlet(K,2)
    χ2 ~ Dirichlet(K,2)
    χ = hcat(χ1, χ2)

    # Pre-allocate.
    μ1 = [TV(undef, size(Gm, 1)) for Gm in data]
    μ2 = [TV(undef, size(Gm, 1)) for Gm in data]
    μ3 = [TV(undef, size(Gm, 1)) for Gm in data]
    σ = TV(undef, length(data))

    for (data_idx, Gm) in enumerate(data)
        D, N = size(Gm)

        # choose topic
        Y ~ Categorical(α)
        θ = χ[:,Y]

        # sample mean values and variance
        μ1[data_idx] ~ MvNormal(zeros(D),1)
        μ2[data_idx] ~ MvNormal(zeros(D),1)
        μ3[data_idx] ~ MvNormal(zeros(D),1)
        μ = [μ1, μ2, μ3]
        σ[data_idx] ~ truncated(Normal(0, 10), 0, Inf)
        σ = σ[data_idx]
        
        # draw assignments for each point in x and generate it from a multivariate normal
        for i in 1:N
            k ~ Categorical(θ[:])
            Gm[:,i] ~ MvNormal(μ[k], sqrt(σ))
        end
    end
end

A couple of a notes here:

  1. The weird TV argument I’m passing to the model is to tell Turing.jl that “hey, I’m want to use a Vector{Float64} by default, but if you want to use something else, go ahead!” This ensures that the model is type-stable, which significantly improves performance. There’s some more information about it here: Performance Tips The resulting TV(undef, ...) call is then just replaced with Vector{Float64}(undef, ...) or w/e.
    • When using particle samplers in Turing.jl, you need to use the TArray type for the variables you want to sample. I’m preeetty sure that when we use this approach of providing the type of the container (the TV argument), Turing.jl will automatically replace this with the correct TArray. But I might be wrong, so if you get weird results still, it might just be fixed replacing ::Type{TV} = Vector{Float64} with ::Type{TV} = TArray{Float64, 1}.
  2. I replaced the begin with function just because it reads a bit nicer when using the where {TV} statement.

Hope this helps!:slight_smile:

EDIT: I completely missed that you were doing PG and HMC within-Gibbs! Then you need one type-argument for those used with HMC (requires arrays to be able to take ForwardDiff.Dual since we’re taking the gradient through it) and one for those used with PG (requires TArray as mentioned above). That is:

@model function MGMM2(data, ::Type{TVC} = Array{Float64}, ::Type{TVD} = TArray{Float64}) where {TVC, TVD}
    
    K, T = 3, 2

    # draw α and χ
    α ~ Dirichlet(T,2)
    χ1 ~ Dirichlet(K,2)
    χ2 ~ Dirichlet(K,2)
    χ = hcat(χ1, χ2)

    # Pre-allocate.
    # Continuous variables.
    μ1 = [TVC(undef, size(Gm, 1)) for Gm in data]
    μ2 = [TVC(undef, size(Gm, 1)) for Gm in data]
    μ3 = [TVC(undef, size(Gm, 1)) for Gm in data]
    σ = TVC(undef, length(data))

    # Discrete variables.
    Y = TVD(undef, length(data))
    k = [TVD(undef, size(Gm, 2)) for Gm in data]


    for (data_idx, Gm) in enumerate(data)
        D, N = size(Gm)

        # choose topic
        Y[data_idx] ~ Categorical(α)
        θ = χ[:,Y]

        # sample mean values and variance
        μ1[data_idx] ~ MvNormal(zeros(D),1)
        μ2[data_idx] ~ MvNormal(zeros(D),1)
        μ3[data_idx] ~ MvNormal(zeros(D),1)
        μ = [μ1, μ2, μ3]
        σ[data_idx] ~ truncated(Normal(0, 10), 0, Inf)
        
        # draw assignments for each point in x and generate it from a multivariate normal
        for i in 1:N
            k[data_idx][i] ~ Categorical(θ[:])
            Gm[:,i] ~ MvNormal(μ[k[data_idx][i]], sqrt(σ[data_idx]))
        end
    end
end

Oh and just in case it’s not clear: the choice of names TV, TVC, TVD, etc. doesn’t matter. As long as something is a Type{...} Turing.jl will replace the type used when it deems necessary.

2 Likes