I was able to make some progress by reviewing the Julia source code for TuringGLM (single language implementation for the win!). I was able to remove the loop based on their example, as well as standardize predictors (as done in the rstanarm example). I also changed my priors for the SD to use Exponential(0.5), which is again done in the MRP example in Stan and also addresses the truncated() speed concern.
I’m not sure I’ve fully optimized the ported example (still noticeably slower than rstanarm version), but it’s getting there. For anyone interested, the idea is to use MRP in a multinomial choice model of vehicle ownership, stratified against both the human and vehicle populations.
For reference, my model function/call are now updated to:
@model function varying_intercept(
state_idx, eth_idx, educ_idx, male_eth_idx, educ_age_idx, educ_eth_idx, y, X, n;
n_gr_state=length(unique(state_idx)), n_gr_eth=length(unique(eth_idx)), n_gr_educ=length(unique(educ_idx)),
n_gr_educ_age=length(unique(educ_age_idx)), n_gr_educ_eth=length(unique(educ_eth_idx)),
n_gr_male_eth=length(unique(male_eth_idx)), predictors=size(X, 2), standardize::Bool=false
)
if standardize
μ_X, σ_X, X = standardize_predictors(X) # Use the TuringGLM function
end
#priors
inter ~ Normal(0, 1) # Overall intercept term
β ~ filldist(Normal(), predictors) # population-level coefficients
#prior for variance of random intercepts
τ_state ~ Exponential(0.5) # group-level SDs intercepts for state
τ_eth ~ Exponential(0.5) # group-level SDs intercepts for ethnicity
τ_educ ~ Exponential(0.5) # group-level SDs intercepts for education
τ_male_eth ~ Exponential(0.5) # group-level SDs intercepts for male:ethnicity
τ_educ_age ~ Exponential(0.5) # group-level SDs intercepts for education:age
τ_educ_eth ~ Exponential(0.5) # group-level SDs intercepts for education:ethnicity
zⱼ_state ~ filldist(Normal(), n_gr_state) # group-level intercepts for state
zⱼ_eth ~ filldist(Normal(), n_gr_eth) # group-level intercepts for ethnicity
zⱼ_educ ~ filldist(Normal(), n_gr_educ) # group-level intercepts for education
zⱼ_male_eth ~ filldist(Normal(), n_gr_male_eth) # group-level intercepts for male:ethnicity interaction
zⱼ_educ_age ~ filldist(Normal(), n_gr_educ_age) # group-level intercepts for education:age interaction
zⱼ_educ_eth ~ filldist(Normal(), n_gr_educ_eth) # group-level intercepts for education:ethnicity interaction
αⱼ_state = τ_state .* getindex.((zⱼ_state,), state_idx) # group-level intercepts for state
αⱼ_eth = τ_eth .* getindex.((zⱼ_eth,), eth_idx) # group-level intercepts for ethnicity
αⱼ_educ = τ_educ .* getindex.((zⱼ_educ,), educ_idx) # group-level intercepts for education
αⱼ_male_eth = τ_male_eth .* getindex.((zⱼ_male_eth,), male_eth_idx) # group-level intercepts for male:ethnicity interaction
αⱼ_educ_age = τ_educ_age .* getindex.((zⱼ_educ_age,), educ_age_idx) # group-level intercepts for education:age interaction
αⱼ_educ_eth = τ_educ_eth .* getindex.((zⱼ_educ_eth,), educ_eth_idx) # group-level intercepts for education:ethnicity interaction
v = inter .+ αⱼ_state .+ αⱼ_eth .+ αⱼ_educ .+ αⱼ_male_eth .+ αⱼ_educ_age .+ αⱼ_educ_eth .+ X * β
#likelihood
y ~ arraydist(LazyArray(@~ BernoulliLogit.(v)))
end;
# Retrieve the number of observations.
n, _ = size(cces_df)
y = cces_df[:, :abortion];
X = Matrix{Float64}(cces_df[:, [:repvote, :region_NE, :region_S, :region_W, :male]]);
state_idx = cces_df[:, :state_idx];
eth_idx = cces_df[:, :eth_idx];
educ_idx = cces_df[:, :educ_idx];
male = cces_df[:, :male];
educ_age_idx = cces_df[:, :educ_age_idx];
educ_eth_idx = cces_df[:, :educ_eth_idx];
male_eth_idx = cces_df[:, :male_eth_idx];
Nadapt = 1000
delta = 0.99
model_intercept = varying_intercept(
state_idx, eth_idx, male_eth_idx, educ_idx, educ_age_idx, educ_eth_idx, y, X, n, standardize = true
)
chain_intercept = sample(model_intercept, NUTS(Nadapt,delta), MCMCThreads(), 2000, 4)
println(DataFrame(summarystats(chain_intercept)))