[Turing] Using a random variable as an index

HMC requires gradients, so all sampled parameters need to be continuous, but τ is not. By default Turing uses ForwardDiff to compute the gradient, which won’t work for discrete parameters, hence the error.

There are (at least) 3 ways you can proceed:

  1. Use a sampler that doesn’t require gradients
  2. Replace the discrete parameter with a continuous one
  3. Marginalize out the discrete parameter and then in post-processing recover draws of it.

Throughout this, I’ll be using the following dataset:

using Random, StatsPlots

rng = MersenneTwister(42)
n = 100
λ₁ = 10
λ₂ = 20
τ = 30
data = [rand(rng, Poisson(i ≤ τ ? λ₁ : λ₂)) for i in 1:n]

scatter(data; xlabel="Day", ylabel="Count", label="data")
plot!(1:n, [fill(λ₁, τ); fill(λ₂, n - τ)], label="rate")

data

1. Use a sampler that doesn’t require gradients

In Turing, this is probably the easiest option. You can, for example, use a particle sampler like SMC or use a Gibbs sampler to alternate updates of the discrete and continuous parameters. There are a number of examples of this in the Turing docs.

@model function textmodel(data,  n = length(data),  α = mean(data))
    λ₁ ~ Exponential(α)
    λ₂ ~ Exponential(α)
    τ ~ DiscreteUniform(1, n) 
    for i ∈ 1:τ
       data[i] ~ Poisson(λ₁)
    end
    for j ∈ τ+1:n
        data[j] ~ Poisson(λ₂)
    end
end

rng = MersenneTwister(11)
@time chns_discrete = sample(rng, textmodel(data), SMC(), MCMCThreads(), 10_000, 4)
Chains MCMC chain (10000×5×4 Array{Float64, 3}):

Iterations        = 1:1:10000
Number of chains  = 4
Samples per chain = 10000
Wall duration     = 20.27 seconds
Compute duration  = 75.09 seconds
parameters        = λ₁, τ, λ₂
internals         = lp, weight

Summary Statistics
  parameters      mean       std   naive_se      mcse        ess      rhat   ess_per_sec 
      Symbol   Float64   Float64    Float64   Float64    Float64   Float64       Float64 

          λ₁   10.2698    0.5376     0.0027    0.0258    99.6680    1.2293        1.3273
          λ₂   18.9207    1.0455     0.0052    0.0494   109.6145    1.1654        1.4598
           τ   29.4185    2.2262     0.0111    0.1083    92.1613    1.4315        1.2274

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

          λ₁    9.8500    9.9879   10.1530   10.4519   12.3396
          λ₂   16.7302   18.2020   18.7936   19.3083   22.2498
           τ   23.0000   29.0000   30.0000   30.0000   32.0000

Here we have yet to converge (note small ESS values and Rhat’s very different from 1), so we would likely have to draw many more than these 40,000 draws to infer the parameters.

Appropriately tuned HMC (specifically the dynamic HMC implementation called NUTS in Turing) is typically more efficient and has better diagnostic features, so sometimes you want to consider other options.

2. Replace the discrete parameter with a continuous one

Imagine computing the following change to your model. (I assume you mean α=mean(data), as otherwise the exponential prior will put basically no density on rate parameters consistent with your data):

@model function textmodel_continuous(data, n = length(data), α=mean(data))
    λ₁ ~ Exponential(α)
    λ₂ ~ Exponential(α)
    τ ~ Uniform(1, n)  # this is now continuous!
    λ = ifelse.((1:n) .≤ τ, λ₁, λ₂)
    data .~ Poisson.(λ)
end

