Case study: Speeding up a logistic regression with RHS prior (Turing vs Numpyro) - any tricks I'm missing?

In lieu of offering solutions, I’d like to try to narrow down where the problem is coming from.

TLDR: I suspect the culprit is the bijector for correlation matrices.

I’ll focus on LKJ(2, 1) because this case is really simple (only 1 free parameter) and lets us do some stuff analytically, and already when I sample this with NUTS as in the above example I find some issues:

using Turing, StatsPlots

sampler = NUTS(1000, 0.65)
nchains, ndraws = 4, 1_000
@model function LKJ_basic()
    X ~ LKJ(2, 1)
end
chns_basic = sample(LKJ_basic(), sampler, MCMCThreads(), ndraws, nchains);

Now we check the post-warmup numerical error:

julia> chns_basic[:numerical_error] |> sum
38.0

38 divergent transitions is really too many. We could increase the acceptance ratio to reduce the number (not all the way to 0), but as we’ll see below, this shouldn’t be necessary. This indicates that either the geometry of the induced unconstrained distribution is bad or there is numerical instability somewhere else in the model.

For some analytical results, if X \sim \mathrm{LKJ}(2, 1), then X_{11} = X_{22} = 1 and X_{21} = X_{12} = b \sim \mathrm{Uniform}(-1, 1). We can easily check this with exact draws:

julia> Xs = rand(LKJ(2, 1), 1_000);

julia> bs = [X[2, 1] for X in Xs];

julia> plot(Uniform(-1, 1); func=cdf, label="Uniform(-1, 1)")

julia> ecdfplot!(bs; label="b");

tmp2

For this 2x2 case, the lower Cholesky factor of X is L = \begin{pmatrix}1 & 0 \\ b & \sqrt{1 - b^2}\end{pmatrix}. b is transformed from an unconstrained parameter y with b = \tanh(y), whose derivative is \frac{\mathrm{d}b}{\mathrm{d}y} = 1 - \tanh(y)^2 = 1 - b^2.
So the induced distribution on y is p(y) = \frac{1}{2}(1 - \tanh(y)^2), which is not a hard distribution to sample:

julia> bs ≈ [cholesky(X).L[2, 1] for X in Xs]
true

julia> @model function LKJ_latent()
           y ~ Flat()
           Turing.@addlogprob! log1p(-tanh(y)^2)
       end;

julia> chns_latent = sample(LKJ_latent(), sampler, MCMCThreads(), ndraws, nchains);

julia> chns_latent[:numerical_error] |> sum
0.0

julia> ecdfplot(ys; label="y exact")

julia> @df chns_latent ecdfplot!(:y; label="y NUTS")

tmp3

Since the geometry of the induced unconstrained distribution (what HMC actually samples) is fine, the source of numerical stability must be in the bijector itself (computing X from y), its logdetjac, the logpdf of LKJ, or gradients of any of these.

We’ll test if the source is the bijector or its gradient by manually computing the bijection:

@model function LKJ_latent_bijector_test()
    y ~ Flat()
    b = tanh(y)
    X = [1 b; b 1]
    logJ = log1p(-b^2)
    Turing.@addlogprob! logpdf(LKJ(2, 1), X) + logJ
end
chns_bijector = sample(LKJ_latent_bijector_test(), sampler, MCMCThreads(), ndraws, nchains)

Now let’s see the results:

julia> chns_bijector[:numerical_error] |> sum
0.0

julia> ecdfplot(ys; label="y exact")

julia> @df chns_bijector ecdfplot!(:y; label="y bijector NUTS")

tmp4

No sampling problems, and we are indeed targeting the correct distribution! And this model included the logpdf of LKJ and its gradient, so that can’t be our problem.

By process of elimination, this would imply that the numerical instability for this simple case is in the Bijectors.jl bijector itself (or its gradient). Perhaps that’s also where the problem is for higher-dimensional correlation matrices.

11 Likes