Turing, Ordered Variables, and the Dirichlet Distribution

Hi all,

I am trying to fit some diffusion data which we expect to be generated from two or more Rayleigh distributions.

To make some data for the MWE, take the following example:

function get_data(; μ₁ = 1, μ₂ = 2, N = 1000, fraction = 0.5)
    dist1 = Rayleigh(μ₁)
    dist2 = Rayleigh(μ₂)
    N1 = round(Int, fraction * N)
    N2 = round(Int, (1 - fraction) * N)
    data = vcat(rand(dist1, N1), rand(dist2, N2))
    return data
end
data = get_data(; fraction = 0.7);

As such, this is a finite mixture model of different Rayleigh distributions. At first, I experienced problems with label switching and thus I wanted to apply the ordered method used in Stan. The closest I could find in Turing was this thread which correctly allows me to enforce ordering, however, at the cost of some more complex code (having to use generated_quantities).

My function thus became:

@model function diffusion_ordered(Δr)

    N = length(Δr)

    ΔD ~ filldist(Exponential(), 2)
    D = cumsum(ΔD)

    θ ~ Uniform(0, 1)
    w = [θ, 1 - θ]

    dists = [Rayleigh(d) for d in D]
    Δr ~ filldist(MixtureModel(dists, w), N)
    return (; D)
end

This works well, and makes 1000 samples using 4 threads in 2-3 seconds (not counting compilation). The rhats also looks fine:

Summary Statistics
  parameters      mean       std   naive_se      mcse         ess      rhat   ess_per_sec 
      Symbol   Float64   Float64    Float64   Float64     Float64   Float64       Float64 

       ΔD[1]    0.9419    0.0521     0.0008    0.0014   1451.6843    1.0020       41.0069
       ΔD[2]    0.8301    0.0744     0.0012    0.0019   1696.2559    1.0016       47.9155
           θ    0.6139    0.0726     0.0011    0.0021   1380.5319    1.0016       38.9970

However, this doesn’t scale well to more mixtures (e.g. K=3 mixtures), due to the θ and w definitions . Therefore, I tried a function with w ~ Dirichlet(K, 1) instead:

@model function diffusion_ordered_dirichlet(Δr, K = 2)

    N = length(Δr)

    ΔD ~ filldist(Exponential(), K)
    D = cumsum(ΔD)

    w ~ Dirichlet(K, 1)

    dists = [Rayleigh(d) for d in D]
    Δr ~ filldist(MixtureModel(dists, w), N)
    return (; D)
end

This generates equally well mixed chains:

Summary Statistics
  parameters      mean       std   naive_se      mcse         ess      rhat   ess_per_sec 
      Symbol   Float64   Float64    Float64   Float64     Float64   Float64       Float64 

       ΔD[1]    0.9431    0.0518     0.0008    0.0010   2706.9331    1.0005        8.6276
       ΔD[2]    0.8292    0.0724     0.0011    0.0013   3277.2988    1.0008       10.4455
        w[1]    0.6152    0.0726     0.0011    0.0015   2516.3084    1.0008        8.0201
        w[2]    0.3848    0.0726     0.0011    0.0015   2516.3084    1.0008        8.0201

but is much, much slower. This time it takes the model 80 seconds to fit.

I also tried another way:

@model function diffusion_ordered_Δw(Δr, K = 2)

    N = length(Δr)

    ΔD ~ filldist(Exponential(), K)
    D = cumsum(ΔD)

    Δw ~ filldist(Exponential(), K)
    w = Δw / sum(Δw)

    dists = [Rayleigh(d) for d in D]
    Δr ~ filldist(MixtureModel(dists, w), N)
    return (; D, w)
end

which is a lot faster (6 seconds) and also scales to more mixtures. It does, however, also depend on some more coding afterwards to extract the w's using generated_quantities).


My two main questions are:

  1. Is this really the way to force ordering as a constraint? Which also requires one to extract the relevant parameters afterwards with generated_quantities.
  2. Why is Dirichlet so slow? Is know there was a type instability a few years back, but that should be long fixed by now.

Thanks a lot in advance!

1 Like