Porting multi-level regression with poststratification (MRP) tutorial from rstanarm

I was able to make some progress by reviewing the Julia source code for TuringGLM (single language implementation for the win!). I was able to remove the loop based on their example, as well as standardize predictors (as done in the rstanarm example). I also changed my priors for the SD to use Exponential(0.5), which is again done in the MRP example in Stan and also addresses the truncated() speed concern.

I’m not sure I’ve fully optimized the ported example (still noticeably slower than rstanarm version), but it’s getting there. For anyone interested, the idea is to use MRP in a multinomial choice model of vehicle ownership, stratified against both the human and vehicle populations.

For reference, my model function/call are now updated to:

@model function varying_intercept(
    state_idx, eth_idx, educ_idx, male_eth_idx, educ_age_idx, educ_eth_idx, y, X, n; 
        n_gr_state=length(unique(state_idx)), n_gr_eth=length(unique(eth_idx)), n_gr_educ=length(unique(educ_idx)),
        n_gr_educ_age=length(unique(educ_age_idx)), n_gr_educ_eth=length(unique(educ_eth_idx)), 
        n_gr_male_eth=length(unique(male_eth_idx)), predictors=size(X, 2), standardize::Bool=false
)
    
    if standardize
        μ_X, σ_X, X = standardize_predictors(X) # Use the TuringGLM function
    end
    
    #priors
    inter ~ Normal(0, 1)         # Overall intercept term
    β ~ filldist(Normal(), predictors)  # population-level coefficients
    #prior for variance of random intercepts
    τ_state ~ Exponential(0.5)    # group-level SDs intercepts for state
    τ_eth ~ Exponential(0.5)    # group-level SDs intercepts for ethnicity
    τ_educ ~ Exponential(0.5)    # group-level SDs intercepts for education
    τ_male_eth ~ Exponential(0.5)    # group-level SDs intercepts for male:ethnicity
    τ_educ_age ~ Exponential(0.5)    # group-level SDs intercepts for education:age
    τ_educ_eth ~ Exponential(0.5)    # group-level SDs intercepts for education:ethnicity
    zⱼ_state ~ filldist(Normal(), n_gr_state)       # group-level intercepts for state
    zⱼ_eth ~ filldist(Normal(), n_gr_eth)       # group-level intercepts for ethnicity
    zⱼ_educ ~ filldist(Normal(), n_gr_educ)       # group-level intercepts for education
    zⱼ_male_eth ~ filldist(Normal(), n_gr_male_eth)       # group-level intercepts for male:ethnicity interaction
    zⱼ_educ_age ~ filldist(Normal(), n_gr_educ_age)       # group-level intercepts for education:age interaction
    zⱼ_educ_eth ~ filldist(Normal(), n_gr_educ_eth)       # group-level intercepts for education:ethnicity interaction
    αⱼ_state = τ_state .* getindex.((zⱼ_state,), state_idx)       # group-level intercepts for state
    αⱼ_eth = τ_eth .* getindex.((zⱼ_eth,), eth_idx)       # group-level intercepts for ethnicity
    αⱼ_educ = τ_educ .* getindex.((zⱼ_educ,), educ_idx)       # group-level intercepts for education
    αⱼ_male_eth = τ_male_eth .* getindex.((zⱼ_male_eth,), male_eth_idx)       # group-level intercepts for male:ethnicity interaction
    αⱼ_educ_age = τ_educ_age .* getindex.((zⱼ_educ_age,), educ_age_idx)       # group-level intercepts for education:age interaction
    αⱼ_educ_eth = τ_educ_eth .* getindex.((zⱼ_educ_eth,), educ_eth_idx)       # group-level intercepts for education:ethnicity interaction
    v = inter .+ αⱼ_state .+ αⱼ_eth .+ αⱼ_educ .+ αⱼ_male_eth .+ αⱼ_educ_age .+ αⱼ_educ_eth .+ X * β
    #likelihood
    y ~ arraydist(LazyArray(@~ BernoulliLogit.(v)))
    
end;
# Retrieve the number of observations.
n, _ = size(cces_df)
y = cces_df[:, :abortion];
X = Matrix{Float64}(cces_df[:, [:repvote, :region_NE, :region_S, :region_W, :male]]);
state_idx = cces_df[:, :state_idx];
eth_idx = cces_df[:, :eth_idx];
educ_idx = cces_df[:, :educ_idx];
male = cces_df[:, :male];
educ_age_idx = cces_df[:, :educ_age_idx];
educ_eth_idx = cces_df[:, :educ_eth_idx];
male_eth_idx = cces_df[:, :male_eth_idx];

Nadapt = 1000
delta = 0.99

model_intercept = varying_intercept(
        state_idx, eth_idx, male_eth_idx, educ_idx, educ_age_idx, educ_eth_idx, y, X, n, standardize = true
)
chain_intercept = sample(model_intercept, NUTS(Nadapt,delta), MCMCThreads(), 2000, 4)
println(DataFrame(summarystats(chain_intercept)))
1 Like