Turing NUTS chains getting stuck at the parameter bounds

Hi all

I have a medium sized ODE model with four estimated parameters, and I have problems with chain convergence when using Turing (NUTS) for inference. Depending on the initial draws, some chains converge, and some chains get stuck mostly at the bounds of some of the parameters. Extending the burn-in phase up to 6000 iters does not change anything, the stuck chains just remain stuck. I have observed this also with my previous model. Is there a way to mitigate this behaviour, @yebai @torfjelde ? Could it be related to the Bijector, see also this post?

# Household stuff
using DifferentialEquations
using CSV, DataFrames
using StatsPlots
using Turing
using Distributions
import Random
using LabelledArrays
using UnPack
using DelimitedFiles
using Serialization
using BenchmarkTools
using ModelingToolkit

Random.seed!(36541)

# ODE model: simple SIR model with seasonally forced contact rate
# Model structure: (agegroups=4 * level=3 * state=4)
function MSIR_als!(du, u, p, t)

    ## params
    β = p[2]
    η = p[3]
    φ = p[4]
    σ = p[5]
    ω = 1.0 / p[6]
    μ = 1.0 / p[7]
    b = p[8]
    nₘ = p[9] # number of women in child-bearing age group
    d = p[10:13]
    ϵ = p[14:17]
    n = p[18:21]
    c = reshape(p[22:end], (4,4))

    ## Create views

    # current states
    N = @view u[:,:,1:end-1]
    M = @view u[:,:,1]
    S = @view u[:,:,2]
    I = @view u[:,:,3]
    R = @view u[:,:,4]

    ## differentials
    dN = @view du[:,:,1:end-1]
    dM = @view du[:,:,1] 
    dS = @view du[:,:,2]
    dI = @view du[:,:,3]
    dR = @view du[:,:,4]
    dC = @view du[:,:,5]

    # Transitions

    ## Ageing (includes births)
    ageing_out = zeros(promote_type(eltype(u), eltype(p)), size(N))
    @. ageing_out = ϵ * N
    ageing_in = zeros(promote_type(eltype(u), eltype(p)), size(N))
    ageing_in[2:end,:,:] += ageing_out[1:end-1,:,:] # ageing in from age age groups below

    ## Births
    pₘ = sum(R[3,:]) / nₘ # calculate the proportion of women of child-bearing age who are in the R compartments
    ageing_in[1,1,1] = b * pₘ # births into M compartment
    ageing_in[1,1,2] = b * (1.0-pₘ) # births into S compartment

    ## Deaths (calculated from births and transitions in and out, to maintain a stable N per age group)
    props_a = zeros(promote_type(eltype(u), eltype(p)), size(N))
    @. props_a = N / n # calculate proportions of n in each state
    deaths = zeros(promote_type(eltype(u), eltype(p)), size(N)) # distribute the age-specific deaths among the states
    @. deaths = d * props_a
    
    ## FOI

    # effective per contact transmission probability
    βeff = β * (1.0 + η * cos(2.0 * π * (t-φ) / 365.0))

    # total number of infectious agents by age group
    I_tot = sum(I, dims=2)

    λ = βeff .* vec(sum(c .* I_tot, dims=1))

    ## infections
    infection = .*(λ, S)

    ## clearance
    clearance = .*(σ, I)

    ## waning immunity
    waning_out_M = .*(μ, M) # waning out of M into S
    waning_out_R = zeros(promote_type(eltype(u), eltype(p)), size(R)) # waning out of R
    @. waning_out_R = *(ω, R) 

    waning_from_R = zeros(promote_type(eltype(u), eltype(p)), size(S)) # waning from R into S
    waning_from_R[:,2:end] += waning_out_R[:,1:end-1] 
    waning_from_R[:,end] += waning_out_R[:,end]
    
    # Equations
    ## births
    ## transitions between age groups 
    @. dN = ageing_in - ageing_out - deaths
    ## transitions between states
    @. dM = - waning_out_M
    @. dS = waning_out_M + waning_from_R - infection
    @. dI = infection - clearance
    @. dR = clearance - waning_out_R

    ## cumulative incidence
    @. dC = infection 

    return nothing