rng = MersenneTwister(26)
chns_continuous = sample(rng, textmodel_continuous(data), NUTS(), MCMCThreads(), 1_000, 4)
Chains MCMC chain (1000×15×4 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 4
Samples per chain = 1000
Wall duration     = 59.41 seconds
Compute duration  = 213.43 seconds
parameters        = λ₁, τ, λ₂
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std   naive_se      mcse         ess      rhat   ess_per_sec 
      Symbol   Float64   Float64    Float64   Float64     Float64   Float64       Float64 

          λ₁   10.7436    0.5977     0.0094    0.0186   1020.0549    1.0020        4.7794
          λ₂   19.1539    0.5432     0.0086    0.0204    696.8815    1.0047        3.2652
           τ   30.3473    0.6903     0.0109    0.0212    912.8391    1.0043        4.2771

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

          λ₁    9.6240   10.3286   10.7331   11.1331   11.9780
          λ₂   18.1422   18.7793   19.1410   19.5075   20.2732
           τ   29.0716   30.0593   30.3800   30.7231   31.8695

This assigns a day-specific rate parameter that follows a step function, jumping from \lambda_1 to \lambda_2 at time \tau, which will generally fall between days. Since this is continuous, everything works well, but because the gradient of the log density wrt \tau will be zero, we’ve not given HMC all of the information available, and it probably had to adapt small step sizes to get high acceptance rates, which causes the sampling to be slow.

To give more gradient info, instead of a step function, we can use a steep sigmoid, e.g.

f(i) = (\lambda_2 - \lambda_1) \operatorname{logistic}(k (i - \tau)) + \lambda_1.

The parameter k controls the steepness of the sigmoid. Low values make it very shallow, while high values make it very steep. A value of k=10 causes it transition from its min to max value over essentially a 1-day period.

using StatsFuns

@model function textmodel_sigmoid(data, k, n = length(data), α=mean(data))
    λ₁ ~ Exponential(α)
    λ₂ ~ Exponential(α)
    τ ~ Uniform(1, n)
    λ = @. (λ₂ - λ₁) * logistic(k * ((1:n) - τ)) + λ₁
    data .~ Poisson.(λ)
end

rng = MersenneTwister(26)
chns_sigmoid = sample(rng, textmodel_sigmoid(data, 10), NUTS(), MCMCThreads(), 1_000, 4)
Chains MCMC chain (1000×15×4 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 4
Samples per chain = 1000
Wall duration     = 0.3 seconds
Compute duration  = 1.12 seconds
parameters        = λ₁, τ, λ₂
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std   naive_se      mcse         ess      rhat   ess_per_sec 
      Symbol   Float64   Float64    Float64   Float64     Float64   Float64       Float64 

          λ₁   10.7355    0.6105     0.0097    0.0136   2429.0148    1.0007     2176.5366
          λ₂   19.1455    0.5341     0.0084    0.0144   1931.2172    1.0018     1730.4814
           τ   30.3281    0.6541     0.0103    0.0236    696.7093    1.0010      624.2914

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

          λ₁    9.5666   10.3147   10.7276   11.1380   11.9544
          λ₂   18.0941   18.7797   19.1430   19.5070   20.2085
           τ   29.1121   30.0500   30.3556   30.6909   31.6976

This runs blazingly fast (0.3s!), and both ESS and R-hat looks good. Let’s look at the ECDF for τ:

using StatsPlots

@df chns_sigmoid ecdfplot(:τ; xlabel="τ", ylabel="Frequency", legend=false)

sigmoid_tau_ecdf

So the vast amount of the probability mass is between days 30 and 31, which is very close to the true value of 30.

You can experiment with different k values or put a prior on k and sample it.

  1. Marginalize out the discrete parameter and then in post-processing recover draws of it.

When you have a discrete parameter with finite support, you can marginalize it out, leaving only continuous parameters that can be sampled with HMC. After drawing the continuous variables from the posterior, you can then draw exact discrete variables as though you had sampled them with the discrete variables, except that the resulting estimates should be better due to Rao-Blackwellization.

Marginalizing

The likelihood with \tau is

p(y | \lambda_1, \lambda_2, \tau) = \left(\prod_{j=1}^\tau p(y_j | \lambda_1) \right) \left( \prod_{j=\tau+1}^n p(y_j | \lambda_2) \right).

So the marginal likelihood is

p(y | \lambda_1, \lambda_2) = \sum_{i=1}^n p(\tau=i) p(y | \lambda_1, \lambda_2, \tau=i) = \frac{1}{n} \sum_{i=1}^n \left(\prod_{j=1}^i p(y_j | \lambda_1) \right) \left( \prod_{j=i+1}^n p(y_j | \lambda_2) \right)

At first glance, this looks hard to compute, but it actually isn’t. We can rewrite this as

p(y | \lambda_1, \lambda_2) = \frac{1}{n} \sum_{i=1}^n a_i(y) b_i(y),

where a_i(y) = \prod_{j=1}^i p(y_j | \lambda_1) and b_i(y) = \prod_{j=i+1}^n p(y_j | \lambda_2) = \frac{\prod_{j=1}^n p(y_j | \lambda_2)}{\prod_{j=1}^i p(y_j | \lambda_2)}.
These terms can be computed using cumprod, or, given log-probabilities, cumsum:

function compute_terms(data, λ₁, λ₂)
    lp1 = cumsum(logpdf.(Poisson(λ₁), data))
    lp2 = cumsum(logpdf.(Poisson(λ₂), data))
    lp = lp1 .+ (lp2[end] .- lp2)
end

Then the log-marginal likelihood is logsumexp(compute_terms(data, λ₁, λ₂)). In Turing, you would write the model as:

@model function textmodel_marginal(data, n = length(data), α=mean(data))
    λ₁ ~ Exponential(α)
    λ₂ ~ Exponential(α)
    lp = compute_terms(data, λ₁, λ₂)
    Turing.@addlogprob! logsumexp(lp) - log(n)
end

Now, to recover posterior draws of \tau, note that we can factorize the posterior to:

p(\lambda_1, \lambda_2, \tau | y) = p(\lambda_1, \lambda_2 | y) p(\tau | y, \lambda_1, \lambda_2).

On the right-hand-side, the left term is the marginal posterior that we get by marginalizing out \tau, while the right term is the one we would use to then recover posterior draws of \tau. So since we already have draws from the distribution on the left, we now want to condition on those draws and the data to draw from the right term.

Note that we can also write the posterior in this way:

p(y, \lambda_1, \lambda_2, \tau) = p(y| \lambda_1, \lambda_2, \tau) p(\tau | \lambda_1, \lambda_2) = p(\tau | y, \lambda_1, \lambda_2) p(y | \lambda_1, \lambda_2)

Since p(\tau|\lambda_1, \lambda_2) = p(\tau) = \frac{1}{n}, we can solve for p(\tau | y, \lambda_1, \lambda_2):

p(\tau | y, \lambda_1, \lambda_2) = \frac{p(y| \lambda_1, \lambda_2, \tau)}{n p(y | \lambda_1, \lambda_2)} = \frac{p(y| \lambda_1, \lambda_2, \tau)}{\sum_{i=1}^n p(y| \lambda_1, \lambda_2, \tau=i)}.

Note that for each value \tau can take, the numerator corresponds to the elements returned by compute_terms, and the denominator just normalizes that vector of probabilities. The result is that for each posterior draw of \lambda_1 and \lambda_2, we can draw \tau using

lp = compute_terms(data, λ₁, λ₂)
p = exp.(lp .- logsumexp(lp))
τ = rand(Categorical(p))

Since this is a little more complicated than the other options, I’ll give a more complete example:

function compute_terms(data, λ₁, λ₂)
    lp1 = cumsum(logpdf.(Poisson(λ₁), data))
    lp2 = cumsum(logpdf.(Poisson(λ₂), data))
    lp = lp1 .+ (lp2[end] .- lp2)
end

@model function textmodel_marginal(data, n = length(data), α=mean(data))
    λ₁ ~ Exponential(α)
    λ₂ ~ Exponential(α)
    lp = compute_terms(data, λ₁, λ₂)
    Turing.@addlogprob! logsumexp(lp) - log(n)
end

rng = MersenneTwister(87)
chns_marginal= sample(rng, textmodel_marginal(data), NUTS(), MCMCThreads(), 1_000, 4)
Chains MCMC chain (1000×14×4 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 4
Samples per chain = 1000
Wall duration     = 0.25 seconds
Compute duration  = 0.87 seconds
parameters        = λ₁, λ₂
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std   naive_se      mcse         ess      rhat   ess_per_sec 
      Symbol   Float64   Float64    Float64   Float64     Float64   Float64       Float64 

          λ₁   10.7435    0.6116     0.0097    0.0114   3901.3490    1.0003     4494.6417
          λ₂   19.1458    0.5200     0.0082    0.0076   4247.3706    0.9999     4893.2841

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

          λ₁    9.5381   10.3272   10.7371   11.1548   11.9891
          λ₂   18.1305   18.7943   19.1389   19.4991   20.1602

Again, this is blazingly fast, and notice that our ESS values are higher than with the sigmoid model.
Now we sample tau:

# a dummy model we use so predict τ given λ₁, λ₂
@model function textmodel_marginal_pred(data, α=mean(data))
    λ₁ ~ Exponential(α)
    λ₂ ~ Exponential(α)
    lp = compute_terms(data, λ₁, λ₂)
    p = exp.(lp .- logsumexp(lp))
    τ ~ Categorical(p)
end

chns_τ = predict(rng, textmodel_marginal_pred(data), chns)
Chains MCMC chain (1000×1×4 Array{Float64, 3}):

Iterations        = 1:1:1000
Number of chains  = 4
Samples per chain = 1000
parameters        = τ
internals         = 

Summary Statistics
  parameters      mean       std   naive_se      mcse         ess      rhat 
      Symbol   Float64   Float64    Float64   Float64     Float64   Float64 

           τ   29.8180    0.7470     0.0118    0.0097   4066.6011    0.9993

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

           τ   29.0000   30.0000   30.0000   30.0000   31.0000

We end up with almost 6 times the ESS we got from the sigmoid model. And let’s compare the ECDFs:

@df chns_τ ecdfplot(:τ; label="τ discrete")
@df chns_sigmoid ecdfplot!(:τ; label="τ continuous",)
@df chns_sigmoid ecdfplot!(round.(:τ, RoundDown); label="τ continous rounded", xlabel="Day", ylabel="Frequency")

marginal_tau_ecdf

So the vast majority of the probability mass is on the true value of 30. And both the discrete and continuous models yield very similar inferences if we round the continuous parameter down to the nearest day.

27 Likes