I’ve seen Pigeons.jl advocated a few times here as a promising tool for parallel tempering. The package makes use of AutoMALA as a general-purpose sampler, instead of the popular NUTS sampler. The AutoMALA paper claims to outperform NUTS when the target exhibits varying local geometry. However, I’ve found AutoMALA to be inefficient when the target exhibits strong correlations.
Practical aside: I often encounter targets where the challenge of varying local geometry presents itself as the target exhibiting strong local correlations. The general wisdom is that such a model should be re-parametrized, but this isn’t practical in my case because (1) it’s not clear a priori what the re-parametrization should be, and (2) the correlation happens in the tails and not the bulk.
Here’s a MWE based on a multivariate t-distribution with 4 degrees of freedom and strong correlations:
using DynamicPPL, Pigeons, LogDensityProblems, MCMCChains
using Distributions, LinearAlgebra
import ForwardDiff
struct DistLogPotential
dist::Distribution
dim::Int
end
function (distlp::DistLogPotential)(x)
return logpdf(distlp.dist, x)
end
LogDensityProblems.dimension(distlp::DistLogPotential) = distlp.dim;
LogDensityProblems.logdensity(distlp::DistLogPotential, x) = distlp(x);
Pigeons.initialization(dist::DistLogPotential, rng, idx::Int) = zeros(dim);
dim = 6
dist = MvTDist(4, [
1 0.98 0.96 0 0 0;
0.98 1 0.98 0 0 0;
0.96 0.98 1 0 0 0;
0 0 0 1 0.99 0;
0 0 0 0.99 1 0;
0 0 0 0 0 1
]);
# AutoMALA
distlp = DistLogPotential(dist, dim);
n_rounds = 12;
@time AM = pigeons(
target = distlp,
reference = distlp, # not used because n_chains = 1
seed = 1,
n_rounds = n_rounds,
n_chains = 1,
explorer = AutoMALA(),
record = [traces; record_default()],
show_report = false
);
AM_chn = Chains(AM);
AM_trace = AM_chn.value[:,1:dim,1].data;
# NUTS
using Turing
@model function turing_model(distlp::DistLogPotential)
θ ~ filldist(Turing.Flat(), distlp.dim)
Turing.@addlogprob! distlp(θ)
return nothing
end
@time NUTS_chn = sample(
turing_model(distlp), Turing.NUTS(0.9), 2^n_rounds;
nadaAMs=2^n_rounds
);
NUTS_trace = NUTS_chn.value[:,1:dim,1].data;
Note the use of n_chains = 1 in pigeons: this is because I’m interested in the performance of AutoMALA itself, not parallel tempering.
Both AutoMALA and NUTS ran in a similar amount of time, but AutoMALA gave a minimum ESS of 33, NUTS gave a minimum ESS of 1492 (each from 2^12 samples). The AutoMALA trace is quite autocorrelated:
Marginal means (true mean is 0):
# AutoMALA
6-element Vector{Float64}:
-0.3524143084722495
-0.3461356921430497
-0.34182607700807704
-0.1729245255081515
-0.17271370633000674
-0.28447159893009116
# NUTS
6-element Vector{Float64}:
-0.025387573118296494
-0.026794540401409394
-0.02634109011562243
-0.04530981170877012
-0.04523217013823499
0.0005582516537453444
The sample means based on NUTS is more accurate. AutoMALA also doesn’t quite reach the tails compared to NUTS.
I imagine using multiple chains + AutoMALA would alleviate some of these issues, but it then loses the computational edge over NUTS. FWIW, Pigeons.jl also implements slice sampling, but that will be less effective on high-dimensional targets.
Some questions:
- AutoMALA supposedly uses less leapfrog steps than NUTS. How can I set up a fairer comparison?
- What’s the underlying reason for AutoMALA’s worse autocorrelation? Is it the step size selector, or the round-based adaptation, or something else?
- If multiple chains were used, do to chains use the same AutoMALA parameters? (e.g. step size guess, preconditioner) Does adaptation use samples from all chains, not just the target?
- Pigeons.jl’s parallel tempering functionality still looks great. Is there existing work on getting Pigeons to work with NUTS?4.
