Hi,
I’ve had a problem where I had to deal with generated_quantities
and I’ve come up with a solution myself, but I pretty sure that it’s not the optimal way of doing it. No matter what, I think it would be a great help if Turing provided a helper function for this.
I am in a case where I am experiencing label switching problems when performing Bayesian inference.
I’ll set up a small MWE example.
First we generate some data x
coming from a bi-modal distribution:
using Turing: Turing, Model, @model, NUTS, sample, filldist, MCMCThreads, namesingroup
using Distributions: Exponential, Uniform, Rayleigh, MixtureModel
function generate_data()
x1 = rand(Exponential(0.1), 1000)
x2 = rand(Exponential(0.5), 500)
return vcat(x1, x2)
end
x = generate_data();
The first approach to this, would be the code below:
@model function no_ordering(x)
# prior rs
rs ~ filldist(Exponential(0.1), 2)
dists = [Exponential(r) for r in rs]
# prior θ
Θ ~ Uniform(0, 1)
w = [Θ, 1 - Θ]
f = w[2]
# mixture distribution
distribution = MixtureModel(dists, w)
# likelihood
x ~ filldist(distribution, length(x))
end
model_no_ordering = no_ordering(x);
chains_no_ordering = sample(model_no_ordering, NUTS(0.65), MCMCThreads(), 100, 4);
Here we have two priors for the rate r
and combine them in a mixture model of exponential distributions. However, when we plot the chains for this, we see the label switching problem:
A solution to this would be to change the way we setup the priors to enforce that r_1 < r_2:
@model function with_ordering(x)
# prior d
Δr ~ filldist(Exponential(0.1), 2)
rs = cumsum(Δr)
dists = [Exponential(r) for r in rs]
# prior θ
Θ ~ Uniform(0, 1)
w = [Θ, 1 - Θ]
f = w[2]
# mixture distribution
distribution = MixtureModel(dists, w)
# likelihood
x ~ filldist(distribution, length(x))
return (; rs, f)
end
model_with_ordering = with_ordering(x)
chains_with_ordering = sample(model_with_ordering, NUTS(0.65), MCMCThreads(), 100, 4)
This works well for the actual inference, however, now we have Δr[1]
and Δr[2]
in our chain, and not the actual values for r_1 and r_2.
We can return values in the function that we want to include later on, which is the reason behind the line return (; rs, f)
. f
here is just included as an example of another value that we might also be interested in.
We can now use Turings generated_quantities
function. The output of this function, however, is of type 100×4 Matrix{NamedTuple{(:rs, :f), Tuple{Vector{Float64}, Float64}}}
, which I’ve had trouble to work with.
As thus, I’ve implemented the following helper functions to first make the generated quantities into a chain and later merge it with the original chain:
function get_generated_quantities(model::Turing.Model, chains::Turing.Chains)
chains_params = Turing.MCMCChains.get_sections(chains, :parameters)
generated_quantities = Turing.generated_quantities(model, chains_params)
return generated_quantities
end
function get_generated_quantities(dict::Dict)
return get_generated_quantities(dict[:model], dict[:chains])
end
""" Get the number of dimensions (K) for the specific variable """
function get_K(dict::Dict, variable::Union{Symbol,String})
K = length(first(dict[:generated_quantities])[variable])
return K
end
function get_variables(dict::Dict)
return dict[:generated_quantities] |> first |> keys
end
function get_N_samples(dict::Dict)
return length(dict[:chains])
end
function get_N_chains(dict::Dict)
return length(Turing.chains(dict[:chains]))
end
function generated_quantities_to_chain(dict::Dict, variable::Union{Symbol,String})
K = get_K(dict, variable)
matrix = zeros(dict[:N_samples], K, dict[:N_chains])
for chain = 1:dict[:N_chains]
for (i, xi) in enumerate(dict[:generated_quantities][:, chain])
matrix[i, :, chain] .= xi[variable]
end
end
if K == 1
chain_names = [Symbol("$variable")]
else
chain_names = [Symbol("$variable[$i]") for i = 1:K]
end
generated_chain = Turing.Chains(matrix, chain_names, info = dict[:chains].info)
return generated_chain
end
function generated_quantities_to_chains(dict::Dict)
return hcat(
[generated_quantities_to_chain(dict, variable) for variable in dict[:variables]]...,
)
end
function merge_generated_chains(dict::Dict)
return hcat(
dict[:chains],
Turing.setrange(dict[:generated_chains], range(dict[:chains])),
)
end
function get_merged_chains(model::Turing.Model, chains::Turing.Chains)
dict = Dict{Symbol,Any}(:model => model, :chains => chains)
dict[:generated_quantities] = get_generated_quantities(dict)
dict[:variables] = get_variables(dict)
dict[:N_samples] = get_N_samples(dict)
dict[:N_chains] = get_N_chains(dict)
dict[:generated_chains] = generated_quantities_to_chains(dict)
return merge_generated_chains(dict)
end
merged_chains = get_merged_chains(model_with_ordering, chains_with_ordering)
Now we can easily plot the new merged chains to get (plotting only the three needed variables):
And see a nice, combined output:
Chains MCMC chain (100×18×4 Array{Float64, 3}):
Iterations = 51:1:150
Number of chains = 4
Samples per chain = 100
Wall duration = 0.74 seconds
Compute duration = 2.53 seconds
parameters = Δr[1], Δr[2], Θ, rs[1], rs[2], f
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std naive_se mcse ess rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
Δr[1] 0.0978 0.0061 0.0003 0.0005 213.0475 1.0144 84.2418
Δr[2] 0.3756 0.0330 0.0016 0.0024 209.5259 1.0140 82.8493
Θ 0.6603 0.0352 0.0018 0.0034 127.9808 1.0325 50.6053
rs[1] 0.0978 0.0061 0.0003 0.0005 213.0475 1.0144 84.2418
rs[2] 0.4734 0.0356 0.0018 0.0028 191.3401 1.0173 75.6584
f 0.3397 0.0352 0.0018 0.0034 127.9808 1.0325 50.6053
I am quite confident that is not the optimal, nor prettiest, way of doing this, but I couldn’t find any other examples of easier ways of accomplishing the same.
Am I missing something here? Or is this really the way to do this?
I guess an ordering constraint in the actual inference step could alleviate the need for generated_quantities
in this case, but I assume that I am not the only one that sometimes want to extract further information from transformed variables in the inference step.
Thanks a lot!