Inferring missing values in a Turing.jl model

I have a very simple model that doesn’t seem to work when I try to infer values for missing variable. I will try to explain the model and provide fully reproducible, self-contained code. I know it’s long post but I’ve broken it down to its basic steps.

Context: I have case counts of meningitis y_{a, t} stratified by age group a and time t (in years)

The model is as follows

\begin{aligned} y(a, t) &\sim \text{Poisson}(N_{a, t} \exp{(\rho_a \beta_{t}})) \\ \rho_a &\sim \text{Normal}(-10, 1) \\ \beta_t &\sim \text{Normal}(0, 0.2) \end{aligned}

The interpretation of this model is as follows: for any age group a and time t, the total cases y(a, t) is essentially the population size N_{a, t} (given) multiplied by disease incidence rate e^{\rho_a \beta_t}. Since case counts are integer values, we consider a Poisson distribution for the model.

Here is the code that implements this model (as well as the data)

using Turing, MCMCChains, MCMCDiagnosticTools
using Gnuplot
using JLD2

@model function forecast_model(y, psz)
    max_ag, max_t = size(y)
    ρ ~ filldist(Normal(-10, 1), max_ag)
    b ~ filldist(Normal(0, 0.2), max_t)
    
    for t = 1:max_t
        for ag = 1:max_ag
            λ = log(psz[ag, t]) + ρ[ag] + b[t]
            y[ag, t] ~ Poisson(exp(λ)) 
        end
    end
end

function run_model() 
    # case cnts (age groups rows x time columns)
    case_cnts = round.(Int, [
        161.299063 116.647079 208.76691 77.11912 72.227844 79.02922 31.806968 22.078419 24.026358; 
        135.209772 71.594128 37.84001 26.590837 15.421628 8.009348 23.616834 7.5108197 15.948128;
        152.434139 150.982026 127.113591 34.570938 68.580912 45.892228 16.3526104 29.54641 25.0445475;
        465.153894 418.268672 422.486528 159.593688 160.118874 160.521795 133.77078 80.267775 96.4912842;
        400.942158 267.261088 334.317225 134.239542 202.392522 156.0098142 197.6655689 109.709976 54.9601544;
        125.83011 127.725034 194.192652 130.77674 131.50147 131.785874 105.4030112 52.6924296 118.6608762;
        112.770537 156.514596 97.6839312 92.8587286 175.17032 90.886476 103.6336906 34.2089804 81.152368;
        309.614049 276.953272 313.180569 245.486976 247.032037 177.611035 147.0404689 104.9896251 106.2844142;
    ])

    # population sizes (age groups rows x time columns)
    pop_size = [
        3751141	3762809	3795762	3855956	4012658	3951461	3975871	4014258	4004393
        3755827	3768112	3784001	3798691	3855407	4004674	3936139	3953063	3987032
        11725703 11614002 11555781 11523646 11430152 11473057 11680436	11818564 11925975
        51683766 52283584 52810816 53197896 53372958 53507265 53508312	53511850 53606269
        66823693 66815272 66863445 67119771 67464174 67830354 68160541	68568735 68700193
        62915055 63862517 64730884 65388370 65750735 65892937 65876882	65865537 65922709
        37590179 39128649 40701638 42208513 43792580 45443238 47106223	48869972 50720230
        34401561 34619159 34797841 35069568 35290291 35522207 35863529	36203319 36649798
    ]

    model_fun = forecast_model(case_cnts, pop_size);
    sampler = NUTS(1000, 0.65) 
    chain = sample(model_fun, sampler, 2000; progress=true)
    return chain
end

And this works great! Here is some code to pull out the relevant posteriors and plot the fitted model.

modelresults = run_model()

# get the posterior means/quantiles of rho, beta 
rho, rhoquants = describe(group(modelresults, :ρ)) .|> DataFrame
rmeans = rho.mean
rupper = rhoquants[!, Symbol("97.5%")]
rlower = rhoquants[!, Symbol("2.5%")]

beta, betaquants = describe(group(modelresults, :b)) .|> DataFrame
bmeans = beta.mean
bupper = betaquants[!, Symbol("97.5%")]
blower = betaquants[!, Symbol("2.5%")]

