How to use Bijectors.jl for model with simplex, am I doing right?

Hello all,

I am trying to translate a textbook written in R/Stan into Turing. Simple models such as logistic regression are easy for me to model. However, faced to advanced models with ordered/simplex constraints, I got stuck. It seems to me that such constraints would be realized using Bijectors.jl. Is this right?

I managed to write a minimal code,

using Turing, Bijectors, Distributions, StatsPlots

@model function mn(Y)
	N = sum(Y)
	dist = Dirichlet(length(Y), 0.1)
	binv = dist |> bijector |> inv
	θ ~ transformed(dist, binv)
	Y ~ Multinomial(N, θ)
end

Y = [15, 155, 4, 0, 29, 1]

m = mn(Y)

chn = sample(m, PG(100), MCMCThreads(), 1000, 4)

describe(chn) gave

, StatsPlots.plot(chn) traceplot as follows.

I am not sure whether these results are right or not, because the traceplots are lower/upper-bounded, and the medians of theta[3], theta[4], theta[5], and theta[6] are almost identical, in addition to median(theta[1]) > median(theta[5]).

I’m afraid I missed something rudimentary.
How can I improve the code above? Any suggestion appreciated.

Thanks!

At first glance, I don’t think the ordered bijector b is applicable here, because it maps from an unconstrained vector in \mathbb{R}^n to an ordered vector in \mathbb{R}^n. But to sample on the simplex itself, one uses a bijector s that maps from an unconstrained vector in \mathbb{R}^n to a point on the (n-1)-simplex. One would want to compose the bijectors as b \circ s, but outputs of s have constraints that violate the assumptions of b.

1 Like

Hi Seth,

I truly appreciate your advice!

I don’t think the ordered bijector b is applicable here

You are right. My post is misleading. Ordered constraint is not in my question scope.

But to sample on the simplex itself, one uses a bijector s that maps from an unconstrained vector in Rn to a point on the (n−1) -simplex.

Is it possible for you to share a sample code in order to understand what you mean?
I naively thought as rand(transformed(dist, binv)) returns a simplex-like vector, transformed(dist, binv) would be the one to sample. :sweat:

Ah, okay, my apologies for misunderstanding. To clarify, is your main goal to figure out how to encode a model like the one above or to gain an understanding of how to use Bijectors?

Most users never need to touch Bijectors, as it’s transparently used in the background when needed to transform your model so the parameters are sampled in a latent unconstrained space and then transform them back before returning them to you.

So for your model, you can just do this:

using Turing, Distributions, StatsPlots
@model function mn(Y; N = sum(Y), k = length(Y))
	θ ~ Dirichlet(k, 1)
	Y ~ Multinomial(N, θ)
end
Y = [15, 155, 4, 0, 29, 1]
m = mn(Y)
chn = sample(m, NUTS(500, 0.9), MCMCThreads(), 1000, 4)

Note that I’ve switched the sampler to NUTS. This should be the go-to when your parameters are all continuous, since it’s typically more efficient, and when it fails, it can help diagnose problems that cause other samplers to fail silently.

In a first run (using NUTS()), I saw lots of numerical errors (divergences; sum(chn[:numerical_error])). This seems to happen because of the very low counts of 0 and 1; this causes the posterior density to have high curvature, which causes samples to be biased. Increasing to δ=0.9 adapts a smaller step size, and we get no numerical error. Here’s the result:

julia> chn
Chains MCMC chain (1000×18×4 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 4
Samples per chain = 1000
Wall duration     = 4.55 seconds
Compute duration  = 4.53 seconds
parameters        = θ[1], θ[2], θ[3], θ[4], θ[5], θ[6]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std   naive_se      mcse         ess      rhat   ess_per_sec 
      Symbol   Float64   Float64    Float64   Float64     Float64   Float64       Float64 

        θ[1]    0.0762    0.0185     0.0003    0.0002   5182.5234    1.0000     1145.0560
        θ[2]    0.7429    0.0307     0.0005    0.0005   4792.4908    0.9998     1058.8800
        θ[3]    0.0240    0.0107     0.0002    0.0002   4574.2447    0.9996     1010.6595
        θ[4]    0.0047    0.0045     0.0001    0.0001   5322.6232    1.0001     1176.0104
        θ[5]    0.1429    0.0247     0.0004    0.0004   4348.0419    0.9998      960.6809
        θ[6]    0.0094    0.0066     0.0001    0.0001   4931.6204    0.9997     1089.6201

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

        θ[1]    0.0441    0.0624    0.0752    0.0881    0.1161
        θ[2]    0.6818    0.7228    0.7435    0.7642    0.8020
        θ[3]    0.0077    0.0163    0.0224    0.0300    0.0490
        θ[4]    0.0001    0.0014    0.0033    0.0066    0.0165
        θ[5]    0.0972    0.1257    0.1417    0.1595    0.1946
        θ[6]    0.0011    0.0046    0.0079    0.0128    0.0262

The rhat values close to 1 look good. And note that the ess (effective sample size) values are all greater than the number of requested draws (4*1000=4_000). That indicates antithetic sampling; i.e. NUTS could sample so well that the resulting draws will yield better estimates even than exact (independent) draws from the posterior.

julia> plot(chn)

Note the characteristic “fuzzy caterpillar” look of the trace. This is what you expect to see (and what is absent in the OP, which tells me something is wrong there; I was unable to run your original model, however). A better check is to use ArviZ.plot_rank:

julia> using ArviZ

julia> plot_rank(chn)

Approximately uniform histograms are what you hope to see, and we do.

4 Likes

Thank you so much for your wonderful response, Seth!!

Yes, my goal is to know how to encode a constrained model, as you correctly assumed.

Wow, that’s the point what I have overlooked. I’m amazed at the model got quite simple.

I also appreciate your providing how to tune parameters and check output, which is very helpful for me.

1 Like