Case study: Speeding up a logistic regression with RHS prior (Turing vs Numpyro) - any tricks I'm missing?

As a quick follow-up, I just remembered CorrBijector makes posterior improper · Issue #228 · TuringLang/Bijectors.jl · GitHub, and it seems likely to me that its the cause of these problems. For LKJ(2, 1), there’s just 1 degree of freedom, but Bijectors will instead sample 4. 3 of those don’t appear in any log density term, so they’re completely unbounded on the reals. Let’s sample this with an explicit y:

@model function LKJ_latent_bijector_test2()
    y ~ filldist(Flat(), 2, 2)
    U = Bijectors._inv_link_chol_lkj(y)  # bijectors uses upper triangular factors
    b = U[1, 2]
    X = U'U
    logJ = log1p(-b^2)
    Turing.@addlogprob! logpdf(LKJ(2, 1), X) + logJ
end
julia> chns_bijector2 = sample(LKJ_latent_bijector_test2(), sampler, MCMCThreads(), ndraws, nchains)
Chains MCMC chain (1000×16×4 Array{Float64, 3}):

Iterations        = 1001:1:2000
Number of chains  = 4
Samples per chain = 1000
Wall duration     = 19.43 seconds
Compute duration  = 67.24 seconds
parameters        = y[1,1], y[2,1], y[1,2], y[2,2]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters                         mean                          std                   naive_se                        mcse         ess      rhat   ess_per_sec 
      Symbol                      Float64                      Float64                    Float64                     Float64     Float64   Float64       Float64 

      y[1,1]   -11147904746169815040.0000   103813424022984933376.0000   1641434358067363072.0000   13068514454427486208.0000      8.5923    3.5233        0.1278
      y[2,1]    -5848611813323702272.0000    56976721607190446080.0000    900881069440256640.0000    6841438468043346944.0000     10.2138    1.9909        0.1519
      y[1,2]                       0.0226                       0.8944                     0.0141                      0.0136   3895.7260    1.0001       57.9342
      y[2,2]     4836710533717538816.0000    28310642859682942976.0000    447630567300904064.0000    3509525907373429760.0000     10.7624    1.6892        0.1601

Quantiles
  parameters                          2.5%                        25.0%                      50.0%                       75.0%                        97.5% 
      Symbol                       Float64                      Float64                    Float64                     Float64                      Float64 

      y[1,1]   -248511673432142577664.0000   -12349585123783217152.0000   5015527916168183808.0000   26492620967308754944.0000   188906933481250881536.0000
      y[2,1]   -129198750965275377664.0000   -43501092142952873984.0000   -360217999869619840.0000   27753645829781962752.0000   114326851817075408896.0000
      y[1,2]                       -1.8147                      -0.5439                     0.0171                      0.5837                       1.8161
      y[2,2]    -38932511857583849472.0000    -7583069140425722880.0000    962254295777056512.0000   12060787746128771072.0000   101517804871764230144.0000

julia> chns_bijector2[:numerical_error] |> sum
29.0

So we can see how the unconstrained parameters explode, and we see a similar amount of post-warmup numerical error as with the original model.

EDIT: Let’s test the theory by using the same model but just putting a normal prior on the unused degrees of freedom:

@model function LKJ_latent_bijector_test3()
    y ~ filldist(Flat(), 2, 2)
    U = Bijectors._inv_link_chol_lkj(y)
    b = U[1, 2]
    X = U'U
    logJ = log1p(-b^2)
    Turing.@addlogprob! logpdf(LKJ(2, 1), X) + logJ
    Turing.@addlogprob! -(y[1, 1]^2 + y[2, 1]^2 + y[2, 2]^2)/2
end
julia> chns_bijector3 = sample(LKJ_latent_bijector_test3(), sampler, MCMCThreads(), ndraws, nchains)
Chains MCMC chain (1000×16×4 Array{Float64, 3}):

Iterations        = 1001:1:2000
Number of chains  = 4
Samples per chain = 1000
Wall duration     = 0.98 seconds
Compute duration  = 3.6 seconds
parameters        = y[1,1], y[2,1], y[1,2], y[2,2]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std   naive_se      mcse         ess      rhat   ess_per_sec 
      Symbol   Float64   Float64    Float64   Float64     Float64   Float64       Float64 

      y[1,1]    0.0011    0.9983     0.0158    0.0120   7042.3000    0.9995     1954.5656
      y[2,1]    0.0037    0.9945     0.0157    0.0118   5927.0604    0.9996     1645.0348
      y[1,2]   -0.0028    0.9163     0.0145    0.0118   4972.3063    1.0001     1380.0461
      y[2,2]   -0.0013    1.0046     0.0159    0.0121   6517.8198    0.9997     1808.9980

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

      y[1,1]   -1.9008   -0.6818    0.0110    0.6896    1.9938
      y[2,1]   -1.9811   -0.6736    0.0007    0.6845    1.9273
      y[1,2]   -1.8420   -0.5662    0.0129    0.5413    1.8464
      y[2,2]   -1.9892   -0.6589    0.0024    0.6556    2.0052


julia> chns_bijector3[:numerical_error] |> sum
0.0

So I think this confirms the theory.

4 Likes