I have a very simple model that doesn’t seem to work when I try to infer values for missing
variable. I will try to explain the model and provide fully reproducible, self-contained code. I know it’s long post but I’ve broken it down to its basic steps.
Context: I have case counts of meningitis y_{a, t} stratified by age group a and time t (in years)
The model is as follows
The interpretation of this model is as follows: for any age group a and time t, the total cases y(a, t) is essentially the population size N_{a, t} (given) multiplied by disease incidence rate e^{\rho_a \beta_t}. Since case counts are integer values, we consider a Poisson distribution for the model.
Here is the code that implements this model (as well as the data)
using Turing, MCMCChains, MCMCDiagnosticTools
using Gnuplot
using JLD2
@model function forecast_model(y, psz)
max_ag, max_t = size(y)
ρ ~ filldist(Normal(-10, 1), max_ag)
b ~ filldist(Normal(0, 0.2), max_t)
for t = 1:max_t
for ag = 1:max_ag
λ = log(psz[ag, t]) + ρ[ag] + b[t]
y[ag, t] ~ Poisson(exp(λ))
end
end
end
function run_model()
# case cnts (age groups rows x time columns)
case_cnts = round.(Int, [
161.299063 116.647079 208.76691 77.11912 72.227844 79.02922 31.806968 22.078419 24.026358;
135.209772 71.594128 37.84001 26.590837 15.421628 8.009348 23.616834 7.5108197 15.948128;
152.434139 150.982026 127.113591 34.570938 68.580912 45.892228 16.3526104 29.54641 25.0445475;
465.153894 418.268672 422.486528 159.593688 160.118874 160.521795 133.77078 80.267775 96.4912842;
400.942158 267.261088 334.317225 134.239542 202.392522 156.0098142 197.6655689 109.709976 54.9601544;
125.83011 127.725034 194.192652 130.77674 131.50147 131.785874 105.4030112 52.6924296 118.6608762;
112.770537 156.514596 97.6839312 92.8587286 175.17032 90.886476 103.6336906 34.2089804 81.152368;
309.614049 276.953272 313.180569 245.486976 247.032037 177.611035 147.0404689 104.9896251 106.2844142;
])
# population sizes (age groups rows x time columns)
pop_size = [
3751141 3762809 3795762 3855956 4012658 3951461 3975871 4014258 4004393
3755827 3768112 3784001 3798691 3855407 4004674 3936139 3953063 3987032
11725703 11614002 11555781 11523646 11430152 11473057 11680436 11818564 11925975
51683766 52283584 52810816 53197896 53372958 53507265 53508312 53511850 53606269
66823693 66815272 66863445 67119771 67464174 67830354 68160541 68568735 68700193
62915055 63862517 64730884 65388370 65750735 65892937 65876882 65865537 65922709
37590179 39128649 40701638 42208513 43792580 45443238 47106223 48869972 50720230
34401561 34619159 34797841 35069568 35290291 35522207 35863529 36203319 36649798
]
model_fun = forecast_model(case_cnts, pop_size);
sampler = NUTS(1000, 0.65)
chain = sample(model_fun, sampler, 2000; progress=true)
return chain
end
And this works great! Here is some code to pull out the relevant posteriors and plot the fitted model.
modelresults = run_model()
# get the posterior means/quantiles of rho, beta
rho, rhoquants = describe(group(modelresults, :ρ)) .|> DataFrame
rmeans = rho.mean
rupper = rhoquants[!, Symbol("97.5%")]
rlower = rhoquants[!, Symbol("2.5%")]
beta, betaquants = describe(group(modelresults, :b)) .|> DataFrame
bmeans = beta.mean
bupper = betaquants[!, Symbol("97.5%")]
blower = betaquants[!, Symbol("2.5%")]
## work "backwards" to predict the case counts using the posterior mean/quants
pop_size = [
3751141 3762809 3795762 3855956 4012658 3951461 3975871 4014258 4004393
3755827 3768112 3784001 3798691 3855407 4004674 3936139 3953063 3987032
11725703 11614002 11555781 11523646 11430152 11473057 11680436 11818564 11925975
51683766 52283584 52810816 53197896 53372958 53507265 53508312 53511850 53606269
66823693 66815272 66863445 67119771 67464174 67830354 68160541 68568735 68700193
62915055 63862517 64730884 65388370 65750735 65892937 65876882 65865537 65922709
37590179 39128649 40701638 42208513 43792580 45443238 47106223 48869972 50720230
34401561 34619159 34797841 35069568 35290291 35522207 35863529 36203319 36649798
]
case_cnts = round.(Int, [
161.299063 116.647079 208.76691 77.11912 72.227844 79.02922 31.806968 22.078419 24.026358;
135.209772 71.594128 37.84001 26.590837 15.421628 8.009348 23.616834 7.5108197 15.948128;
152.434139 150.982026 127.113591 34.570938 68.580912 45.892228 16.3526104 29.54641 25.0445475;
465.153894 418.268672 422.486528 159.593688 160.118874 160.521795 133.77078 80.267775 96.4912842;
400.942158 267.261088 334.317225 134.239542 202.392522 156.0098142 197.6655689 109.709976 54.9601544;
125.83011 127.725034 194.192652 130.77674 131.50147 131.785874 105.4030112 52.6924296 118.6608762;
112.770537 156.514596 97.6839312 92.8587286 175.17032 90.886476 103.6336906 34.2089804 81.152368;
309.614049 276.953272 313.180569 245.486976 247.032037 177.611035 147.0404689 104.9896251 106.2844142;
])
poster = similar(pop_size, Float64)
poster_upper = similar(pop_size, Float64)
poster_lower = similar(pop_size, Float64)
for ag in 1:8
for t in 1:9
λ = log(pop_size[ag, t]) + rmeans[ag] + bmeans[t]
λᵤ = log(pop_size[ag, t]) + rupper[ag] + bupper[t]
λₗ = log(pop_size[ag, t]) + rlower[ag] + blower[t]
poster[ag, t] = mean(Poisson(exp(λ))) #mean(Poisson(exp(λ)))
poster_upper[ag, t] = mean(Poisson(exp(λᵤ)))
poster_lower[ag, t] = mean(Poisson(exp(λₗ)))
end
end
## PLOT poster, poster_upper, poster_lower with case_cnts using your favourite plotting package
where each plot is a specific year and the x-axis are the 8 age groups. If it helps, the time is from 1997 to 2005. The black curve is data and the blue curve is the mean of the posterior.
PROBLEM
I want to infer what would’ve happened (i.e., forecast) the next year (i.e., 2006). In other words, in my case_cnts
, I append a vector of missing
fmiss = reshape(fill(missing, 8), (8, 1))
case_cnts = hcat(case_cnts, fmiss)
julia> case_cnts = hcat(case_cnts, fmiss)
8×10 Matrix{Union{Missing, Int64}}:
161 117 209 77 72 79 32 22 24 missing
135 72 38 27 15 8 24 8 16 missing
152 151 127 35 69 46 16 30 25 missing
465 418 422 160 160 161 134 80 96 missing
401 267 334 134 202 156 198 110 55 missing
126 128 194 131 132 132 105 53 119 missing
113 157 98 93 175 91 104 34 81 missing
310 277 313 245 247 178 147 105 106 missing
And ofcourse, for the sake of reproducibility, here is an updated pop_size
which includes the corresponding additional column (i.e. population sizes in 2006 per age group)
_fpops = [4041738 3972124 11925021 53818831 68998018 65959087 52500986 37164107]
fpops = reshape(repeat(_fpops', 1), (8, 1))
pop_size = hcat(pop_size, fpops)
julia> pop_size = hcat(pop_size, fpops)
8×10 Matrix{Int64}:
3751141 3762809 3795762 3855956 4012658 3951461 3975871 4014258 4004393 4041738
3755827 3768112 3784001 3798691 3855407 4004674 3936139 3953063 3987032 3972124
11725703 11614002 11555781 11523646 11430152 11473057 11680436 11818564 11925975 11925021
51683766 52283584 52810816 53197896 53372958 53507265 53508312 53511850 53606269 53818831
66823693 66815272 66863445 67119771 67464174 67830354 68160541 68568735 68700193 68998018
62915055 63862517 64730884 65388370 65750735 65892937 65876882 65865537 65922709 65959087
37590179 39128649 40701638 42208513 43792580 45443238 47106223 48869972 50720230 52500986
34401561 34619159 34797841 35069568 35290291 35522207 35863529 36203319 36649798 37164107
and I am able to run the model again, but it doesn’t really produce good results. Forget the forecasting, the model fit itself goes crazy for the previous 1997 to 2005 years (see image below). Not to mention the step size becomes really small:
julia> run_model()
running updated model 2
┌ Info: Found initial step size
└ ϵ = 5.684341886080802e-15
Does anyone know whats going on.