## work "backwards" to predict the case counts using the posterior mean/quants 
pop_size = [
3751141	3762809	3795762	3855956	4012658	3951461	3975871	4014258	4004393
3755827	3768112	3784001	3798691	3855407	4004674	3936139	3953063	3987032
11725703 11614002 11555781 11523646 11430152 11473057 11680436	11818564 11925975
51683766 52283584 52810816 53197896 53372958 53507265 53508312	53511850 53606269
66823693 66815272 66863445 67119771 67464174 67830354 68160541	68568735 68700193
62915055 63862517 64730884 65388370 65750735 65892937 65876882	65865537 65922709
37590179 39128649 40701638 42208513 43792580 45443238 47106223	48869972 50720230
34401561 34619159 34797841 35069568 35290291 35522207 35863529	36203319 36649798
]

case_cnts = round.(Int, [
    161.299063 116.647079 208.76691 77.11912 72.227844 79.02922 31.806968 22.078419 24.026358; 
    135.209772 71.594128 37.84001 26.590837 15.421628 8.009348 23.616834 7.5108197 15.948128;
    152.434139 150.982026 127.113591 34.570938 68.580912 45.892228 16.3526104 29.54641 25.0445475;
    465.153894 418.268672 422.486528 159.593688 160.118874 160.521795 133.77078 80.267775 96.4912842;
    400.942158 267.261088 334.317225 134.239542 202.392522 156.0098142 197.6655689 109.709976 54.9601544;
    125.83011 127.725034 194.192652 130.77674 131.50147 131.785874 105.4030112 52.6924296 118.6608762;
    112.770537 156.514596 97.6839312 92.8587286 175.17032 90.886476 103.6336906 34.2089804 81.152368;
    309.614049 276.953272 313.180569 245.486976 247.032037 177.611035 147.0404689 104.9896251 106.2844142;
])

poster = similar(pop_size, Float64)
poster_upper = similar(pop_size, Float64)
poster_lower = similar(pop_size, Float64)
for ag in 1:8
    for t in 1:9 
        λ = log(pop_size[ag, t]) + rmeans[ag] + bmeans[t]
        λᵤ = log(pop_size[ag, t]) + rupper[ag] + bupper[t]
        λₗ = log(pop_size[ag, t]) + rlower[ag] + blower[t]
        poster[ag, t] = mean(Poisson(exp(λ))) #mean(Poisson(exp(λ)))
        poster_upper[ag, t] = mean(Poisson(exp(λᵤ)))
        poster_lower[ag, t] = mean(Poisson(exp(λₗ)))
    end
end

## PLOT poster, poster_upper, poster_lower with case_cnts  using your favourite plotting package


where each plot is a specific year and the x-axis are the 8 age groups. If it helps, the time is from 1997 to 2005. The black curve is data and the blue curve is the mean of the posterior.

PROBLEM
I want to infer what would’ve happened (i.e., forecast) the next year (i.e., 2006). In other words, in my case_cnts, I append a vector of missing

fmiss = reshape(fill(missing, 8), (8, 1))
case_cnts = hcat(case_cnts, fmiss)

julia> case_cnts = hcat(case_cnts, fmiss)
8×10 Matrix{Union{Missing, Int64}}:
 161  117  209   77   72   79   32   22   24  missing
 135   72   38   27   15    8   24    8   16  missing
 152  151  127   35   69   46   16   30   25  missing
 465  418  422  160  160  161  134   80   96  missing
 401  267  334  134  202  156  198  110   55  missing
 126  128  194  131  132  132  105   53  119  missing
 113  157   98   93  175   91  104   34   81  missing
 310  277  313  245  247  178  147  105  106  missing

