As a quick follow-up, I just remembered CorrBijector makes posterior improper · Issue #228 · TuringLang/Bijectors.jl · GitHub, and it seems likely to me that its the cause of these problems. For LKJ(2, 1)
, there’s just 1 degree of freedom, but Bijectors will instead sample 4. 3 of those don’t appear in any log density term, so they’re completely unbounded on the reals. Let’s sample this with an explicit y
:
@model function LKJ_latent_bijector_test2()
y ~ filldist(Flat(), 2, 2)
U = Bijectors._inv_link_chol_lkj(y) # bijectors uses upper triangular factors
b = U[1, 2]
X = U'U
logJ = log1p(-b^2)
Turing.@addlogprob! logpdf(LKJ(2, 1), X) + logJ
end
julia> chns_bijector2 = sample(LKJ_latent_bijector_test2(), sampler, MCMCThreads(), ndraws, nchains)
Chains MCMC chain (1000×16×4 Array{Float64, 3}):
Iterations = 1001:1:2000
Number of chains = 4
Samples per chain = 1000
Wall duration = 19.43 seconds
Compute duration = 67.24 seconds
parameters = y[1,1], y[2,1], y[1,2], y[2,2]
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std naive_se mcse ess rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
y[1,1] -11147904746169815040.0000 103813424022984933376.0000 1641434358067363072.0000 13068514454427486208.0000 8.5923 3.5233 0.1278
y[2,1] -5848611813323702272.0000 56976721607190446080.0000 900881069440256640.0000 6841438468043346944.0000 10.2138 1.9909 0.1519
y[1,2] 0.0226 0.8944 0.0141 0.0136 3895.7260 1.0001 57.9342
y[2,2] 4836710533717538816.0000 28310642859682942976.0000 447630567300904064.0000 3509525907373429760.0000 10.7624 1.6892 0.1601
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
y[1,1] -248511673432142577664.0000 -12349585123783217152.0000 5015527916168183808.0000 26492620967308754944.0000 188906933481250881536.0000
y[2,1] -129198750965275377664.0000 -43501092142952873984.0000 -360217999869619840.0000 27753645829781962752.0000 114326851817075408896.0000
y[1,2] -1.8147 -0.5439 0.0171 0.5837 1.8161
y[2,2] -38932511857583849472.0000 -7583069140425722880.0000 962254295777056512.0000 12060787746128771072.0000 101517804871764230144.0000
julia> chns_bijector2[:numerical_error] |> sum
29.0
So we can see how the unconstrained parameters explode, and we see a similar amount of post-warmup numerical error as with the original model.
EDIT: Let’s test the theory by using the same model but just putting a normal prior on the unused degrees of freedom:
@model function LKJ_latent_bijector_test3()
y ~ filldist(Flat(), 2, 2)
U = Bijectors._inv_link_chol_lkj(y)
b = U[1, 2]
X = U'U
logJ = log1p(-b^2)
Turing.@addlogprob! logpdf(LKJ(2, 1), X) + logJ
Turing.@addlogprob! -(y[1, 1]^2 + y[2, 1]^2 + y[2, 2]^2)/2
end
julia> chns_bijector3 = sample(LKJ_latent_bijector_test3(), sampler, MCMCThreads(), ndraws, nchains)
Chains MCMC chain (1000×16×4 Array{Float64, 3}):
Iterations = 1001:1:2000
Number of chains = 4
Samples per chain = 1000
Wall duration = 0.98 seconds
Compute duration = 3.6 seconds
parameters = y[1,1], y[2,1], y[1,2], y[2,2]
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std naive_se mcse ess rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
y[1,1] 0.0011 0.9983 0.0158 0.0120 7042.3000 0.9995 1954.5656
y[2,1] 0.0037 0.9945 0.0157 0.0118 5927.0604 0.9996 1645.0348
y[1,2] -0.0028 0.9163 0.0145 0.0118 4972.3063 1.0001 1380.0461
y[2,2] -0.0013 1.0046 0.0159 0.0121 6517.8198 0.9997 1808.9980
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
y[1,1] -1.9008 -0.6818 0.0110 0.6896 1.9938
y[2,1] -1.9811 -0.6736 0.0007 0.6845 1.9273
y[1,2] -1.8420 -0.5662 0.0129 0.5413 1.8464
y[2,2] -1.9892 -0.6589 0.0024 0.6556 2.0052
julia> chns_bijector3[:numerical_error] |> sum
0.0
So I think this confirms the theory.