Debugging and Visualizing High Dimensional Models (Turing/MCMCChains)

Hi all. I’m working on a Tutorial in which I use Turing to fit a time series model to some Economic data (nonfarm payrolls).

The model is basically a radial basis function expansion, with parameters for center locations, amplitudes, and a sigmoidal shock for the COVID pandemic… With sufficient debugging, it works reasonably well when I do an optimization:

image

Now I’m trying to sample the model using Turing.jl’s implementation of NUTS:

chain = sample(bmodel,NUTS(100,.85,max_depth=8,init_ϵ=2e-4),MCMCThreads(),150,2; init_theta = op.values.array)

The sampling completes, on two chains. The summary doesn’t look particularly good. As you can see, the effective sample size is about 4 after 50 real samples on 2 chains. And r_hat is quite large as well.

Summary Statistics from Chain
Summary Statistics
  parameters      mean     std  naive_se     mcse       ess   r_hat
  ──────────  ────────  ──────  ────────  ───────  ────────  ──────
       Ac[1]   -0.0251  0.1477    0.0148  missing    4.3478  5.2349
       Ac[2]   -0.1174  0.0974    0.0097  missing    4.3478  2.6193
       Ac[3]    0.6177  0.1791    0.0179  missing    4.3478  5.6394
      Arb[1]   13.4318  0.1326    0.0133  missing    4.3478  2.9643
      Arb[2]    0.7188  0.1091    0.0109  missing    4.3478  3.3800
      Arb[3]   47.0745  0.4903    0.0490  missing    4.3478  7.1403
      Arb[4]    1.4287  0.4006    0.0401  missing    4.3478  4.4474
      Arb[5]    0.0224  0.0851    0.0085  missing    4.4266  2.1120
      Arb[6]  -26.9109  0.7010    0.0701  missing    4.3478  8.8497
      Arb[7]    6.1686  0.1536    0.0154  missing    4.4141  2.4206
      Arb[8]  -26.8714  0.0891    0.0089  missing    4.3478  3.1630
      Arb[9]    0.6893  0.1505    0.0151  missing    4.3478  4.4602
     Arb[10]    3.3334  0.1706    0.0171  missing    4.3478  4.2515
     Arb[11]    7.5663  0.0646    0.0065  missing    6.5470  1.4389
     Arb[12]   -9.2274  0.3382    0.0338  missing    4.3478  4.2497
...

Obviously, I’d like to do some plots to see what’s going on. For example a simple timeseries plot of several parameters. Of course, MCMCChains has this build in right? :wink:

plot(chain)

leads to the following window opening:

As well you might expect. The model has ~ 100 parameters, and you can’t possibly plot them all meaningfully on one plot.

In general, most of my Bayesian models have well over 10 parameters, and I’ve worked on models with tens of thousands. For example a model of economic outcomes in each of the Public Use Microdata Areas from the census. Each of those has ~ 100k people in it, and so given there are 330M people in the US, there are about 3300 PUMAs, each one needed 5 or 10 parameters to describe the local economic conditions… with geographic partial pooling, etc.

To me, the delight of Bayesian methods is that you can construct meaningful models like this which inevitably involve many parameters. What mechanisms are there to do meaningful plots by slicing and dicing the parameters into subsets etc? The examples in the MCMCChains website don’t really go beyond the basics.

Also for example: MCMCChains says:

" By convention, MCMCChains assumes that parameters with names of the form "name[index]" belong to one group of parameters called :name . You can access the names of all parameters in a chain that belong to the group :name by running

namesingroup(chain, :name)

If the chain contains a parameter of name :name it will be returned as well.

The function group(chain, :name) returns a subset of the chain chain with all parameters in the group :name ."

But that doesn’t seem to be the case… For example in my model there’s Arb[n] for n from 1 to 40, which is the coefficients of the radial basis functions… but doing:

