Help with Bayesian Latent Variable Model

Hi everyone,

I am currently running into some issues with a Latent Variable Model I am trying to set up. I already simplified the model a few times, because I first want to get the simplest idea working and then build up from there.
But in essence, imagine we have a large number of observations, each observation has multiple different scores. We believe these scores might be highly related, and thus stemming from some underlying latent variable that can explain certain correlations. That is already basically the gist of it. In reality, there are more factors, but I am working with a fully fictional generated model atm, so leaving all facts of reality outside of it, and I cannot seem to get it to work.

I have tried to find numerous sources, but it is hard to find one solid reliable source, both literature and concrete code examples. It seems like many sources I find approach the whole latent variable model slightly different as well, which does not help in giving me confidence to pick the right setup.

Code and model

using Turing, Distributions, MCMCChains, ReverseDiff, Memoization, Random
import LinearAlgebra


function construct_loadings(L, L_d, L_t, nr_latent_factor, nr_manifest_variables)
    idx = 1
    # construct loadings matrix with identification constraints (as specified in 
    # "Bayesian latent variable models for the analysis of experimental psychology data")
    # Since we have Φ=I, we fix F*(F-1)/2 parameters in L (where F is nr_latent_factor)
    # So one column can have no zeros, one column has one zero ... one column must have m-1 zeros.
    # which we achieve, by fixing upper triangular to 0.
    # An additional constraint is that each column must contain a nonzero parameter forced to assume only positive or negative values
    # which is what we do by setting the diagonal to positive values only (L_t)
    for i in 1:nr_manifest_variables
        for j in (i+1):nr_latent_factor
            L[i,j] = 0.0
    for j in 1:nr_latent_factor
        L[j,j] = L_d[j]
        for k in (j+1):nr_manifest_variables
            L[k, j] = L_t[idx]
            idx = idx + 1
    return L

# nr_observations: the number of total observations we have
# nr_manifest_variables: for each observation, the number of manifest variables, i.e. observed dimensions we have
# nr_latent_factor: the number of latent factors we expect to be underlying all of the VAS factors
@model function latent_variable_model(observed_variables, nr_observations, nr_manifest_variables, nr_latent_factors, ::Type{T} = Float64) where {T}
    non_zero = Int(nr_latent_factors*(nr_manifest_variables-nr_latent_factors) + nr_latent_factors*(nr_latent_factors-1)/2)

    # mu_L_t ~ Normal(0, 1)
    # sigma_L_t ~ truncated(Normal(0, 1), 0, Inf)
    # parameters for the lower triangular part of the loadings matrix
    L_t ~ filldist(Normal(0, 1), non_zero)
    # positive parameters for the diagonal of the loading matrix
    L_d ~ filldist(truncated(Normal(0, 1), 0, Inf), nr_latent_factors) 
    L = construct_loadings(Array{T, 2}(undef, nr_manifest_variables, nr_latent_factors), L_d, L_t, nr_latent_factors, nr_manifest_variables)
    # basically just the variation of each of the manifest variables
    psi ~ filldist(truncated(Normal(0, 2), 0, Inf), nr_manifest_variables)
    Q = LinearAlgebra.Diagonal(psi)

    # mean of each of the latent variables
    latent_mu ~ filldist(Normal(0, 1), nr_latent_factors)
    # parameters for the actual latent variables for each observations
    # latent_factors ~ filldist(MvNormal(latent_mu, 1), nr_observations)
    latent_influence = L*latent_mu
    # loop through each test
    for sample_index in 1:nr_observations
        # observed_variables[:, sample_index] ~ MvNormal(latent_influence[:, sample_index], Q)
        observed_variables[:, sample_index] ~ MvNormal(latent_influence, Q)

    return observed_variables

nr_observations = 500
nr_manifest_variables = 2
nr_latent_factors = 2

missing_observed_variables = zeros(nr_manifest_variables, nr_observations)
model_simulated = latent_variable_model(missing_observed_variables, nr_observations, nr_manifest_variables, nr_latent_factors)
model_simulated_mis = DynamicPPL.Model{(:observed_variables,)}(:model_simulated_mis, model_simulated.f, model_simulated.args, model_simulated.defaults)
chain_simulated = sample(model_simulated_mis, Prior(), 1)

observed_variables = generated_quantities(model_simulated, chain_simulated)[1]

chain_except_excluded_variables = set_section(chain_simulated, merge(Dict(:internals => chain_simulated.name_map[:internals]),
Dict(variable => namesingroup(chain_simulated, variable) for variable in [:observed_variables])))

model = latent_variable_model(observed_variables, nr_observations, nr_manifest_variables, nr_latent_factors)

n_threads = Threads.nthreads()
n_samples = Int(round(2000/n_threads))  # number of MCMC samples
n_adapt = n_samples
target_accept_ratio = 0.65
NUTS_algorithm = NUTS(n_adapt, target_accept_ratio)

