How to implement a model with missing observations in Turing?

I would like to write in Julia the following toy Bernoulli model written in JAGS:

model{
  p0 ~ dbeta(1,1)
  for(i in 1:n){
    y[i] ~ dbern(p0)
  }     
}

Here’s the complete driver

true_p0 <- 0.2
n <- 600
y <- sample(c(0,1), size=n, prob = c(1-true_p0, true_p0), replace=TRUE)
y[500:600] = NA # Some observations are missing!
data=list(y=y, n=n)
n_iter <- 1000
out<-jags.parallel(data=data, inits=NULL, n.iter=1000,
                   parameters.to.save = c('p0', 'y'), 
                   model.file = 'basic_missing.bug', 
                   n.chains = chain_count)

# The missing values are treated as parameters, so we can extract them: 
print(sum(out$BUGSoutput$sims.list$y[,550])/nrow(out$BUGSoutput$sims.list$y))
# > [1] 0.232375
# The non-missing values are treated as constants, so even if we extract them, there is no variability:
print(sum(out$BUGSoutput$sims.list$y[,200])/nrow(out$BUGSoutput$sims.list$y))
# [1] 0

Here’s my Julia model:

using Random
using Distributions
using Turing
using StatsPlots
using Plots

Random.seed!(3)
true_p0=0.2
v=cat(rand(Bernoulli(true_p0), 500), [missing for _ in 1:100], dims=1)

@model function bernoulli_model(v) 
    p0 ~ Beta(1,1)
    v .~ Bernoulli(p0)
end
model=bernoulli_model(v)

nsamples = 1000
nchains = 8
sampler = MH()
chains = sample(model, sampler, MCMCThreads(), nsamples, nchains, discard_initial = 200, thinning=50);
> TaskFailedException
>
>    nested task error: TaskFailedException
>    
>        nested task error: DomainError with 0.0:
>        Beta: the condition α > zero(α) is not satisfied.

What do I do wrong?

I have cross-posted this question to julia - How to implement a most basic missing data model in Turing? - Stack Overflow . I pledge to synchronize answers between both places (here and in SO) :slight_smile: .

I figured out the answer myself.

There are two issues here:

  1. Bernoulli(p0) expects the values to be integers, and throws a cryptic error when they are booleans. So the first thing to do is to convert the input vector to int by e.g. multiplying the bool values by 1.
  2. v .~ Bernoulli(p0), a vectorized (aka broadcasting) operator “~” is a wrong construct here, probably because the function dispatching happens only once, but we are talking about two completely different operations, that are serviced by two completely different overloads of operator ~ over the Bernoulli distribution. I suspect that one variant is a simple sampling variant that just edits the target log-likelihood, which is applied when “v[i]” is known. The other one is the variant that defines a latent random variable that represents the missing data, which is applied when the “v[i]” is missing. In order to force Julia to dispatch a different operator ~ overload dynamically based on the type of the argument, we need to use the actual loop.

Here’s the corrected code:

using Random
using Distributions
using Turing

Random.seed!(3)
true_p0=0.2
v=cat(rand(Bernoulli(true_p0), 500), [missing for _ in 1:100], dims=1)

@model function bernoulli_model(v) 
    p0 ~ Beta(1,1)
    for i in eachindex(v)
        v[i] ~ Bernoulli(p0)
    end
end

model=bernoulli_model(v)model=bernoulli_model(v)

nsamples = 1000
nchains = 8
sampler = MH()
chains = sample(model, sampler, MCMCThreads(), nsamples, nchains, discard_initial = 200, thinning=50);
1 Like

The dot-tilde form won’t work for missing data imputation. Use a for loop:

julia> @model function bernoulli_model(v) 
    p0 ~ Beta(1,1)
    for i in eachindex(v)
        v[i] ~ Bernoulli(p0)
    end
end;

julia> model = bernoulli_model(v);

julia> chains = sample(model, sampler, MCMCThreads(), nsamples, nchains, discard_initial = 200, thinning=50)
Chains MCMC chain (1000×102×8 Array{Float64, 3}):