julia> group(chain,:Arb)
ERROR: UndefVarError: group not defined
Stacktrace:
 [1] top-level scope at REPL[162]:1
 [2] eval(::Module, ::Any) at ./boot.jl:331
 [3] eval_user_input(::Any, ::REPL.REPLBackend) at /build/julia-pifKTc/julia-1.4.1+dfsg/usr/share/julia/stdlib/v1.4/REPL/src/REPL.jl:86
 [4] run_backend(::REPL.REPLBackend) at /home/dlakelan/.julia/packages/Revise/tV8FE/src/Revise.jl:1165
 [5] top-level scope at none:0

I did manage to do this:

plot(chain[:,:Arb,:])

which is apparently a way to subset the variables to those in the Arb group… but then I get:


Which is still not a usable plot.

It looks like this is useful:

plot(chain[:,[Symbol(“As[$i]”) for i in 1:10],:])

which then lets me get sort of reasonable plots… (still don’t know why it lays out with weird margins…), and shows that the NUTS sampler is not moving at all well

So, perhaps some updated info on how best to use Turing to get better mixing etc. would help. I have read through the site a fair amount, and am still not clear on how the different samplers work… like NUTS vs AdvancedHMC.NUTS etc?

2 Likes

As can be seen in your last plot you have bimodality present for all parameters. NUTS and HMC, powerful as they are cannot help you in this case. You have to help them. There are two quick wins you can do here. 1. Increase the inductive bias in your model structure and/or 2. set stronger priors. Just my two cents. Good luck.

1 Like

I’m not sure it’s really bimodality, but rather that the sampler is unable to move effectively through the parameter space. If you look at the traceplots they are moving hardly at all at each step. This suggests to me that there is something I can do to make the sampler happier. But I’m not sure what that is! I don’t know what knobs are available etc to tune. The online docs seem somewhat divergent from the latest software releases for example.

1 Like

While NUTS/HMC can be tuned to a certain extent, this rarely results in drastic improvements. As @DoktorMike suggested, you should rethink your model a bit (cf the folk theorem of statistical computing).

I would suggest

  1. starting with a (much) simpler model, and extending that gradually via posterior predictive checks,
  2. if that fails, understand the precise reason (which may entail days of good, clean fun, but you learn a lot), and then use more informative priors.

This is for a Tutorial. The goal here is to learn how to debug models :wink: if I just wanted a model that fits with highly orthogonal terms I’d use a chebyshev series.

Unfortunately what I’ve discovered so far is:

  1. Diagnostic plots are hinky, leaving you with weird margins etc
  2. There isn’t much in the way of documentation on the samplers. And a fair amount of it seems to be behind the development (yay for development… but the docs need to follow… I’d be happy to improve docs but I don’t know where to look)
  3. The sampling interface is a little odd in that you get totally different return types from an optimization, sampling, and variational sampling attempt… leading you to have to write three totally separate diagnostic routines to diagnose each case.
  4. There are no real docs on how to access MCMC chains, and what docs there are are out of date including functions that don’t exist anymore etc.

I’m considering myself as kind of free QA for the Turing team. so I’m hoping what I find here can be both helpful for them, and lead to a tutorial that actually looks like a real world data analysis problem, and not like a textbook problem where all you have to do is type a few lines of code and then magically everything works.

I chose this model because it has high expressive power and its expected to be somewhat challenging to sample. One major reason it’s challenging to sample is that the time series has all kinds of little hinky bits in them (the Economists call them “recessions” :wink: and each one of those provides an opportunity to locally improve the function and get stuck in a local optimum). It is a tutorial after all, what good would it be if it were easy :rofl: . What I didn’t expect was that the sampler would simply get completely stuck and not be able to move at all… My experience with Stan made me think it’d probably work ok. I might wind up implementing the model in Stan to compare.

this actually makes for a great Tutorial case. But when a model fits as well as this does near the MAP estimate… it doesn’t make sense to invoke the folk theorem right away and abandon it…

One strategy I’m considering is to work with a tempered version of this distribution. However it’s not at all clear from the docs how Turing defines the model, and whether I can easily define a tempered version of this model (basically find the function that encodes the PDF and then wrap it like x -> mymodel(x)/2.0 and sample that)

It seems like this might be a case where by taking a LARGE sample from a tempered distribution, and then sub-sampling according to the un-tempered distribution I could do well.

