I have a change point detection model which I’m converting over from JAGS to Turing, and this one is interesting as it raises the question of whether you should use indicator variables (common to other PPL’s) or to use control flow.
I’ve implemented both, and they both ‘work’ in that sampling does not through an error. However having tried a couple of different samplers, the sampling is extremely bad. Any tips on why this is the case, which would be the best sampler, or whether one should favour use if indicator variables over control flow?
Indicator variable model
using Turing, StatsPlots
@model function model(c)
tₘₐₓ = length(c)
t = [1:tₘₐₓ;]
# priors
μ = Vector(undef, 2)
μ[1] ~ Normal(0, 100)
μ[2] ~ Normal(0, 100)
σ ~ Uniform(0, 100)
τ ~ Uniform(1, tₘₐₓ)
# indicator variable
z = Vector(undef, tₘₐₓ)
z[t.<τ] .= 1
z[t.≥τ] .= 2
# likelihood
for i in 1:tₘₐₓ
c[i] ~ Normal(μ[z[i]], σ)
end
end
# generate data
τ_true, μ₁, μ₂, σ_true = 500, 45, 30, 4
c = vcat(rand(Normal(μ₁,σ_true), τ_true),
rand(Normal(μ₂,σ_true), 1000-τ_true))
chain = sample(model(c), MH(), 5000)
plot(chain)
# data
plot([1:length(c);], c, xlabel="Time", ylabel="Count", title="Data space", legend=false)
# mean posterior predictive
plot!([1, mean(chain[:τ])], [mean(chain["μ[1]"]), mean(chain["μ[1]"])], lw=6, color=:black)
plot!([mean(chain[:τ]), length(c)], [mean(chain["μ[2]"]), mean(chain["μ[2]"])], lw=6, color=:black)
results in terrible chains
and
If else model
using Turing, StatsPlots
@model function model(c)
tₘₐₓ = length(c)
# priors
μ₁ ~ Normal(0, 100)
μ₂ ~ Normal(0, 100)
σ ~ Uniform(0, 100)
τ ~ Uniform(1, tₘₐₓ)
# likelihood
for t in 1:tₘₐₓ
if t < τ
c[t] ~ Normal(μ₁, σ)
else
c[t] ~ Normal(μ₂, σ)
end
end
end
# generate data
τ_true, μ₁, μ₂, σ_true = 500, 45, 30, 4
c = vcat(rand(Normal(μ₁,σ_true), τ_true),
rand(Normal(μ₂,σ_true), 1000-τ_true))
chain = sample(model(c), MH(), 5000)
plot(chain)
# data
plot([1:length(c);], c, xlabel="Time", ylabel="Count", title="Data space", legend=false)
# mean posterior predictive
plot!([1, mean(chain[:τ])], [mean(chain[:μ₁]), mean(chain[:μ₁])], lw=6, color=:black)
plot!([mean(chain[:τ]), length(c)], [mean(chain[:μ₂]), mean(chain[:μ₂])], lw=6, color=:black)
with similarly bad chains.