Turing: indicator variables vs control flow

I have a change point detection model which I’m converting over from JAGS to Turing, and this one is interesting as it raises the question of whether you should use indicator variables (common to other PPL’s) or to use control flow.

I’ve implemented both, and they both ‘work’ in that sampling does not through an error. However having tried a couple of different samplers, the sampling is extremely bad. Any tips on why this is the case, which would be the best sampler, or whether one should favour use if indicator variables over control flow?

Indicator variable model

using Turing, StatsPlots

@model function model(c)
    tₘₐₓ = length(c) 
    t = [1:tₘₐₓ;]
    # priors
    μ = Vector(undef, 2)
    μ[1] ~ Normal(0, 100)
    μ[2] ~ Normal(0, 100)
    σ ~ Uniform(0, 100)
    τ ~ Uniform(1, tₘₐₓ)
    # indicator variable
    z = Vector(undef, tₘₐₓ)
    z[t.<τ] .= 1
    z[t.≥τ] .= 2
    # likelihood
    for i in 1:tₘₐₓ
        c[i] ~ Normal(μ[z[i]], σ)
    end
end

# generate data
τ_true, μ₁, μ₂, σ_true = 500, 45, 30, 4
c = vcat(rand(Normal(μ₁,σ_true), τ_true), 
         rand(Normal(μ₂,σ_true), 1000-τ_true))

chain = sample(model(c), MH(), 5000)

plot(chain)

# data
plot([1:length(c);], c, xlabel="Time", ylabel="Count", title="Data space", legend=false)
# mean posterior predictive
plot!([1, mean(chain[:τ])], [mean(chain["μ[1]"]), mean(chain["μ[1]"])], lw=6, color=:black)
plot!([mean(chain[:τ]), length(c)], [mean(chain["μ[2]"]), mean(chain["μ[2]"])], lw=6, color=:black)

results in terrible chains


and

If else model

using Turing, StatsPlots

@model function model(c)
    tₘₐₓ = length(c)
    # priors
    μ₁ ~ Normal(0, 100)
    μ₂ ~ Normal(0, 100)
    σ ~ Uniform(0, 100)
    τ ~ Uniform(1, tₘₐₓ)
    # likelihood
    for t in 1:tₘₐₓ
        if t < τ        
            c[t] ~ Normal(μ₁, σ)
        else
            c[t] ~ Normal(μ₂, σ)
        end
    end
end

# generate data
τ_true, μ₁, μ₂, σ_true = 500, 45, 30, 4
c = vcat(rand(Normal(μ₁,σ_true), τ_true), 
         rand(Normal(μ₂,σ_true), 1000-τ_true))

chain = sample(model(c), MH(), 5000)

plot(chain)

# data
plot([1:length(c);], c, xlabel="Time", ylabel="Count", title="Data space", legend=false)
# mean posterior predictive
plot!([1, mean(chain[:τ])], [mean(chain[:μ₁]), mean(chain[:μ₁])], lw=6, color=:black)
plot!([mean(chain[:τ]), length(c)], [mean(chain[:μ₂]), mean(chain[:μ₂])], lw=6, color=:black)

with similarly bad chains.

If the posterior is otherwise equivalent, go with what’s simpler/faster — possibly your second option.

Bad mixing may indicate a misfit of your model to the data, or just a model with multiple modes — MH is not particularly good at dealing with that.

Hi Ben,

I have 2 thoughts.

  1. There are some stan examples here that might help with the coding details.
  2. You might be better off with a smooth changepoint function (sigmoid). I don’t know the internals well enough to be sure but I suspect that might provide a better gradient for \tau. In particular, with incorrect values of \tau near 0, you’d expect \mu_2 to be somewhere like an average over the whole data which is what you see, so changepoint sampling issues seem pretty explanatory. I remember a post somewhere about this that I can’t find but iirc it was something like:
cp ~ Uniform(0.0, length) # Real uniform
switch = sigmoid(cp, bandwidth)     
# ^ center = cp, scale = bandwidth, exercise for reader to code :)
data ~ Normal(switch * mu_1 + (1 - switch) * mu_2, 1)

Here bandwidth controls how “strict” your cutoff could be. This function gives you a gradient at all values of switch. I’d love to hear any thoughts on this approach from someone with more experience!

Take care,
Brad

Here’s a comment from Daniel Lakeland on Gelman’s blog that suggests a similar approach.

1 Like

Thanks both. Doing more experimentation, it does seem that the chain for the switch point τ is all over the place. Even when tightening up the priors around the true values for the other parameters, it still does a poor job. So yes I suspect you’re right re. the switch point.

Although just visualising in my mind, I’d assume as the switch point gets closer to the true switch point then the likelihood would improve IF the sample for mu1 is higher than mu2, otherwise it would presumably be worse. So yes, possibly it’s just a hard problem to solve re the switch point even though it’s superficially trivial.

What does not help, is if you constrain mu2 > mu1 for example.

It also does not help if you change the model and add some gradient in a linear discontinuity type model.

So yep, I’ll experiment with your suggestions with a sigmoid around the switchpoint.

Hi Ben.

I took a stab myself and identified a few things:

  1. Perhaps not surprising, standardizing the data really helps. Alternately, we could have centered in the model instead. Here’s what I did:
τ_true, μ₁, μ₂, σ_true = 500, 45, 30, 4
c = vcat(rand(Normal(μ₁,σ_true), τ_true), 
         rand(Normal(μ₂,σ_true), 1000-τ_true));

std_c = (c .- mean(c)) ./ std(c); 
  1. The sigmoid approach also worked, here’s the code:
logit = bijector(Beta())   # bijection:  (0, 1)  →  ℝ
inv_logit = inv(logit)     # bijection:       ℝ  →  (0, 1)

function sigm(μ, σ, x)   # scaled, shifted sigmoid
    return inv_logit(σ*(x - μ))
end;


@model function changepoint(data)
    # priors
    spec = 0.01
    μ_1 ~ Normal(0, 2)
    μ_2 ~ Normal(0, 2)
    τ ~ Uniform(1, 1000)
    σ ~ truncated(Normal(1, 2), 0, 20)
    
    # likelihood
    for i in 1:length(data)
        switch = sigm(τ, spec, i)
        z = (1-switch) * μ_1 + switch * μ_2
        data[i] ~ Normal(z, σ)
    end
end;


# Settings of the Hamiltonian Monte Carlo (HMC) sampler.
iterations = 2000
ϵ = 0.005
τ = 10;

cp_chain = sample(
    changepoint(std_c), 
    HMC(ϵ, τ), iterations, 
    progress=true, drop_warmup=false);

StatsPlots.plot(cp_chain)

Note that the traces are against the standardized data, which looks like:

7 Likes

An excellent answer!

Thanks for this excellent answer @BradGroff. It’s also spurred me to add bijectors to my “to learn” list.

I tried to delve into this a bit further. It’s still confusing to me why estimation of change points in JAGS works fine, but not here. I get why the sigmoid might help things, but I still can’t quite intuit why a change point model would fail here.

To experiment with that, I tried a change point model where all the parameters were known other than the change point. Inferring the change point alone does indeed work. Also works with MH.

using Turing, StatsPlots

@model function model(c, μ₁, μ₂, σ)
    tₘₐₓ = length(c)
    # prior
    τ ~ Uniform(1, tₘₐₓ) 
    # likelihood
    for t in 1:tₘₐₓ
        if t < τ        
            c[t] ~ Normal(μ₁, σ)
        else
            c[t] ~ Normal(μ₂, σ)
        end
    end
end

# generate data
τ, μ₁, μ₂, σ = 350, 45, 30, 4
c = vcat(rand(Normal(μ₁,σ), τ), 
         rand(Normal(μ₂,σ), 1000-τ))

chain = sample(model(c,μ₁, μ₂, σ), HMC(0.005, 10), 2000)

plot(chain)

Where the problem seems to start is when inferring a change point and a mean. If you have custom selected priors, it’s doable, but the convergence is very slow…

@model function model(c, μ₂, σ)
    tₘₐₓ = length(c)
    # prior
    τ ~ Uniform(1, tₘₐₓ) 
    μ₁ ~ Normal(45, σ)
    # likelihood
    for t in 1:tₘₐₓ
        if t < τ        
            c[t] ~ Normal(μ₁, σ)
        else
            c[t] ~ Normal(μ₂, σ)
        end
    end
end

# generate data
τ, μ₁, μ₂, σ = 350, 45, 30, 4
c = vcat(rand(Normal(μ₁,σ), τ), 
         rand(Normal(μ₂,σ), 1000-τ))

chain = sample(model(c, μ₂, σ), HMC(0.005, 10), 5000)
plot(chain)

And the convergence gets unworkably slow with non hand-picked priors.

So I guess it looks like the problem is not the change point / step function alone, but estimating that in conjunction with one or both means. I still don’t quite get the intuition of the problem, but certainly get why the sigmoid solution would work. I’ll certainly be going forward with that… thanks again for the thorough reply :slight_smile:

It’s also a bit odd… if you standardise the data as in @BradGroff 's example it works fine. But if you try to adjust the priors to the data then that does not work well, even with a ton of samples.

@model function changepoint(c)
    tₘₐₓ = length(c)
    spec = 0.01
    # priors
    μ_1 ~ Normal(mean(c), std(c)*2)
    μ_2 ~ Normal(mean(c), std(c)*2)
    σ ~ TruncatedNormal(0, std(c)*2, 0, std(c)*20)
    σ ~ Uniform(0, std(c)*20)
    τ ~ Uniform(1, tₘₐₓ)
    # likelihood
    for t in 1:length(c)
        switch = sigmoid(τ, spec, t)
        z = (1-switch) * μ_1 + switch * μ_2
        c[t] ~ Normal(z, σ)
    end
end

# generate data
τ_true, μ₁, μ₂, σ_true = 350, 45, 30, 4
c = vcat(rand(Normal(μ₁,σ_true), τ_true), 
         rand(Normal(μ₂,σ_true), 1000-τ_true))

chain = sample(changepoint(c), HMC(0.005, 10), 10000)
plot(chain)

EDIT: Seems to work way better using NUTS, compared to HMC.

This is to be expected as inference in the model with very wide priors is more difficult for HMC. You might want to switch to NUTS but in any case a model reparameterisation would be a good idea in this case.

1 Like