Hi there,
I continue to learn Turing. In the tutorial for Bayesian Logistic Regression (https://turing.ml/dev/tutorials/2-logisticregression/), there is a section to check the trend plots of the chains, in order to “do a spot check to make sure each chain converges around similar points”. However, in the tutorial for Bayesian Neural Network (https://turing.ml/dev/tutorials/3-bayesnn/), there is not such a section to check the convergence of chains. I sampled three chains using
mapreduce(c → sample(bayes_nn(hcat(xs…),ts), HMC(0.05, 4), 1500),
chainscat,
1:3
)
The chains do not converge according to a trend plot.
Would you please help me understand to how to make the chains converge in a BNN model? I understand that I can use “MAP estimation to classify our population”, but I prefer to use the average values of chains as parameter estimates.
Thanks,
Chuan
Hi,
In case of the BNN example, it is to be expected that posterior is multi modal. Sampling multiple chains using HMC will likely not converge to the same value, but they should in expectation. I’m not certain this will happen after so few iterations and you might want to increase by a factor of 10 or more. You might also want to check the autocorrelation using the autocorrelation plot to analyse the inference results.
If sampling based inference becomes too inefficient you might want to use variational inference instead. This will provide you with a variational approximation of the posterior which, in case of the currently available alg., will be a uni modal approximation.
I hope that helps.
3 Likes
Thanks for your response, Martin. It is very helpful.
“It is to be expected that posterior is multi modal” explains why a single chain does not converge. For the convergence of multiple chains, I will try to run more iterations. I will also check the autocorrelation plot.
The variational inference is new to me, and I need to learn about it. Thanks for introducing it.
Thanks,
Chuan
Happy to help.
You can find the tutorial on variational inference here: https://turing.ml/dev/tutorials/9-variationalinference/