How to use Bijectors.jl for model with simplex, am I doing right?

Ah, okay, my apologies for misunderstanding. To clarify, is your main goal to figure out how to encode a model like the one above or to gain an understanding of how to use Bijectors?

Most users never need to touch Bijectors, as it’s transparently used in the background when needed to transform your model so the parameters are sampled in a latent unconstrained space and then transform them back before returning them to you.

So for your model, you can just do this:

using Turing, Distributions, StatsPlots
@model function mn(Y; N = sum(Y), k = length(Y))
	θ ~ Dirichlet(k, 1)
	Y ~ Multinomial(N, θ)
end
Y = [15, 155, 4, 0, 29, 1]
m = mn(Y)
chn = sample(m, NUTS(500, 0.9), MCMCThreads(), 1000, 4)

Note that I’ve switched the sampler to NUTS. This should be the go-to when your parameters are all continuous, since it’s typically more efficient, and when it fails, it can help diagnose problems that cause other samplers to fail silently.

In a first run (using NUTS()), I saw lots of numerical errors (divergences; sum(chn[:numerical_error])). This seems to happen because of the very low counts of 0 and 1; this causes the posterior density to have high curvature, which causes samples to be biased. Increasing to δ=0.9 adapts a smaller step size, and we get no numerical error. Here’s the result:

julia> chn
Chains MCMC chain (1000×18×4 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 4
Samples per chain = 1000
Wall duration     = 4.55 seconds
Compute duration  = 4.53 seconds
parameters        = θ[1], θ[2], θ[3], θ[4], θ[5], θ[6]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std   naive_se      mcse         ess      rhat   ess_per_sec 
      Symbol   Float64   Float64    Float64   Float64     Float64   Float64       Float64 

        θ[1]    0.0762    0.0185     0.0003    0.0002   5182.5234    1.0000     1145.0560
        θ[2]    0.7429    0.0307     0.0005    0.0005   4792.4908    0.9998     1058.8800
        θ[3]    0.0240    0.0107     0.0002    0.0002   4574.2447    0.9996     1010.6595
        θ[4]    0.0047    0.0045     0.0001    0.0001   5322.6232    1.0001     1176.0104
        θ[5]    0.1429    0.0247     0.0004    0.0004   4348.0419    0.9998      960.6809
        θ[6]    0.0094    0.0066     0.0001    0.0001   4931.6204    0.9997     1089.6201

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

        θ[1]    0.0441    0.0624    0.0752    0.0881    0.1161
        θ[2]    0.6818    0.7228    0.7435    0.7642    0.8020
        θ[3]    0.0077    0.0163    0.0224    0.0300    0.0490
        θ[4]    0.0001    0.0014    0.0033    0.0066    0.0165
        θ[5]    0.0972    0.1257    0.1417    0.1595    0.1946
        θ[6]    0.0011    0.0046    0.0079    0.0128    0.0262

The rhat values close to 1 look good. And note that the ess (effective sample size) values are all greater than the number of requested draws (4*1000=4_000). That indicates antithetic sampling; i.e. NUTS could sample so well that the resulting draws will yield better estimates even than exact (independent) draws from the posterior.

julia> plot(chn)

Note the characteristic “fuzzy caterpillar” look of the trace. This is what you expect to see (and what is absent in the OP, which tells me something is wrong there; I was unable to run your original model, however). A better check is to use ArviZ.plot_rank:

julia> using ArviZ

julia> plot_rank(chn)

Approximately uniform histograms are what you hope to see, and we do.

5 Likes