I was trying to sample the posterior for the parameters and initial condition of a very simple ODE taking the example from
https://turing.ml/v0.21/tutorials/10-bayesian-differential-equations/
as a model, and I was always obtaining quite different results. Sometimes the chain sampled with NUTS converged to distributions around the ‘true’ parameter values and the observational noise variance was low, but sometimes the chain converged to some strange values with a high value for the variance of the observational noise. There was no indication of which chains were more ‘true’ and with less variance, with the exception that the initial step of the NUTS method was high and a lot of warnings like this:
Warning: The current proposal will be rejected due to numerical error(s).
isfinite.((θ, r, ℓπ, ℓκ)) = (true, false, false, false)
At first I believed that it was my problem that was ill conditioned or the priors were not well chosen. But then I tried the original example that I copy here again as a MWE and one every two or three chains were also displaying this odd behavior.
using Turing, DifferentialEquations, StatsPlots, LinearAlgebra, Random
Random.seed!(14)
# Define Lotka-Volterra model.
function lotka_volterra(du, u, p, t)
α, β, γ, δ = p
x, y = u
du[1] = (α - β * y) * x # prey
du[2] = (δ * x - γ) * y # predator
return nothing
end
# Define initial-value problem.
u0 = [1.0, 1.0]
p = [1.5, 1.0, 3.0, 1.0]
tspan = (0.0, 10.0)
prob = ODEProblem(lotka_volterra, u0, tspan, p)
# noisy data
sol = solve(prob, Tsit5(); saveat=0.1)
odedata = Array(sol) + 0.8 * randn(size(Array(sol)))
# Turing model
@model function fitlv(data, prob)
σ ~ InverseGamma(2, 3)
α ~ truncated(Normal(1.5, 0.5); lower=0.5, upper=2.5)
β ~ truncated(Normal(1.2, 0.5); lower=0, upper=2)
γ ~ truncated(Normal(3.0, 0.5); lower=1, upper=4)
δ ~ truncated(Normal(1.0, 0.5); lower=0, upper=2)
p = [α, β, γ, δ]
predicted = solve(prob, Tsit5(); p=p, saveat=0.1)
for i in 1:length(predicted)
data[:, i] ~ MvNormal(predicted[i], σ^2 * I)
end
return nothing
end
model = fitlv(odedata, prob)
# Sample 3 independent chains with forward-mode automatic differentiation (the default).
chain = sample(model, NUTS(0.65), MCMCThreads(), 1000, 3; progress=false)
plot(chain)
This is a typical output of the last line:
where the fist chain (blue) converges to a distribution that is far from the ‘true’ original values and with an estimated standard deviation of the observational noise \sigma much higher. Maybe the sampling process get stuck in some local minima? But the time series obtained using data retrodiction sampled from this odd distribution do not even have the same period than the original data.
These odd chains are less common if the prior for \sigma is truncated around the original 0.8 value but still appear. Is there some way to prevent this, or at least to discard these chains using some criteria?
Thanks in advance!