Factor mixture model in Turing - too slow?

I have been trying to make a factor mixture model in Turing, but have run into some performance issues.

To see where the slowdown comes from, I started by splitting the model into two separate models: a simple factor model, and a simple mixture model.

The simple factor model (1 factor, 200 observations each w 3 dimensions) takes under 6s to do 2000 NUTS iterations. A simple mixture model (2 gaussian components, 200 observations) takes under 2s to do 2000 NUTS iterations (total iterations, including burnin).

So, perfectly fast. (Well, using ForwardDiff. If I use ReverseDiff, those numbers go (without caching) to 59s for the factor model/57s for the mixture. If I use the cache, I get 9s/10s, respectively, which is also extremely good.)

However, when I try to run code for the factor mixture model (which is basically a straightforward combination of the two simpler models, with the big change that it is a mixture of multivariates), it takes *~15 minutes to do 2000 NUTS iterations. Reversediff (without caching) takes 14 minutes just to sample 200 iterations. And this is a toy problem on a smallish data set.

So, it seems there’s something about the factor mixture model below that’s making it take a couple of orders of magnitude longer than one might (naively) expect, relative to the underlying models. Since it’s fitting a mixture of multivariates, I expect it to be substantially slower, but this seems excessive…

My code is below. A few things to note:

(1) The recommended logsumexp from StatsFuns doesn’t play well with the caching reversediff, so I use a simpler version (not as fast, but works with the cache).

(2) The usual MvNormal distribution requires a PD matrix, but NUTS/HMC often seem to have a little trouble with that, producing non-PD matrices with some regularity. I use PositiveFactorizations to relax things a bit, and so I don’t need an if check that will cause issues with caching. Seems to work well.

Any idea why this code is so slow? I have looked at @code_warntype, and nothing jumps out at me (admittedly, I find the output of that for Turing models a bit hard to understand). I tried to avoid operations that are known to be slow, and the individual models are quite fast.

All this is on 1.5.4.

I’m happy to post the separate factor and mixture models, in case anyone’s curious about those.

using Turing, BenchmarkTools, ReverseDiff, PositiveFactorizations, DistributionsAD

# Generate fake data - one factor, 2 clusters, 3 factor loadings
fake_factor = hcat(rand(Normal(-1, 0.5), 1, 100), rand(Normal(1, 0.5), 1, 100))
data_noise = 0.1
fake_data = [0., 0.5, -0.5] .+ [0.8, 0.6, 0.9] * fake_factor .+ rand(Normal(0, data_noise), 3, 200)


# A version of logsumexp that works with caching
function slow_logsumexp(x::Vector{T}) where T<:Real
    A = maximum(x)
    log(mapreduce(y -> exp(y - A), +, x)) + A
end


# Get the LL of a single (multivariate) observation, 
# given a mixture of multivariate normal distributions.
# mu, sigmaChol, and w are vectors of length K.
# sigmaChol is the cholesky factor of the model-implied covariance.
function logpdf_mix(y, mu, sigmaChol, weight)
    slow_logsumexp(map((m, s, w) -> logpdf(DistributionsAD.TuringDenseMvNormal(m, s), y) + log(w), mu, sigmaChol, weight))
end 


# Given a P x N matrix of data, get the LL 
# for all observations
function ll_mixture(y, mu, sigmaChol, weight)
    mapreduce(x->logpdf_mix(x, mu, sigmaChol, weight), +, eachcol(y))
end

# data = observed matrix of data
# F = number of factors (only use one with this code)
# K = number of mixture components
@model mixture_efa(data, F, K) = begin
    P, N =size(data)

    # Factor loadings. 
    # Model identification achieved by constraining average loading to 1
    # and constrained to be positive (to keep things simple)
    loadings_part ~ filldist(truncated(Normal(0, 5), 0, Inf), P-1, F)
    loadings_fixed = P .- sum(loadings_part; dims=1)
    loadings_full = vcat(loadings_part, loadings_fixed)

    # Item intercepts. 
    # Model identification achieved by making intercepts sum to 0
    int_part ~ filldist(Normal(0, 5), P-1)
    int_fixed = -sum(int_part)
    int_full = vcat(int_part, int_fixed)

    # Item uniquenesses
    err_sd ~ filldist(truncated(Normal(0, 5), 0, Inf), P)

    # Mean and variance of the latent factor
    latent_mu ~ filldist(Normal(0, 10), F, K)
    latent_var ~ filldist(truncated(Normal(0, 10), 0, Inf), F, K)

    # Mixture weights
    w ~ Dirichlet(K, 1)

    # The model-implied mean and covariance of the observed data for 
    # each of the K mixture components
    mu = map(k -> int_full .+ loadings_full * view(latent_mu, :, k), 1:K)
    Sigma = map(k -> cholesky(Positive, Symmetric(loadings_full * view(latent_var, :, k) * loadings_full' .+ Diagonal(err_sd .^ 2))), 1:K)

    # the LL of the observed data, given the mixture components
    Turing.@addlogprob! ll_mixture(data, mu, Sigma, w)

end

# Although the model 
mixture_test = mixture_efa(fake_data, 1, 2)

Turing.setadbackend(:forwarddiff)
@time m1 = sample(mixture_test, NUTS(100, 0.65), MCMCThreads(), 100, 4)

Looking at the code a bit more, I think found one reason for the slowness, and it’s something I should have caught before I posted:

In the non-factor mixture model (that’s so fast), I actually used a Beta prior to get two proportions; whereas I used a Dirichlet in the factor mixture model.

I had looked at the @code_warntype output before, but I somehow missed the fact that the weight variable w was showing up as type Any. That doesn’t happen when I use a Beta prior, but it does happen when I use the Dirichlet.

If I change my above code to use a Beta, the speed does improve a lot. Reversediff + caching takes 420s (7min) to do 2000 iterations, and the samples look great.

Forwarddiff’s performance is rather weird. In a good number of runs (all after a short initial run, of course), it took anywhere between 200 and 700s, but only the long runs (600-700s) were actually good. The quicker forwarddiff runs all had one thread that was producing weird/wrong samples, and a lot of numeric errors during sampling.

I don’t know why Dirichlet isn’t producing a type-stable variable in the model. Fixing it does seem to improve performance a lot (~30% in forwarddiff, if I only look at “correct” runs);

However, I still hope there are things I can improve in the model. 7-11 minutes to handle a rather small problem (200 observations and 12 parameters) seems a bit slow, especially since the other individual models were each so very fast.

Edit:
Apparently, the type instability of the Dirichlet in Turing is a known (but maybe not well-known) issue. Although in my case, I don’t think it’s the whole problem.

1 Like