Ordinal probit model with Turing - difficulty sampling

I am trying to use Turing to estimate an ordinal multinomial model with partially fixed thresholds/cutpoints, but I’m wondering if I’m doing something wrong: metropolis-hastings sampling works great, but the autodiff-based samplers (NUTS/HMC) can’t seem to do a single iteration, even running overnight (plus indicate errors in estimating starting values for the sampler).

Any Turing users have any ideas what’s going wrong here/what I’m doing wrong?

I’m wondering if it’s a kind of type instability (perhaps from combining the fixed cutpoints with the estimated ones into a single vector), but (a) I was having a lot of trouble finding correct syntax to check that in current versions of Turing, and (b) even if that’s the problem, I still wouldn’t know how to solve it. :slight_smile:

EDIT: Just after I posted, I found the correct code for checking type instability. The code involving the multinomial distribution and the ps vector show up in red.

Here’s a MWE:

using Turing, Distributions

responses = [1, 5, 23, 21, 27] # Number of people choosing responses 1-5
total_n = 77

@model function ordered_multinomial(responses, counts, threshold_lower, threshold_upper)

    # model mean and std of the latent normal distribution
    mu ~ Normal(0, 2)
    sigma ~ Exponential(2)

    # Model cutpoints/thresholds. There are K-1 thresholds 
    # (where K is the number of response options).
    # Here, two thresholds are fixed (to allow estimating the 
    # latent normal location and scale). To estimate the two 
    # "free" thresholds, I partition the space between the upper 
    # and lower thresholds into three parts, using a Dirichlet prior
    threshold_partitions ~ Dirichlet(3, 1)

    # Make a vector of all thresholds/cutpoints, combining known 
    # (fixed) values and estimated values
    threshold_points = cumsum(threshold_partitions) * (threshold_upper - threshold_lower) .+ threshold_lower
    c = vcat(-Inf, threshold_lower, threshold_points, Inf)

    # Using the latent normal distribution and set of cutpoints, get probabilities of 
    # people choosing a given response option (1-5)
    ps = [cdf(Normal(0, 1), (i[2] - mu) / sigma) - cdf(Normal(0, 1), (i[1] - mu) / sigma ) for i in zip(c[1:(end-1)], c[2:end])]

    # Model likelihood of response pattern given set of response 
    # probabilities and a total sample size
    responses ~ Multinomial(counts, ps)

mh_chain = sample(ordered_multinomial(responses, total_n, 1.5, 4.5), MH(diagm(ones(4) .* [0.1, 0.1, 0.02, 0.02])), 10000)

# NUTS/HMC versions just sit there, even overnight. CPU chart shows it doing SOMETHING, but no idea what
# nuts_chain = sample(ordered_multinomial(responses, total_n, 1.5, 4.5), NUTS(), 1000)

# Check type instability
model = ordered_multinomial(responses, total_n, 1.5, 4.5)

@code_warntype model.f(
        Random.default_rng(), Turing.SampleFromPrior(), Turing.DefaultContext()

# Multinomial distribution shows up in red, although 
# I'm not sure if that's the source of the problem I'm having

Well, I did manage to remove any type instabilities. The instability of the “ps” vector I resolved by avoiding the comprehension and simply doing a loop. E.g.,

ps = Array{T}(undef, 5)
for i in 1:5
        ps[i] = cdf(Normal(0, 1), (d[i+1] - mu) / sigma) - cdf(Normal(0, 1), (d[i] - mu) / sigma)

However, removing the type instability does nothing to fix the issue of NUTS/HMC being so slow as to appear frozen.

And since I forgot to post my version numbers, this is on Julia 1.10.2 and Turing 0.30.7.

Not exactly sure what the problem is, but it seems to be a numeric issue breaking the auto-diff. When recoding the model as follows, it samples nicely for me:

c = vcat(threshold_lower, threshold_points)
tmp = cdf.(Normal(0, 1), (c .- mu) ./ sigma)
ps = diff(vcat(0, tmp, 1))
1 Like

Thanks! I suspect you are right about the numeric issues (I got the same results with both forward and reverse diff, fwiw). Your fix works quite well.

Well, your suggested fix worked great as long as I was doing a single multinomial observation.

But for some reason, looping over a series of such observations (my ultimate use case – I want to model Likert responses from a large set of variables) seems to give a stackoverflow error.


using Turing

@model function test(x)
    counts = sum(x; dims=1)
    J, N = size(x)

    y ~ filldist(Dirichlet(J, 1), N)

    for i in 1:N
        x[:, i] ~ Multinomial(counts[i], y[:, i])


responses = [sample(10:25) for i in 1:5, j in 1:3]

m1 = sample(test(responses), HMC(0.01, 10), 1000)

ERROR: StackOverflowError:
     [1] Array
       @ ./boot.jl:479 [inlined]
     [2] Array
       @ ./boot.jl:487 [inlined]
     [3] similar
       @ ./array.jl:421 [inlined]
     [4] _simplex_bijector
       @ ~/.julia/packages/Bijectors/gR4QV/src/bijectors/simplex.jl:22 [inlined]
     [5] transform
       @ ~/.julia/packages/Bijectors/gR4QV/src/bijectors/simplex.jl:16 [inlined]
     [6] with_logabsdet_jacobian(b::Bijectors.SimplexBijector, x::Matrix{Float64})
       @ Bijectors ~/.julia/packages/Bijectors/gR4QV/src/bijectors/simplex.jl:14
     [7] logabsdetjac(b::Bijectors.SimplexBijector, x::Matrix{Float64})
       @ Bijectors ~/.julia/packages/Bijectors/gR4QV/src/interface.jl:119
--- the last 2 lines are repeated 39989 more times ---
 [79986] with_logabsdet_jacobian(b::Bijectors.SimplexBijector, x::Matrix{Float64})
       @ Bijectors ~/.julia/packages/Bijectors/gR4QV/src/bijectors/simplex.jl:14

The error code suggests a possible issue somewhere in the bijector calculation. And indeed, I can run the model just fine if I use metropolis-hastings.

It appears to be something specific to slicing the array of observations and/or counts. If I reduce the example further to:

@model function test(x)
    counts = vec(sum(x; dims=1))
    J, N = size(x)

    y ~ filldist(Dirichlet(J, 1), N)

    x[:, 1] ~ Multinomial(counts[1], y[:, 1])

responses = [sample(10:25) for i in 1:5, j in 1:3]

m1 = sample(test(responses), HMC(0.01, 10), 1000)

I get the same error.

I thought I’d post here again before submitting a report on github, just in case I’m doing something known wrong in my code.