1 Like

So, for the edification of those keeping track… I was able to define a tempered distribution by doing the following (where bmodel is my existing model).

tmodel = DynamicPPL.Model((x...) -> bmodel.f(x...)/2.0,bmodel.args,bmodel.modelgen)

I have no idea how much that affects speed, with the varargs and the splatting. But it did run.

1 Like

If you’re following along at home… Here’s what I found today:

I decided that the main reason it was so hard to sample the original model was that there was an enormous amount of correlation between all the dimensions involving the radial basis function expansion… Because the RBFs have global support, if you change any given coefficient, or center location, it changes the value of the function everywhere in the time domain. So if the sampler needs to fit a particular region of time better, while keeping other regions of time relatively constant, it will need to move in a very specific direction in 80 dimensional space (40 coefficients and 40 centers).

The alternative was to fix the centers, to reduce the dimensionality, and decouple the coefficients by using RBFs with compact support… the specific one is:

exp(1.0-1.0/(1.0-((x-c)/s)^2) (for abs((x-c)/s < 1) and 0 otherwise)

This function is called the “bump” function and it’s got compact support in the range c ± s and it’s infinitely smooth and goes to zero with zero derivative of all orders at the boundary. So you can adjust a particular bump, and it only affects the function in a smallish region. Therefore the sampler is far less constrained in terms of where it can move in space.

Sampling that with NUTS produced a dramatically improved Effective Sample Size and R-hat.

I will be including these results in the Tutorial, it’s a useful example of how what works well in a deterministic context (global RBFs converge exponentially as interpolants) doesn’t necessarily work well in a Bayesian statistical context.

Still finding that the documentation, and tools to debug could use some improvements. Would be happy to submit bug reports against the docs.

2 Likes

@dlakelan really looking forward to these tutorials once you publish them! Regarding the problems you had with visualization/diagnostics: have you tried out the Arviz.jl package? It’s python on the backend but really useful nonetheless.

Thanks @joshualeond, I haven’t looked at Arviz but will do so now.

So far I have one fairly well developed tutorial (BasicDataAndPlots which is a tutorial and an associated discussion. It’s in jmd but I should probably compile and push the notebook as well. I’ll do that today.) and one somewhat splat of data analysis of COVID in the US (because people kept asking me what I thought was going on). That’s already got a pushed notebook because I wanted people to be able to go to binder and run it.

You can check out the current status here: https://github.com/dlakelan/JuliaDataTutorials

Also, when it comes to working with high dimensional models, how do you think Arviz does? I’m more used to working in R so haven’t worked with Arviz previously. The website looks nice though. But my basic assumption going into any data analysis in real-world situations is that my model is going to have anywhere from 15 to 15000 parameters. For example I have a project I’m working with a group on where we’re looking at 12 or so different genetic classifications of tumors across about 25 different tissues, and each tumor/tissue combination has several parameters that describe survival/severity. So that’s maybe 25 * 12 * 5 = 1500 parameters.

In my mind, this is just a typical situation. Any visualization and exploration tool should really be written with this in mind. If I call “plot” with such a model, I should get something usable and I should not lock up my machine running out of memory until it crashes and/or producing 1 million plots each of which fits in to 3 pixels.

Mostly, even in R, the plotting/viz tools seem to be written around the textbook/toy model scenarios: 3 to 10 parameters. There are good real-world problems where you can solve them with 3 to 10 parameters. But they’re far less common than the 1500 messy parameter situations.

wdyt about Arviz in this context?

I’m probably not the best person to speak to this. I’m basically a beginner bayesian so most of my work has been with minimal parameters (<10 or so). Maybe @sethaxen could give you a better idea on this. However, you said you came from R so I’m assuming you were using Stan before. What sorts of visualizations did you use in the past for high dimensional work in Stan? It seems likely that the Arviz package would have similar plots but am unsure.

Thanks for the links to the tutorials btw.

I often used bayesplot, and had to use the options that extract samples of parameters by name or regex match. Or I created my own visualizations in raw form with ggplot2.

R actually had all the problems I mentioned :wink: