Getting Turing to work with a hierarchical multinomial regression model

Hi everyone,

I’m sorry if this question is too specific. I tried following advice in previous similar threads but couldn’t get anything to work, so figured I should try asking myself.

For context, I’m modeling 5000 Hangman games, where players revealed a hidden word by guessing letters. Each player played games in two conditions, and my goal is to compare their behaviour in the two conditions. My approach is to fit a multinomial logistic regression model to predict letter selections with three predictors: prior, post and eig. Each predictor is represented as a vector of length 26 (for 26 letters). Since it is impossible to select the same letter more than once in the same game, I have an additional mask vector with 1s for letters that can be selected and 0 for letters that can’t be selected.

Below is my model. The arguments are:

  • priors: n by 26 Float64 matrix
  • posts: n by 26 Float64 matrix
  • eig: n by 26 Float64 matrix
  • masks: n by 26 int64 matrix (0 or 1)
  • subj: n by 1 int64 vector (representing the subject number)
  • condition: n by 1 int64 vector (0 or 1)
  • y: n by 1 int64 vector (between 1 and 26, the index of the selected letter)
@model function logistic_regression(priors, posts, eig, masks, subj, condition, y)

    n = size(priors, 1) # number of rows
    length(y) == size(posts,1) == size(masks,1) == n ||
        throw(DimensionMismatch("number of observations is not equal"))
    
    n_subj = length(unique(subj));
    
    # Hyperparameters for main effects
    mu_prior ~ Normal(0, 10)
    sigma_prior ~ Gamma(2,1)
    
    mu_post ~ Normal(0, 10)
    sigma_post ~ Gamma(2,1)
    
    mu_eig ~ Normal(0, 10)
    sigma_eig ~ Gamma(2,1)
    
    # Hyperparameters for condition effects
    mu_prior_delta ~ Normal(0, 10)
    sigma_prior_delta ~ Gamma(2,1)
    
    mu_post_delta ~ Normal(0, 10)
    sigma_post_delta ~ Gamma(2,1)
    
    mu_eig_delta ~ Normal(0, 10)
    sigma_eig_delta ~ Gamma(2,1)
    
    # Subject-level parameters: main effets
    coef_prior ~ filldist(Normal(mu_prior, sigma_prior), n_subj)
    coef_post ~ filldist(Normal(mu_post, sigma_post), n_subj)
    coef_eig ~ filldist(Normal(mu_eig, sigma_eig), n_subj)

    # Subject-level parameters: condition effects
    coef_prior_delta ~ filldist(Normal(mu_prior_delta, sigma_prior_delta), n_subj)
    coef_post_delta ~ filldist(Normal(mu_post_delta, sigma_post_delta), n_subj)
    coef_eig_delta ~ filldist(Normal(mu_eig_delta, sigma_eig_delta), n_subj)
    
    for i in 1:n
        v_unmasked = softmax((coef_prior[subj[i]]*priors[i,:]+coef_post[subj[i]]*posts[i,:]+coef_eig[subj[i]]*eig[i,:]) +
            condition[i]*(coef_prior_delta[subj[i]]*priors[i,:]+coef_post_delta[subj[i]]*posts[i,:]+coef_eig_delta[subj[i]]*eig[i,:]))
        v = v_unmasked.*masks[i,:]/sum(v_unmasked.*masks[i,:])
        y[i] ~ Categorical(v)
    end
end;

I tried fitting the model using HMC and NUTS, and both don’t really work. HMC is slow (like >5 hours for one chain of 1500 steps) and returns chains that are stuck at a specific set of coordinates, and NUTS is much slower (the current ETA is 8:30 hours to complete a chain of 100 steps). When I try to run more than one chain using MCMCThreads(), it doesn’t even start running.

I followed previous advice and am using Turing.setrdcache(true) and Turing.setadbackend(:reversediff) but this doesn’t seem to help.

Any advice would be greatly appreciated. Many thanks!

You might find my post on this topic helpful. I’m working on a hierarchical multinomial problem, as well. I was able to improve the performance with a few tricks from the Turing developers.

1 Like

Thanks Jason! This does seem very relevant.
In the meantime I found a faster solution that I’m satisfied with: I estimate the MLE coefficients for each subject using optimize, and do frequentist statistics at the group level. I don’t get posterior probabilities over group-level coefficients, but then again I’m not sure that’s what I’m after in the first place :slight_smile:
Thanks again,
-Matan