Turing NUTS chains getting stuck at the parameter bounds

Yes, I’m almost sure these are going to have very very different log-density (lp) values. The key here is going to be to initialize within a reasonable neighborhood of the solution. not too tight, because then you have no real convergence diagnostic, but not so loose that things become locally trapped well away from the real high mass region.

Parallel tempering is mainly useful if there is more than one high mass region separated by some “gulfs” that it can’t cross (though I suppose it would probably also help here). I suspect in this model the actual solution is fairly compact and probably tightly peaked. Sometimes it’s funny that a highly identified model can’t be sampled very well because the region of space you have to get into is so small :slight_smile:

1 Like

Regarding using a transformation, a method I have used is to take bounded priors and transform to R^n, using Turing’s bijector:

transf = bijector(@Prior) # transforms draws from prior to draws from  ℛⁿ 
transformed_prior = transformed(@Prior, transf) # the transformed prior

where the original prior is something like

# prior should be an array of distributions, one for each parameter
lb, ub = PriorSupport() # need these in Prior
macro Prior()
    return :( arraydist([Uniform(lb[i], ub[i]) for i = 1:size(lb,1)]) )
end

Then one can sample from the transformed prior using something like

@model function MSM(m, S, model)
    θt ~ transformed_prior
   <etcetera>
1 Like

I have discarded the previous runs, but here is a plot of the :lp values I posted above. The y axis is not labelled unfortunately, but it is plot(chain[:lp]). The values are very different for the different chains suggesting they did get stuck in a local mode.
logposterior

If you want a “tempering like” solution to this… I’ve done something in the past where I create a “tempering parameter”, and after doing all the calculations, you can look in the __varinfo__ to grab the current lp value, then

tempering ~ Normal(0,1);
...
temperval = (1 + a * inv_logistic(b * (tempering+c))
curlp = __varinfo__... # I don't remember exactly how to access it.
@addlogprob! -curlp
@addlogprob! curlp/temperval
return temperval

when tempering is very negative, temperval is 1, then this is just curlp/1 and you’re sampling from the model you want… when it’s more to the right… you’re sampling from a tempered version of your problem…

Post sampling, you just subset the samples that have temperval sufficiently close to 1.

3 Likes