Iterations        = 201:50:50151
Number of chains  = 8
Samples per chain = 1000
Wall duration     = 76.01 seconds
Compute duration  = 582.33 seconds
parameters        = p0, v[501], v[502], v[503], v[504], v[505], v[506], v[507], v[508], v[509], v[510], v[511], v[512], v[513], v[514], v[515], v[516], v[517], v[518], v[519], v[520], v[521], v[522], v[523], v[524], v[525], v[526], v[527], v[528], v[529], v[530], v[531], v[532], v[533], v[534], v[535], v[536], v[537], v[538], v[539], v[540], v[541], v[542], v[543], v[544], v[545], v[546], v[547], v[548], v[549], v[550], v[551], v[552], v[553], v[554], v[555], v[556], v[557], v[558], v[559], v[560], v[561], v[562], v[563], v[564], v[565], v[566], v[567], v[568], v[569], v[570], v[571], v[572], v[573], v[574], v[575], v[576], v[577], v[578], v[579], v[580], v[581], v[582], v[583], v[584], v[585], v[586], v[587], v[588], v[589], v[590], v[591], v[592], v[593], v[594], v[595], v[596], v[597], v[598], v[599], v[600]
internals         = lp

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec 
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Float64 

          p0    0.2952    0.0427    0.0101    19.3203    35.3858    2.3602        0.0332
      v[501]    0.5673    0.4955    0.0986    25.2290        NaN    1.5912        0.0433
      v[502]    0.4684    0.4990    0.1064    22.0161        NaN    1.8401        0.0378
      v[503]    0.5005    0.5000    0.1044    22.9345        NaN    1.7526        0.0394
      v[504]    0.5336    0.4989    0.1026    23.6514        NaN    1.6685        0.0406
      v[505]    0.6086    0.4881    0.0990    24.3035        NaN    1.6603        0.0417
      v[506]    0.6449    0.4786    0.0938    26.0161        NaN    1.5213        0.0447
      v[507]    0.6859    0.4642    0.0957    23.5415        NaN    1.6857        0.0404
      v[508]    0.5038    0.5000    0.1072    21.7374        NaN    1.8830        0.0373
      v[509]    0.4109    0.4920    0.1034    22.6482        NaN    1.7320        0.0389
      v[510]    0.5479    0.4977    0.1089    20.8990        NaN    1.9917        0.0359
      v[511]    0.6957    0.4601    0.0955    23.2239        NaN    1.8137        0.0399
      v[512]    0.4079    0.4915    0.1049    21.9421        NaN    1.9167        0.0377
      v[513]    0.4316    0.4953    0.1111    19.8852        NaN    2.1379        0.0341
      ⋮           ⋮         ⋮         ⋮         ⋮          ⋮          ⋮           ⋮
                                                                           87 rows omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

          p0    0.2177    0.2681    0.2828    0.3439    0.3715
      v[501]    0.0000    0.0000    1.0000    1.0000    1.0000
      v[502]    0.0000    0.0000    0.0000    1.0000    1.0000
      v[503]    0.0000    0.0000    1.0000    1.0000    1.0000
      v[504]    0.0000    0.0000    1.0000    1.0000    1.0000
      v[505]    0.0000    0.0000    1.0000    1.0000    1.0000
      v[506]    0.0000    0.0000    1.0000    1.0000    1.0000
      v[507]    0.0000    0.0000    1.0000    1.0000    1.0000
      v[508]    0.0000    0.0000    1.0000    1.0000    1.0000
      v[509]    0.0000    0.0000    0.0000    1.0000    1.0000
      v[510]    0.0000    0.0000    1.0000    1.0000    1.0000
      v[511]    0.0000    0.0000    1.0000    1.0000    1.0000
      v[512]    0.0000    0.0000    0.0000    1.0000    1.0000
      v[513]    0.0000    0.0000    0.0000    1.0000    1.0000
      ⋮           ⋮         ⋮         ⋮         ⋮         ⋮
                                                 87 rows omitted

Two more notes. First, the stacktrace is much more informative when not sampling multiple chains in parallel. It’s often a good idea to try sampling a single chain first to make sure no exceptions are thrown and then restart with multiple chains.

Second, using the MH sampler, we can see sampling didn’t perform very well. (rhat of p0 much greater than 1.01, ESS much lower than 100*nchains). I’m guessing this is because with 100 missing data points being sampled, the model has effectively 101 parameters, and this is too high-dimensional for MH to sample well, but something else might be going on here. @yebai @torfjelde any suggestions for a model with this structure?

EDIT: I see you replied with the same solution as I was writing. While this fixes the error, it still does not fix lack of convergent sampling.

2 Likes

Thank you @sethaxen for your insight! Indeed, the sampling did not perform well, and I agree with you. In theory, I could run the chain for much longer - but the real model I am developing would have about 500 binary parameters (also from missing data), not 100, so that would definitely take a while to compute.

Thank you also for your tip about multiple chains! Most of my experience come from Stan, where the error message you get is mostly unaffected by the number of chains.

Is my intuition right about the reason why the dot-tilde form does not work for data imputation?

Perhaps completelly off topic, but fot imputing data Random Forests are also very effective…

Is my intuition right about the reason why the dot-tilde form does not work for data imputation?

Very much so:)

any suggestions for a model with this structure?

I genuinely don’t have much experience when it comes to discrete models, but I would suspect particle methods, e.g. PG, would do better than something like MH.

Can you elaborate a bit more, as since type of v[i] is Union{Missing, Int} the ~ function should choose the right specialization according to normal Julia logic (perhaps while being slower and not type stable)?

In any case, the original .~ form looks natural and it would be a shame (and surprise) if it didn’t work.