[ANN] MuseInference.jl (drop-in replacement for HMC / VI, try it on your problem!)

Earlier this week we put out a paper describing a new method for hierarchical Bayesian inference we dubbed the Marginal Unbiased Score Expansion (MUSE), which can pretty drastically outperform HMC / VI for a wide class of high-dimensional problems. The accompanying Julia package is

pkg> add https://github.com/marius311/MuseInference.jl

It works when you have a high-dimensional latent space you want to marginalize out which can be non-Gaussian, but the final constraints on parameters you care about are fairly Gaussian (often automatically the case if the latent space is high-dimensional due to the central limit theorem). Its approximate, but seems to often perform very well. It scales with dimension b/c no dense operators of the dimensionality of the latent space are ever formed (unlike eg Laplace approx or full-rank VI).

It has an interface into Turing, so, following the example from the documentation, its as easy as e.g.:

@model function funnel()
    Īø ~ Normal(0, 3)
    z ~ MvNormal(zeros(512), exp(Īø/2))
    x ~ MvNormal(z, 1)
end

x = (funnel() | (Īø=0,))() # draw sample of `x` to use as simulated data
model = funnel() | (;x)

muse(model, 0, get_covariance=true)

and you get speedups like ~40X vs. HMC, even more (+ accuracy) vs. MFVI (the paper has more detailed benchmarking discussion):

image

It should be ready to go on any existing Turing model (modulo some API caveats) and if anyone is willing / interested, Iā€™m quite curious to hear feedback either about the accuracy of MUSE or the Turing interface itself when applied to real-world or other toy problems. In the paper we use it on a problem in cosmology weā€™re working on with about 6 million dimensions, so it definitely scales.

On the roadmap is getting this into other Julia PPLs. Thereā€™s also I think some cool extensions which higher order AD like Diffractor will make possible, which Iā€™m looking forward to trying. Please donā€™t hesitate to get in touch on Github / privately if any of this is interesting to you!

28 Likes

Layman question: is this the number of independent points in the temperature maps of the cosmic microwave background (CMB) from the sky, like these ones from CERN? And what are you looking for there?

Yea exactly, it corresponds to the number of pixels in those maps of the CMB. Biggest runs we did there is 1024x2048x3 (=~6 million) where the 3 is because we actually infer maps of the polarization (not temperature) which is 2 numbers per pixel, plus one more for the ā€œgravitational lensing potentialā€. This is actually only about 5% of the sky, but it corresponds to some really low-noise observations made with the South Pole Telescope.

The main idea is that we see a distorted view of the original CMB because its been gravitationally lensed, and we try to reconstruct and marginalize over original unlensed CMB and the distortion. The biggest thing we look for in the unlensed CMB polarization are potential signatures of gravitational waves from inflation (you may remember some of my colleagues thinking they found them a few years back, only to have it turn out to be dust in the Milky Way :grimacing:) and in the distorition field weā€™re learning how the matter in the universe has evolved (since thatā€™s whatā€™s doing the distorting) and learning about the properties of things like dark matter, neutrinos, and even testing gravity itself.

7 Likes

Hi Marius, thanks for making this code available. Iā€™m keen to try out the approximation on a couple of my own GP-related problems.

One minor comment on the manuscript: it took me a while to twig exactly what approximation you were making to eqn 7. You state in the para at the top of page 4 that The key insight of MUSE is not to attempt to perform the remaining integral at all, but rather approximate it with its data-averaged value. Itā€™s clear to me in hindsight what you mean by this, but it would have been helpful to have an equation that explicitly states what this sentence means. i.e. if we refer to the log of the integral in eqn 7 as t(x, Īø), writing that the approximation assumes that t(x, Īø) ā‰ˆ <t(x, Īø)>_{xāˆ¼p(x|Īø)} or something (sorry, bad notation, but hopefully the point is clear) and maybe also writing out the full approximation to eqn 7 in terms of this for the avoidance of doubt.

2 Likes

Thanks for the feedback, Iā€™ll look into making that clearer in the revised version.

And please donā€™t hesitate to get in touch if you run into anything as youā€™re applying it to your problem, curious to see how it works.

1 Like