Easier use of generated_quantities

Hi,

I’ve had a problem where I had to deal with generated_quantities and I’ve come up with a solution myself, but I pretty sure that it’s not the optimal way of doing it. No matter what, I think it would be a great help if Turing provided a helper function for this.

I am in a case where I am experiencing label switching problems when performing Bayesian inference.
I’ll set up a small MWE example.

First we generate some data x coming from a bi-modal distribution:

using Turing: Turing, Model, @model, NUTS, sample, filldist, MCMCThreads, namesingroup
using Distributions: Exponential, Uniform, Rayleigh, MixtureModel

function generate_data()
    x1 = rand(Exponential(0.1), 1000)
    x2 = rand(Exponential(0.5), 500)
    return vcat(x1, x2)
end

x = generate_data();

The first approach to this, would be the code below:

@model function no_ordering(x)

    # prior rs
    rs ~ filldist(Exponential(0.1), 2)

    dists = [Exponential(r) for r in rs]

    # prior θ
    Θ ~ Uniform(0, 1)
    w = [Θ, 1 - Θ]
    f = w[2]

    # mixture distribution
    distribution = MixtureModel(dists, w)

    # likelihood
    x ~ filldist(distribution, length(x))

end

model_no_ordering = no_ordering(x);
chains_no_ordering = sample(model_no_ordering, NUTS(0.65), MCMCThreads(), 100, 4);

Here we have two priors for the rate r and combine them in a mixture model of exponential distributions. However, when we plot the chains for this, we see the label switching problem:

A solution to this would be to change the way we setup the priors to enforce that r_1 < r_2:

@model function with_ordering(x)

    # prior d
    Δr ~ filldist(Exponential(0.1), 2)
    rs = cumsum(Δr)

    dists = [Exponential(r) for r in rs]

    # prior θ
    Θ ~ Uniform(0, 1)
    w = [Θ, 1 - Θ]
    f = w[2]

    # mixture distribution
    distribution = MixtureModel(dists, w)

    # likelihood
    x ~ filldist(distribution, length(x))

    return (; rs, f)
end

model_with_ordering = with_ordering(x)
chains_with_ordering = sample(model_with_ordering, NUTS(0.65), MCMCThreads(), 100, 4)

This works well for the actual inference, however, now we have Δr[1] and Δr[2] in our chain, and not the actual values for r_1 and r_2.

We can return values in the function that we want to include later on, which is the reason behind the line return (; rs, f). f here is just included as an example of another value that we might also be interested in.

We can now use Turings generated_quantities function. The output of this function, however, is of type 100×4 Matrix{NamedTuple{(:rs, :f), Tuple{Vector{Float64}, Float64}}}, which I’ve had trouble to work with.

As thus, I’ve implemented the following helper functions to first make the generated quantities into a chain and later merge it with the original chain:

function get_generated_quantities(model::Turing.Model, chains::Turing.Chains)
    chains_params = Turing.MCMCChains.get_sections(chains, :parameters)
    generated_quantities = Turing.generated_quantities(model, chains_params)
    return generated_quantities
end


function get_generated_quantities(dict::Dict)
    return get_generated_quantities(dict[:model], dict[:chains])
end


""" Get the number of dimensions (K) for the specific variable """
function get_K(dict::Dict, variable::Union{Symbol,String})
    K = length(first(dict[:generated_quantities])[variable])
    return K
end


function get_variables(dict::Dict)
    return dict[:generated_quantities] |> first |> keys
end


function get_N_samples(dict::Dict)
    return length(dict[:chains])
end


function get_N_chains(dict::Dict)
    return length(Turing.chains(dict[:chains]))
end


function generated_quantities_to_chain(dict::Dict, variable::Union{Symbol,String})

    K = get_K(dict, variable)

    matrix = zeros(dict[:N_samples], K, dict[:N_chains])
    for chain = 1:dict[:N_chains]
        for (i, xi) in enumerate(dict[:generated_quantities][:, chain])
            matrix[i, :, chain] .= xi[variable]
        end
    end

    if K == 1
        chain_names = [Symbol("$variable")]
    else
        chain_names = [Symbol("$variable[$i]") for i = 1:K]
    end
    generated_chain = Turing.Chains(matrix, chain_names, info = dict[:chains].info)

    return generated_chain

end


function generated_quantities_to_chains(dict::Dict)
    return hcat(
        [generated_quantities_to_chain(dict, variable) for variable in dict[:variables]]...,
    )
end


function merge_generated_chains(dict::Dict)
    return hcat(
        dict[:chains],
        Turing.setrange(dict[:generated_chains], range(dict[:chains])),
    )
end


function get_merged_chains(model::Turing.Model, chains::Turing.Chains)

    dict = Dict{Symbol,Any}(:model => model, :chains => chains)

    dict[:generated_quantities] = get_generated_quantities(dict)
    dict[:variables] = get_variables(dict)
    dict[:N_samples] = get_N_samples(dict)
    dict[:N_chains] = get_N_chains(dict)

    dict[:generated_chains] = generated_quantities_to_chains(dict)
    return merge_generated_chains(dict)

end


merged_chains = get_merged_chains(model_with_ordering, chains_with_ordering)

Now we can easily plot the new merged chains to get (plotting only the three needed variables):

And see a nice, combined output:

Chains MCMC chain (100×18×4 Array{Float64, 3}):

Iterations        = 51:1:150
Number of chains  = 4
Samples per chain = 100
Wall duration     = 0.74 seconds
Compute duration  = 2.53 seconds
parameters        = Δr[1], Δr[2], Θ, rs[1], rs[2], f
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std   naive_se      mcse        ess      rhat   ess_per_sec 
      Symbol   Float64   Float64    Float64   Float64    Float64   Float64       Float64 

       Δr[1]    0.0978    0.0061     0.0003    0.0005   213.0475    1.0144       84.2418
       Δr[2]    0.3756    0.0330     0.0016    0.0024   209.5259    1.0140       82.8493
           Θ    0.6603    0.0352     0.0018    0.0034   127.9808    1.0325       50.6053
       rs[1]    0.0978    0.0061     0.0003    0.0005   213.0475    1.0144       84.2418
       rs[2]    0.4734    0.0356     0.0018    0.0028   191.3401    1.0173       75.6584
           f    0.3397    0.0352     0.0018    0.0034   127.9808    1.0325       50.6053

I am quite confident that is not the optimal, nor prettiest, way of doing this, but I couldn’t find any other examples of easier ways of accomplishing the same.

Am I missing something here? Or is this really the way to do this?

I guess an ordering constraint in the actual inference step could alleviate the need for generated_quantities in this case, but I assume that I am not the only one that sometimes want to extract further information from transformed variables in the inference step.

Thanks a lot!

4 Likes