I have been trying to make a factor mixture model in Turing, but have run into some performance issues.
To see where the slowdown comes from, I started by splitting the model into two separate models: a simple factor model, and a simple mixture model.
The simple factor model (1 factor, 200 observations each w 3 dimensions) takes under 6s to do 2000 NUTS iterations. A simple mixture model (2 gaussian components, 200 observations) takes under 2s to do 2000 NUTS iterations (total iterations, including burnin).
So, perfectly fast. (Well, using ForwardDiff. If I use ReverseDiff, those numbers go (without caching) to 59s for the factor model/57s for the mixture. If I use the cache, I get 9s/10s, respectively, which is also extremely good.)
However, when I try to run code for the factor mixture model (which is basically a straightforward combination of the two simpler models, with the big change that it is a mixture of multivariates), it takes *~15 minutes to do 2000 NUTS iterations. Reversediff (without caching) takes 14 minutes just to sample 200 iterations. And this is a toy problem on a smallish data set.
So, it seems there’s something about the factor mixture model below that’s making it take a couple of orders of magnitude longer than one might (naively) expect, relative to the underlying models. Since it’s fitting a mixture of multivariates, I expect it to be substantially slower, but this seems excessive…
My code is below. A few things to note:
(1) The recommended logsumexp
from StatsFuns
doesn’t play well with the caching reversediff, so I use a simpler version (not as fast, but works with the cache).
(2) The usual MvNormal distribution requires a PD matrix, but NUTS/HMC often seem to have a little trouble with that, producing non-PD matrices with some regularity. I use PositiveFactorizations to relax things a bit, and so I don’t need an if
check that will cause issues with caching. Seems to work well.
Any idea why this code is so slow? I have looked at @code_warntype
, and nothing jumps out at me (admittedly, I find the output of that for Turing models a bit hard to understand). I tried to avoid operations that are known to be slow, and the individual models are quite fast.
All this is on 1.5.4.
I’m happy to post the separate factor and mixture models, in case anyone’s curious about those.
using Turing, BenchmarkTools, ReverseDiff, PositiveFactorizations, DistributionsAD
# Generate fake data - one factor, 2 clusters, 3 factor loadings
fake_factor = hcat(rand(Normal(-1, 0.5), 1, 100), rand(Normal(1, 0.5), 1, 100))
data_noise = 0.1
fake_data = [0., 0.5, -0.5] .+ [0.8, 0.6, 0.9] * fake_factor .+ rand(Normal(0, data_noise), 3, 200)
# A version of logsumexp that works with caching
function slow_logsumexp(x::Vector{T}) where T<:Real
A = maximum(x)
log(mapreduce(y -> exp(y - A), +, x)) + A
end
# Get the LL of a single (multivariate) observation,
# given a mixture of multivariate normal distributions.
# mu, sigmaChol, and w are vectors of length K.
# sigmaChol is the cholesky factor of the model-implied covariance.
function logpdf_mix(y, mu, sigmaChol, weight)
slow_logsumexp(map((m, s, w) -> logpdf(DistributionsAD.TuringDenseMvNormal(m, s), y) + log(w), mu, sigmaChol, weight))
end
# Given a P x N matrix of data, get the LL
# for all observations
function ll_mixture(y, mu, sigmaChol, weight)
mapreduce(x->logpdf_mix(x, mu, sigmaChol, weight), +, eachcol(y))
end
# data = observed matrix of data
# F = number of factors (only use one with this code)
# K = number of mixture components
@model mixture_efa(data, F, K) = begin
P, N =size(data)
# Factor loadings.
# Model identification achieved by constraining average loading to 1
# and constrained to be positive (to keep things simple)
loadings_part ~ filldist(truncated(Normal(0, 5), 0, Inf), P-1, F)
loadings_fixed = P .- sum(loadings_part; dims=1)
loadings_full = vcat(loadings_part, loadings_fixed)
# Item intercepts.
# Model identification achieved by making intercepts sum to 0
int_part ~ filldist(Normal(0, 5), P-1)
int_fixed = -sum(int_part)
int_full = vcat(int_part, int_fixed)
# Item uniquenesses
err_sd ~ filldist(truncated(Normal(0, 5), 0, Inf), P)
# Mean and variance of the latent factor
latent_mu ~ filldist(Normal(0, 10), F, K)
latent_var ~ filldist(truncated(Normal(0, 10), 0, Inf), F, K)
# Mixture weights
w ~ Dirichlet(K, 1)
# The model-implied mean and covariance of the observed data for
# each of the K mixture components
mu = map(k -> int_full .+ loadings_full * view(latent_mu, :, k), 1:K)
Sigma = map(k -> cholesky(Positive, Symmetric(loadings_full * view(latent_var, :, k) * loadings_full' .+ Diagonal(err_sd .^ 2))), 1:K)
# the LL of the observed data, given the mixture components
Turing.@addlogprob! ll_mixture(data, mu, Sigma, w)
end
# Although the model
mixture_test = mixture_efa(fake_data, 1, 2)
Turing.setadbackend(:forwarddiff)
@time m1 = sample(mixture_test, NUTS(100, 0.65), MCMCThreads(), 100, 4)