And ofcourse, for the sake of reproducibility, here is an updated pop_size which includes the corresponding additional column (i.e. population sizes in 2006 per age group)

 _fpops = [4041738 3972124 11925021 53818831 68998018 65959087 52500986 37164107]
 fpops = reshape(repeat(_fpops', 1), (8, 1))
 pop_size = hcat(pop_size, fpops)

julia> pop_size = hcat(pop_size, fpops)
8×10 Matrix{Int64}:
  3751141   3762809   3795762   3855956   4012658   3951461   3975871   4014258   4004393   4041738
  3755827   3768112   3784001   3798691   3855407   4004674   3936139   3953063   3987032   3972124
 11725703  11614002  11555781  11523646  11430152  11473057  11680436  11818564  11925975  11925021
 51683766  52283584  52810816  53197896  53372958  53507265  53508312  53511850  53606269  53818831
 66823693  66815272  66863445  67119771  67464174  67830354  68160541  68568735  68700193  68998018
 62915055  63862517  64730884  65388370  65750735  65892937  65876882  65865537  65922709  65959087
 37590179  39128649  40701638  42208513  43792580  45443238  47106223  48869972  50720230  52500986
 34401561  34619159  34797841  35069568  35290291  35522207  35863529  36203319  36649798  37164107

and I am able to run the model again, but it doesn’t really produce good results. Forget the forecasting, the model fit itself goes crazy for the previous 1997 to 2005 years (see image below). Not to mention the step size becomes really small:

julia> run_model()
running updated model 2
┌ Info: Found initial step size
└   ϵ = 5.684341886080802e-15

Does anyone know whats going on.

I’m not sure exactly what’s going on with the sampling issues, but you can calculate the forecast directly using generated_quantities:

@model function forecast_model(y, psz, nforecast=0)
    max_ag, max_t = size(y)
    ρ ~ filldist(Normal(-10, 1), max_ag)
    b ~ filldist(Normal(0, 0.2), max_t + nforecast)
    
    for t = 1:max_t
        for ag = 1:max_ag
            λ = log(psz[ag, t]) + ρ[ag] + b[t]
            y[ag, t] ~ Poisson(exp(λ)) 
        end
    end

    λ_forecast = exp.(ρ * b[max_t+1:max_t+nforecast]')
    y_forecast = rand.(Poisson.(λ_forecast))
    return y_forecast
end

function run_model(nforecast = 0) 
    # case cnts (age groups rows x time columns)
    case_cnts = round.(Int, [
        161.299063 116.647079 208.76691 77.11912 72.227844 79.02922 31.806968 22.078419 24.026358; 
        135.209772 71.594128 37.84001 26.590837 15.421628 8.009348 23.616834 7.5108197 15.948128;
        152.434139 150.982026 127.113591 34.570938 68.580912 45.892228 16.3526104 29.54641 25.0445475;
        465.153894 418.268672 422.486528 159.593688 160.118874 160.521795 133.77078 80.267775 96.4912842;
        400.942158 267.261088 334.317225 134.239542 202.392522 156.0098142 197.6655689 109.709976 54.9601544;
        125.83011 127.725034 194.192652 130.77674 131.50147 131.785874 105.4030112 52.6924296 118.6608762;
        112.770537 156.514596 97.6839312 92.8587286 175.17032 90.886476 103.6336906 34.2089804 81.152368;
        309.614049 276.953272 313.180569 245.486976 247.032037 177.611035 147.0404689 104.9896251 106.2844142;
    ])

    # population sizes (age groups rows x time columns)
    pop_size = [
        3751141 3762809 3795762 3855956 4012658 3951461 3975871 4014258 4004393
        3755827 3768112 3784001 3798691 3855407 4004674 3936139 3953063 3987032
        11725703 11614002 11555781 11523646 11430152 11473057 11680436 11818564 11925975
        51683766 52283584 52810816 53197896 53372958 53507265 53508312 53511850 53606269
        66823693 66815272 66863445 67119771 67464174 67830354 68160541 68568735 68700193
        62915055 63862517 64730884 65388370 65750735 65892937 65876882 65865537 65922709
        37590179 39128649 40701638 42208513 43792580 45443238 47106223 48869972 50720230
        34401561 34619159 34797841 35069568 35290291 35522207 35863529 36203319 36649798
    ]

    model_fun = forecast_model(case_cnts, pop_size, nforecast);
    sampler = NUTS(1000, 0.65) 
    chain = sample(model_fun, sampler, 2000; progress=true)
    return chain, model_fun
end

model_results, model = run_model(1)
forecast = generated_quantities(model, model_results)