end


# Input

# demographics
pop = 1_000_000
t_y = [10.0, 10.0, 20.0, 30.0] # years spent in each age group
lifespan = sum(t_y)
props = t_y ./ lifespan
n = pop .* props

# Parameters
ψ = 0.05
β = 0.15
η = 0.05
φ = 180
σ = 1.0 / 5.0
ω = 365.0
μ = 180.0
b = 40.0 # daily births
nₘ = n[3]/2.0 # total women of child-bearing age
n_age = 4
n_level = 3

# ageing vector
ϵ = 1.0 ./ (365.0 .* props .* lifespan)

# calculate theoretical daily deaths in each age group from births and transitions in/out to achieve net 0 change
n_out = n .* ϵ
d = vcat(b, n_out[1:end-1]) - n_out

# contacts
c =  [1.4e-6  5.6e-6    2.275e-6   5.44444e-7
        5.6e-6  1.05e-5   2.8875e-6  1.08889e-6
        9.1e-6  1.155e-5  5.25e-6    2.13889e-6
        4.9e-6  9.8e-6    4.8125e-6  1.86667e-6]

# parameter vector
p = vcat(ψ, β, η, φ, σ, ω, μ, b, nₘ, d, ϵ, n, vec(c))

# Inits
u0 = [hcat(zeros(4), zeros(4), zeros(4)) ;;; # M
        hcat(pop .* props .- [0.0, 0.0, 10.0, 0.0], zeros(4), zeros(4)) ;;; # S
        hcat([0.0, 0.0, 10.0, 0.0], zeros(4), zeros(4)) ;;; # I
        hcat(zeros(4), zeros(4), zeros(4)) ;;; # R
        hcat(zeros(4), zeros(4), zeros(4))] 

# Solver settings
tmin = 0.0
tmax = 365.0 #5.0*365.0
tspan = (tmin, tmax)
solvsettings = (abstol = 1.0e-8, 
                reltol = 1.0e-8, 
                saveat = 1.0,
                solver = Tsit5())

# Initiate ODE problem
problem = ODEProblem(MSIR_als!, u0, tspan, p)
problem_mtk = ODEProblem(modelingtoolkitize(problem), [], tspan, jac=true)

sol = solve(problem, 
            solvsettings.solver, 
            abstol=solvsettings.abstol, 
            reltol=solvsettings.reltol, 
            isoutofdomain = (u,p,t)->any(<(0),u),
            saveat=solvsettings.saveat)

sol_array = Array(sol);
# sum over levels (data cannot distinguish between levels)
inc = dropdims(sum(sol_array, dims=2), dims=2)
inc = diff(inc[:,end,:], dims=2)
#plot(inc')

# observation model
function NegativeBinomial2(ψ, incidence)
    p = 1.0/(1.0 + ψ*incidence)
    r = 1.0/ψ
    return NegativeBinomial(r, p)
end


# Fake some data from model
data = rand.(NegativeBinomial2.(ψ, inc))
#scatter(data',legend = false);
#plot!(inc', legend = false) 

# Fit model to fake data

# Set up as Turing model
Turing.setadbackend(:forwarddiff)

@model function prior()
    ψ ~ Uniform(1e-6, 0.1) #Beta(1.1, 50.0) 
    β ~ Uniform(1e-6, 1.0) #Beta(1.5, 5.0) # 
    η ~ Beta(1.5, 10.0) #Uniform(0.0,1.0) 
    φ ~ Uniform(0.0,364.0)
    return [ψ, β, η, φ]
end

# Define prior and fixed theta
theta_fix = p[5:end]


