HMC requires gradients, so all sampled parameters need to be continuous, but τ
is not. By default Turing uses ForwardDiff to compute the gradient, which won’t work for discrete parameters, hence the error.
There are (at least) 3 ways you can proceed:
- Use a sampler that doesn’t require gradients
- Replace the discrete parameter with a continuous one
- Marginalize out the discrete parameter and then in post-processing recover draws of it.
Throughout this, I’ll be using the following dataset:
using Random, StatsPlots
rng = MersenneTwister(42)
n = 100
λ₁ = 10
λ₂ = 20
τ = 30
data = [rand(rng, Poisson(i ≤ τ ? λ₁ : λ₂)) for i in 1:n]
scatter(data; xlabel="Day", ylabel="Count", label="data")
plot!(1:n, [fill(λ₁, τ); fill(λ₂, n - τ)], label="rate")
1. Use a sampler that doesn’t require gradients
In Turing, this is probably the easiest option. You can, for example, use a particle sampler like SMC or use a Gibbs sampler to alternate updates of the discrete and continuous parameters. There are a number of examples of this in the Turing docs.
@model function textmodel(data, n = length(data), α = mean(data))
λ₁ ~ Exponential(α)
λ₂ ~ Exponential(α)
τ ~ DiscreteUniform(1, n)
for i ∈ 1:τ
data[i] ~ Poisson(λ₁)
end
for j ∈ τ+1:n
data[j] ~ Poisson(λ₂)
end
end
rng = MersenneTwister(11)
@time chns_discrete = sample(rng, textmodel(data), SMC(), MCMCThreads(), 10_000, 4)
Chains MCMC chain (10000×5×4 Array{Float64, 3}):
Iterations = 1:1:10000
Number of chains = 4
Samples per chain = 10000
Wall duration = 20.27 seconds
Compute duration = 75.09 seconds
parameters = λ₁, τ, λ₂
internals = lp, weight
Summary Statistics
parameters mean std naive_se mcse ess rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
λ₁ 10.2698 0.5376 0.0027 0.0258 99.6680 1.2293 1.3273
λ₂ 18.9207 1.0455 0.0052 0.0494 109.6145 1.1654 1.4598
τ 29.4185 2.2262 0.0111 0.1083 92.1613 1.4315 1.2274
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
λ₁ 9.8500 9.9879 10.1530 10.4519 12.3396
λ₂ 16.7302 18.2020 18.7936 19.3083 22.2498
τ 23.0000 29.0000 30.0000 30.0000 32.0000
Here we have yet to converge (note small ESS values and Rhat’s very different from 1), so we would likely have to draw many more than these 40,000 draws to infer the parameters.
Appropriately tuned HMC (specifically the dynamic HMC implementation called NUTS
in Turing) is typically more efficient and has better diagnostic features, so sometimes you want to consider other options.
2. Replace the discrete parameter with a continuous one
Imagine computing the following change to your model. (I assume you mean α=mean(data)
, as otherwise the exponential prior will put basically no density on rate parameters consistent with your data):
@model function textmodel_continuous(data, n = length(data), α=mean(data))
λ₁ ~ Exponential(α)
λ₂ ~ Exponential(α)
τ ~ Uniform(1, n) # this is now continuous!
λ = ifelse.((1:n) .≤ τ, λ₁, λ₂)
data .~ Poisson.(λ)
end
rng = MersenneTwister(26)
chns_continuous = sample(rng, textmodel_continuous(data), NUTS(), MCMCThreads(), 1_000, 4)
Chains MCMC chain (1000×15×4 Array{Float64, 3}):
Iterations = 501:1:1500
Number of chains = 4
Samples per chain = 1000
Wall duration = 59.41 seconds
Compute duration = 213.43 seconds
parameters = λ₁, τ, λ₂
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std naive_se mcse ess rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
λ₁ 10.7436 0.5977 0.0094 0.0186 1020.0549 1.0020 4.7794
λ₂ 19.1539 0.5432 0.0086 0.0204 696.8815 1.0047 3.2652
τ 30.3473 0.6903 0.0109 0.0212 912.8391 1.0043 4.2771
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
λ₁ 9.6240 10.3286 10.7331 11.1331 11.9780
λ₂ 18.1422 18.7793 19.1410 19.5075 20.2732
τ 29.0716 30.0593 30.3800 30.7231 31.8695
This assigns a day-specific rate parameter that follows a step function, jumping from \lambda_1 to \lambda_2 at time \tau, which will generally fall between days. Since this is continuous, everything works well, but because the gradient of the log density wrt \tau will be zero, we’ve not given HMC all of the information available, and it probably had to adapt small step sizes to get high acceptance rates, which causes the sampling to be slow.
To give more gradient info, instead of a step function, we can use a steep sigmoid, e.g.
The parameter k controls the steepness of the sigmoid. Low values make it very shallow, while high values make it very steep. A value of k=10 causes it transition from its min to max value over essentially a 1-day period.
using StatsFuns
@model function textmodel_sigmoid(data, k, n = length(data), α=mean(data))
λ₁ ~ Exponential(α)
λ₂ ~ Exponential(α)
τ ~ Uniform(1, n)
λ = @. (λ₂ - λ₁) * logistic(k * ((1:n) - τ)) + λ₁
data .~ Poisson.(λ)
end
rng = MersenneTwister(26)
chns_sigmoid = sample(rng, textmodel_sigmoid(data, 10), NUTS(), MCMCThreads(), 1_000, 4)
Chains MCMC chain (1000×15×4 Array{Float64, 3}):
Iterations = 501:1:1500
Number of chains = 4
Samples per chain = 1000
Wall duration = 0.3 seconds
Compute duration = 1.12 seconds
parameters = λ₁, τ, λ₂
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std naive_se mcse ess rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
λ₁ 10.7355 0.6105 0.0097 0.0136 2429.0148 1.0007 2176.5366
λ₂ 19.1455 0.5341 0.0084 0.0144 1931.2172 1.0018 1730.4814
τ 30.3281 0.6541 0.0103 0.0236 696.7093 1.0010 624.2914
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
λ₁ 9.5666 10.3147 10.7276 11.1380 11.9544
λ₂ 18.0941 18.7797 19.1430 19.5070 20.2085
τ 29.1121 30.0500 30.3556 30.6909 31.6976
This runs blazingly fast (0.3s!), and both ESS and R-hat looks good. Let’s look at the ECDF for τ
:
using StatsPlots
@df chns_sigmoid ecdfplot(:τ; xlabel="τ", ylabel="Frequency", legend=false)
So the vast amount of the probability mass is between days 30 and 31, which is very close to the true value of 30.
You can experiment with different k values or put a prior on k and sample it.
- Marginalize out the discrete parameter and then in post-processing recover draws of it.
When you have a discrete parameter with finite support, you can marginalize it out, leaving only continuous parameters that can be sampled with HMC. After drawing the continuous variables from the posterior, you can then draw exact discrete variables as though you had sampled them with the discrete variables, except that the resulting estimates should be better due to Rao-Blackwellization.
Marginalizing
The likelihood with \tau is
So the marginal likelihood is
At first glance, this looks hard to compute, but it actually isn’t. We can rewrite this as
where a_i(y) = \prod_{j=1}^i p(y_j | \lambda_1) and b_i(y) = \prod_{j=i+1}^n p(y_j | \lambda_2) = \frac{\prod_{j=1}^n p(y_j | \lambda_2)}{\prod_{j=1}^i p(y_j | \lambda_2)}.
These terms can be computed using cumprod
, or, given log-probabilities, cumsum
:
function compute_terms(data, λ₁, λ₂)
lp1 = cumsum(logpdf.(Poisson(λ₁), data))
lp2 = cumsum(logpdf.(Poisson(λ₂), data))
lp = lp1 .+ (lp2[end] .- lp2)
end
Then the log-marginal likelihood is logsumexp(compute_terms(data, λ₁, λ₂))
. In Turing, you would write the model as:
@model function textmodel_marginal(data, n = length(data), α=mean(data))
λ₁ ~ Exponential(α)
λ₂ ~ Exponential(α)
lp = compute_terms(data, λ₁, λ₂)
Turing.@addlogprob! logsumexp(lp) - log(n)
end
Now, to recover posterior draws of \tau, note that we can factorize the posterior to:
On the right-hand-side, the left term is the marginal posterior that we get by marginalizing out \tau, while the right term is the one we would use to then recover posterior draws of \tau. So since we already have draws from the distribution on the left, we now want to condition on those draws and the data to draw from the right term.
Note that we can also write the posterior in this way:
Since p(\tau|\lambda_1, \lambda_2) = p(\tau) = \frac{1}{n}, we can solve for p(\tau | y, \lambda_1, \lambda_2):
Note that for each value \tau can take, the numerator corresponds to the elements returned by compute_terms
, and the denominator just normalizes that vector of probabilities. The result is that for each posterior draw of \lambda_1 and \lambda_2, we can draw \tau using
lp = compute_terms(data, λ₁, λ₂)
p = exp.(lp .- logsumexp(lp))
τ = rand(Categorical(p))
Since this is a little more complicated than the other options, I’ll give a more complete example:
function compute_terms(data, λ₁, λ₂)
lp1 = cumsum(logpdf.(Poisson(λ₁), data))
lp2 = cumsum(logpdf.(Poisson(λ₂), data))
lp = lp1 .+ (lp2[end] .- lp2)
end
@model function textmodel_marginal(data, n = length(data), α=mean(data))
λ₁ ~ Exponential(α)
λ₂ ~ Exponential(α)
lp = compute_terms(data, λ₁, λ₂)
Turing.@addlogprob! logsumexp(lp) - log(n)
end
rng = MersenneTwister(87)
chns_marginal= sample(rng, textmodel_marginal(data), NUTS(), MCMCThreads(), 1_000, 4)
Chains MCMC chain (1000×14×4 Array{Float64, 3}):
Iterations = 501:1:1500
Number of chains = 4
Samples per chain = 1000
Wall duration = 0.25 seconds
Compute duration = 0.87 seconds
parameters = λ₁, λ₂
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std naive_se mcse ess rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
λ₁ 10.7435 0.6116 0.0097 0.0114 3901.3490 1.0003 4494.6417
λ₂ 19.1458 0.5200 0.0082 0.0076 4247.3706 0.9999 4893.2841
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
λ₁ 9.5381 10.3272 10.7371 11.1548 11.9891
λ₂ 18.1305 18.7943 19.1389 19.4991 20.1602
Again, this is blazingly fast, and notice that our ESS values are higher than with the sigmoid model.
Now we sample tau:
# a dummy model we use so predict τ given λ₁, λ₂
@model function textmodel_marginal_pred(data, α=mean(data))
λ₁ ~ Exponential(α)
λ₂ ~ Exponential(α)
lp = compute_terms(data, λ₁, λ₂)
p = exp.(lp .- logsumexp(lp))
τ ~ Categorical(p)
end
chns_τ = predict(rng, textmodel_marginal_pred(data), chns)
Chains MCMC chain (1000×1×4 Array{Float64, 3}):
Iterations = 1:1:1000
Number of chains = 4
Samples per chain = 1000
parameters = τ
internals =
Summary Statistics
parameters mean std naive_se mcse ess rhat
Symbol Float64 Float64 Float64 Float64 Float64 Float64
τ 29.8180 0.7470 0.0118 0.0097 4066.6011 0.9993
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
τ 29.0000 30.0000 30.0000 30.0000 31.0000
We end up with almost 6 times the ESS we got from the sigmoid model. And let’s compare the ECDFs:
@df chns_τ ecdfplot(:τ; label="τ discrete")
@df chns_sigmoid ecdfplot!(:τ; label="τ continuous",)
@df chns_sigmoid ecdfplot!(round.(:τ, RoundDown); label="τ continous rounded", xlabel="Day", ylabel="Frequency")
So the vast majority of the probability mass is on the true value of 30. And both the discrete and continuous models yield very similar inferences if we round the continuous parameter down to the nearest day.