chain, t, bytes, gctime, memallocs = @timed sample(model, NUTS_algorithm, MCMCThreads(), n_samples, n_threads) ```

I have tried many variations, and at this point not sure what basis makes most sense. I also considered having the actual latent variables as parameters for each observation explicitly, but realize now that that does not make sense. In my real model, I would have the mean of the MvNormal likelihood for the observed variables probably rely on other factors that influence each observation as well.

Additionally, with some of my latest attempts certain diagnostics have been improving, but still when I see the results, I feel like they are further from the true loading factors than I would hope, or sometimes the standard deviation is at least very large (like in the parameters I attached here)? But maybe I am just being naive and that is the amount of precision one is expected to get from a latent variable model such as mine. But intuitively it just feels like too much variation for such a simple model regardless, so I believe I am missing something, but not sure what.

Generated and posterior parameters from above script
Parameter Generated (true) value posterior mean std naive_se mcse ess rhat
L_d[1] 0.2994 0.3231 0.4091 0.0091 0.0322 64.3507 1.0504
L_d[2] 0.8163 0.9319 0.7233 0.0162 0.1276 7.1730 1.5240
L_t[1] -0.1774 -0.0985 0.9073 0.0203 0.1004 21.5255 1.1299
latent_mu[1] 0.0196 -0.0318 0.3969 0.0089 0.0380 37.6252 1.0836
latent_mu[2] 0.4481 0.5334 0.5728 0.0128 0.0341 409.0588 1.0174
psi[1] 2.4593 2.2931 0.1478 0.0033 0.0088 110.1069 1.0401
psi[2] 0.5688 0.5097 0.0312 0.0007 0.0018 356.1487 1.0289

Note that the runtime with 2000 iterations and 500 data points is quite large already, at close to 2 hours. Which is not ideal, but also useful to know if you try to run this script yourself.

I haven’t looked at your model closely, but when I run your example, it evaluates in less than a minute for me. Is this same example taking 2 hours on your machine?

Anyways, since it ran fast for me, I tried increasing n_adapt and n_samples each to 1_000 and target_accept_ratio to 0.8. For this seed, we see a lot of divergent transitions still. They disappear when increasing the target_accept_ratio to 0.95, which is a good sign, but the model, or at least its parameterization, can still likely be improved.

Here’s the state that generated the data:

julia> chain_except_excluded_variables |> summarystats
Summary Statistics
    parameters      mean       std   naive_se      mcse       ess      rhat   ess_per_sec 
        Symbol   Float64   Float64    Float64   Missing   Missing   Missing       Missing 

        L_d[1]    0.2609       NaN        NaN   missing   missing   missing       missing
        L_d[2]    1.1341       NaN        NaN   missing   missing   missing       missing
        L_t[1]    1.4643       NaN        NaN   missing   missing   missing       missing
  latent_mu[1]    0.6621       NaN        NaN   missing   missing   missing       missing
  latent_mu[2]   -0.3112       NaN        NaN   missing   missing   missing       missing
        psi[1]    0.8150       NaN        NaN   missing   missing   missing       missing
        psi[2]    2.8243       NaN        NaN   missing   missing   missing       missing

And here’s the summary of the posterior samples:

julia> chain |> summarystats
Summary Statistics
    parameters      mean       std   naive_se      mcse         ess      rhat   ess_per_sec 
        Symbol   Float64   Float64    Float64   Float64     Float64   Float64       Float64 

        L_d[1]    0.5023    0.4306     0.0068    0.0112   1419.2850    1.0007        3.0618
        L_d[2]    0.7401    0.5153     0.0081    0.0122   2059.5792    1.0022        4.4431
        L_t[1]    0.4376    0.8619     0.0136    0.0260   1546.2667    1.0012        3.3357
  latent_mu[1]    0.5009    0.4088     0.0065    0.0121    876.1555    1.0045        1.8901
  latent_mu[2]    0.5644    0.7359     0.0116    0.0185   1530.3438    1.0001        3.3014
        psi[1]    0.9105    0.0568     0.0009    0.0012   2402.9814    1.0006        5.1839
        psi[2]    2.5921    0.1601     0.0025    0.0027   3090.7690    1.0002        6.6676

If we check using ArviZ.jl, we find that the state that generated the data is within the 94% HDI of the marginal posterior for each variable.

Hmm, weird, when rerunning the sample I posted it takes about 100 seconds for me. I must have gotten things mixed up, I honestly don’t recall how else that could have happened.

Anyway, yeah as can be seen in your run as well, the standard deviations are all pretty wide. That is exactly the issue I am also encountering. So even though the generated data does fall within the range, with the given standard deviations way too much is plausible, even with all the factors flipped around etc.

I also read somewhere that 2 factors and 2 observables doesn’t really work, but that you need 2 to 3 observables per latent factor. Doing that indeed changed the problem I was experiencing and it turned more into straight up non-identifiability.
I found multiple STAN posts as well with similar models, but in most threads when people would ask if they ever fixed it, the answer was usually no sadly. It seems like exploratory factor analysis with MCMC is very tricky for some reason. So currently I am trying confirmatory factor analysis as an alternative, and it indeed doesn’t seem to suffer from those same issues, only from BFMI issues sometimes so far.