@model function turingmodel_mtk(prior, theta_fix, problem, n_age, n_level, solvsettings) 
    
    issuccess = true

    # Sample prior parameters.
    theta_est = @submodel prior()
    # Update `p`.
    #theta_fix=convert.(eltype(theta_est), theta_fix)
    #promote_type(eltype(theta_fix), eltype(theta_est))

    p = vcat(theta_est, theta_fix) 
    # Update problem and solve ODEs
    problem_new = remake(problem; p=p) 

    sol = solve(problem_new, 
                    solvsettings.solver, 
                    abstol=solvsettings.abstol, 
                    reltol=solvsettings.reltol, 
                    isoutofdomain = (u,p,t)->any(<(0),u),
                    saveat=solvsettings.saveat);

    # Return early if integration failed
    issuccess &= (sol.retcode === :Success)
    if !issuccess
        Turing.@addlogprob! -Inf
        return nothing
    end
    
    sol_array = Array(sol);
    sol_array = sol_array[end-(n_age*n_level-1):end,:] # the last n_age *n_level rows correspond to the C compartment
    sol_array = diff(sol_array, dims=2)
    # sum over levels (data cannot distinguish between levels)
    incidence = reshape(sol_array, (n_age,n_level,size(sol_array,2)))
    incidence = dropdims(sum(incidence, dims=2), dims=2)
    # avoid numerical instability issue
    incidence = max.(eltype(incidence)(1e-9), incidence) 

    # likelihood
    obs_ts ~ arraydist(@. NegativeBinomial2(theta_est[1], incidence))

    return(; sol, incidence, p=p, obs_ts) 

end

# Setup and condition model
model = turingmodel_mtk(prior, 
                    theta_fix,
                    problem_mtk,
                    n_age,
                    n_level, 
                    solvsettings) | (obs_ts = data,);
                     
retval, issuccess = model();

# Fit 
rng = Random.MersenneTwister(66)
chain = sample(model, NUTS(3000, 0.65), MCMCThreads(), 1000, 6, progress=true) 

plot(chain)

I noticed that the log likelihood surfaces were strange for some parameters. For example, the log likelihood for beta increases then becomes flat. The log likelihood for psi is periodic.

One disclaimer: I am not familiar with your model. So I am not sure I implemented the log likelihood function correctly. Please see the file plot_loglikelihood.jl in the following repo: GitHub - itsdfish/turing_de_model_sandbox
Feel free to submit corrections.

If this log likelihood is implemented correctly, NUTS might be having trouble because the topology of the posterior distribution is irregular.
psi
beta

1 Like

Assuming your data model is correct, if you know that the region near the boundary is incorrect, based on a priori reasoning, then you should use a prior that forces the probability mass away from the boundary.

Hi Chris

Thanks a lot for your solution. I just remembered we talked about a similar problem a year ago, but I forgot about. Yes you are right, the loglikelihood surface seems to be odd for some of the parameters. Your implementation of the loglik function is correct as far as I see.

I tried to tighten the priors, but I does not make a difference. These are the trace plots for 6 fitted chains. The log posterior plot (second plot below) clearly confirms the “best” result is the green chain, but it seems that the other chains cannot get out of a local mode. The question is how to get away from these local modes. It seems that adjusting the priors doesn’t seem to help much, and neither does adjusting the target acceptance rate or the extension of the adaptation phase.


logposterior

FYI, here are the trajectory fits for two selected chains (the green=“true” posterior and the blue=local mode). The “poor” posterior chain fits about 50% of the data reasonably well, so I can see why the chain would go there.
ppi_good
ppi_poor

1 Like

No problem. I’m glad to help. In the last set of plots, you can see that the posterior is multimodal. Once way to deal with that is to use parallel tempering with NUTS, but unfortunately a PR needs to be finished before that can happen.

1 Like

Those don’t seem like the kind of priors you want at all.

I suggest rescaling Phi so it’s on the [0,1] interval and then use a strong Beta prior.

Basically use Phi*365 in your model where you used Phi before, and then start with a Beta(50,50) prior or something similar.

That might eliminate your problem and if not continue to use the Beta priors everywhere, for the other parameters as well.

The reason to use rescaled Phi is that it will be easier for the sampler to adapt if the parameters all have similar scale.

2 Likes

So I rescaled phi to [0,1] with a Beta() prior, but unfortunately this did not make much difference (see traceplots and log posterior below). With 3000 iters for the adaptation phase, I get 2/6 chains converged. I have had this discussion with @Christopher_Fisher before too, but it seems the scale of the parameters does not play a large role here? I am trying now for a longer adaptation phase, but I don’t think this will solve this issue. Parallel tempering is probably what I need here.


lostposterior

