Sampling gives chains which converge to two different distributions in ODE example of TuringTutorials

I was trying to sample the posterior for the parameters and initial condition of a very simple ODE taking the example from
https://turing.ml/v0.21/tutorials/10-bayesian-differential-equations/
as a model, and I was always obtaining quite different results. Sometimes the chain sampled with NUTS converged to distributions around the ‘true’ parameter values and the observational noise variance was low, but sometimes the chain converged to some strange values with a high value for the variance of the observational noise. There was no indication of which chains were more ‘true’ and with less variance, with the exception that the initial step of the NUTS method was high and a lot of warnings like this:

 Warning: The current proposal will be rejected due to numerical error(s).
isfinite.((θ, r, ℓπ, ℓκ)) = (true, false, false, false)

At first I believed that it was my problem that was ill conditioned or the priors were not well chosen. But then I tried the original example that I copy here again as a MWE and one every two or three chains were also displaying this odd behavior.

using Turing, DifferentialEquations, StatsPlots, LinearAlgebra, Random
Random.seed!(14)

# Define Lotka-Volterra model.
function lotka_volterra(du, u, p, t)
    α, β, γ, δ = p
    x, y = u
   du[1] = (α - β * y) * x # prey
    du[2] = (δ * x - γ) * y # predator
    return nothing
end

# Define initial-value problem.
u0 = [1.0, 1.0]
p = [1.5, 1.0, 3.0, 1.0]
tspan = (0.0, 10.0)
prob = ODEProblem(lotka_volterra, u0, tspan, p)
# noisy data
sol = solve(prob, Tsit5(); saveat=0.1)
odedata = Array(sol) + 0.8 * randn(size(Array(sol)))

# Turing model
@model function fitlv(data, prob)
    σ ~ InverseGamma(2, 3)
    α ~ truncated(Normal(1.5, 0.5); lower=0.5, upper=2.5)
    β ~ truncated(Normal(1.2, 0.5); lower=0, upper=2)
    γ ~ truncated(Normal(3.0, 0.5); lower=1, upper=4)
    δ ~ truncated(Normal(1.0, 0.5); lower=0, upper=2)
    p = [α, β, γ, δ]
    predicted = solve(prob, Tsit5(); p=p, saveat=0.1)
    for i in 1:length(predicted)
        data[:, i] ~ MvNormal(predicted[i], σ^2 * I)
    end
    return nothing
end

model = fitlv(odedata, prob)

# Sample 3 independent chains with forward-mode automatic differentiation (the default).
chain = sample(model, NUTS(0.65), MCMCThreads(), 1000, 3; progress=false)
plot(chain)

This is a typical output of the last line:

where the fist chain (blue) converges to a distribution that is far from the ‘true’ original values and with an estimated standard deviation of the observational noise \sigma much higher. Maybe the sampling process get stuck in some local minima? But the time series obtained using data retrodiction sampled from this odd distribution do not even have the same period than the original data.
These odd chains are less common if the prior for \sigma is truncated around the original 0.8 value but still appear. Is there some way to prevent this, or at least to discard these chains using some criteria?
Thanks in advance!

I’m not certain what the problem is with this particular model, but looking at the built docs page, the statistics on the samples don’t look fantastic either:

  • The recommended cutoff for ESS is ESS > nchains * 100 = 300 in this case, which 3 of the parameters don’t satisfy.
  • The R-hat recommended cutoff is R-hat < 1.01, which 2 parameters don’t satisfy.
  • The trace plots deviate slightly from the “fuzzy caterpillar” one should see from well-behaved MCMC.

It’s possible the docs build just picked a lucky seed for which the positions of the different chains are similar. But on my machine that seed produces results similar to yours.

Paging @cpfiffer @torfjelde and @devmotion, who have contributed to this tutorial.

The first thing to try is to not use MCMCThreads: It seems you try to sample chains in parallel whereas in the tutorial we sample them sequentially with MCMCSerial.

Hi sethaxen:
Thanks for your reply. Yes I agree, the overall summary statistics is awful, but repeating with different seeds sometimes all chains get close to the ‘true’ value and in that case I obtain:

Summary Statistics
  parameters      mean       std   naive_se      mcse         ess      rhat    ⋯
      Symbol   Float64   Float64    Float64   Float64     Float64   Float64    ⋯

           σ    0.2937    0.0155     0.0003    0.0004   1999.1084    1.0013    ⋯
           α    1.4875    0.0184     0.0003    0.0005    833.9254    1.0013    ⋯
           β    0.9999    0.0169     0.0003    0.0005   1045.3177    1.0005    ⋯
           γ    3.0465    0.0559     0.0010    0.0016    857.0627    1.0018    ⋯
           δ    1.0167    0.0201     0.0004    0.0006    843.1132    1.0014    ⋯

It is a pity that when only one chain appears that gives rise to a distribution with values very far from the original ones, the whole statistics is already ruined. That is why my question was more about whether it is possible to establish a criterion to reject certain chains and keep others. At least for this case, the outcomes seems to cluster into two possible situations: (a) low \sigma and parameter posteriors including the original values, (b) high \sigma and parameter posteriors not including the the original values.

Hi devmotion
Thanks for your reply. Yes, I changed that after for speeding up the computation, but I was obtaining exactly the same behavior sampling with MCMCSerial (and checked that again).