What Beta prior did you use? I suggest Beta(70,70) to begin with, you’re telling the model that only the Phi values in the vicinity of 0.5 are meaningful. That should keep any of those other modes down near 0, or near 0.75 or near 1 from mattering…

1 Like

This is probably the case. I suspect the convergence problem is due the periodicity of the phi parameter. One way to confirm is to fix phi at the data generating parameter to see if the problem is eliminated.

I encountered the same problem with a quantum model.

1 Like

So when I fix phi, the outcome is even worse. All chains converge to a local mode, which is far from the truth and also not a great trajectory fit. I think there is no way around finding a way to “jump” more effectively between the modes.


fit

Sorry I misread your graphs. You don’t show what you’ve fixed phi’s value to… so it’s hard to evaluate.

You need to figure out which mode for Phi fits well like it did above:

what value of Phi was used there? Set a prior so that there is negligible density outside a region around that value of Phi

No, this plot is from when phi was inferred with a Beta(10.0, 10.0) prior. As you can see in the traceplot, one chain converged to the correct marginal posteriors, which is confirmed with this trajectory fit. The other chains converged to a local mode, which is the other, poor trajectory fit.

I added vertical lines to the plots corresponding to the true parameter values. As you can see, the maximum log likelihood is far away from the true value for each parameter. However, the true parameter value of phi is near the maximum log likelihood (although there three similar maximums). This might explain why your estimates are far away from the ground truth. Here are some possible explanations:

  1. I introduced an error when modifying your code to produce the plots, in which case something else is wrong
  2. I reproduced an error in the original code, and the plots reflect the error
  3. there is a lot of uncertainty due to the number of data points, which moves the max away from the true parameter values
  4. perhaps there is something abnormal with the model, such as poor identifiability.

I tried testing explanation 3 by generating different random sets of data. The log likelihood plots did not change much, which suggests uncertainty might be low. A better way to test this, however, is to generate more data points, but I am not sure how to do that with your model.

psi
beta
eta
phi

Yes, but what value of phi did the correctly converged chain converge to? You had given the impression it was phi near 0.5 but you need to figure out which phi it really was and then make a stronger prior that only admits phi near that value.

@Christopher_Fisher I think I know what caused the Loglik surface to be wobbly in your code: parms = p modifies also p during the mapping. If you do parms = copy(p) instead, the surfaces look more like they should. Beta still shows a flat surface, with the max LL at the true value of 0.15, but values > 0.15 also have a high LL, so I can see why that is difficult for the sampler.
beta
eta
phi
psi

2 Likes

yes that one chain had converged to 0.5, which was the value used to simulate the data. As I mentioned above, tightening the prior did not improve convergence of the other chains.

Your corrected logliklihood code shows a model that should converge, particularly if you give a reasonable prior for Beta. Like a Beta(2,8) should be fine I would think.

In theory, yes, in practice no :slight_smile: . I tried fitting with Beta(2.0, 8.0) for the Beta parameter, but only 2/6 chains converge to the correct posterior. The posterior landscape just seems to complex even for this small toy model. The LL surface plots above are produced with the other parameters fixed to the true value, but during inference these vary too, so the real LL landscape probably looks much more wobbly and difficult.

Good catch! It’s a silly error, but easy to overlook. :laughing:

Nonetheless, I think its good to verify that the model is properly specified. With that being the case, one possible way to improve the performance of the sampler is transform your parameter space from [0,1] to real space. Typically, NUTS works best when the posterior distribution is a standard multivariate normal. You can get close to this with a probit transformation and rescaling phi by 365 as dlakelan recommended before.

The first step would be to define the probit transformation:

Φ(x) = cdf(Normal(), x)

Next, define your priors in real space with normal distributions. You can consider alternative parameters, but a standard normal might suffice.

@model function prior()
    ψ ~ Normal()
    β ~ Normal()
    η ~ Normal()
    φ ~ Normal()
    return [ψ, β, η, φ]
end

Finally, inside your function turingmodel_mtk, map your parameters to the proper space:

...

theta_est = Φ.(theta_est_real)

theta_eta[4] = theta_eta[4] * 365
...
1 Like