Porting multi-level regression with poststratification (MRP) tutorial from rstanarm

I am working to port a tutorial on MRP from rstanarm to Turing. I’m also just getting started on Julia, so it’s an opportunity to learn the language at the same time.

I’ve estimated a few models in Stan and have some familiarity with the details (by no means a Bayesian master). I am working off this tutorial: Chapter 1 Introduction to Mister P | Multilevel Regression and Poststratification Case Studies

I can get the tutorial model to run without any problems in R. It runs in~30 minutes. The rstanarm specification is:

# Fit in stan_glmer
fit <- stan_glmer(abortion ~ (1 | state) + (1 | eth) + (1 | educ) + male +
                    (1 | male:eth) + (1 | educ:age) + (1 | educ:eth) +
                    repvote + factor(region),
  family = binomial(link = "logit"),
  data = cces_df,
  prior = normal(0, 1, autoscale = TRUE),
  prior_covariance = decov(scale = 0.50),
  adapt_delta = 0.99,
  refresh = 0,
  seed = 1010)

I’ve run through the data processing and confirmed the data is generating the same number of groups for each variable with hierarchical parameters. I started with another tutorial that was ported into Julia/Turing to get some understanding. I’m not perfectly copying the Stan model - e.g., prior = normal(0, 1, autoscale = TRUE) does a scaling and prior_covariance = decov(scale = 0.50) species a prior for the covariance (I think it’s exp(0.5) based on this discussion - Prior_covariance for stan_glmer - Other - The Stan Forums)

It’s been running for a few hours now and hasn’t finished. Either the model is wrong or my specification is much slower than the Stan one (both could also be true - wrong and slow).My model to reproduce the MRP tutorial is:

@model function varying_intercept(
    state_idx, eth_idx, educ_idx, male_eth_idx, educ_age_idx, educ_eth_idx, y, X, n; 
        n_gr_state=length(unique(state_idx)), n_gr_eth=length(unique(eth_idx)), n_gr_educ=length(unique(educ_idx)),
        n_gr_educ_age=length(unique(educ_age_idx)), n_gr_educ_eth=length(unique(educ_eth_idx)), 
        n_gr_male_eth=length(unique(male_eth_idx)), predictors=size(X, 1)
)
    #priors
    inter ~ Normal(0, 1)         # Overall intercept term
    β ~ filldist(Normal(0, 1), predictors)  # population-level coefficients
    #prior for variance of random intercepts
    τ_state ~ truncated(Cauchy(0, 2); lower=0)    # group-level SDs intercepts for state
    τ_eth ~ truncated(Cauchy(0, 2); lower=0)    # group-level SDs intercepts for ethnicity
    τ_educ ~ truncated(Cauchy(0, 2); lower=0)    # group-level SDs intercepts for education
    τ_male_eth ~ truncated(Cauchy(0, 2); lower=0)    # group-level SDs intercepts for male:ethnicity
    τ_educ_age ~ truncated(Cauchy(0, 2); lower=0)    # group-level SDs intercepts for education:age
    τ_educ_eth ~ truncated(Cauchy(0, 2); lower=0)    # group-level SDs intercepts for education:ethnicity
    αⱼ_state ~ filldist(Normal(0, τ_state), n_gr_state)       # group-level intercepts for state
    αⱼ_eth ~ filldist(Normal(0, τ_eth), n_gr_eth)       # group-level intercepts for ethnicity
    αⱼ_educ ~ filldist(Normal(0, τ_educ), n_gr_educ)       # group-level intercepts for education
    αⱼ_male_eth ~ filldist(Normal(0, τ_male_eth), n_gr_male_eth)       # group-level intercepts for male:ethnicity interaction
    αⱼ_educ_age ~ filldist(Normal(0, τ_educ_age), n_gr_educ_age)       # group-level intercepts for education:age interaction
    αⱼ_educ_eth ~ filldist(Normal(0, τ_educ_eth), n_gr_educ_eth)       # group-level intercepts for education:ethnicity interaction

    #likelihood
#     v = logistic(α .+ X * β .+ αⱼ[idx])
    for i in 1:n
        v = inter .+ αⱼ_state[state_idx[i]] .+ αⱼ_eth[eth_idx[i]]
        v+= αⱼ_educ[educ_idx[i]] .+ αⱼ_male_eth[male_eth_idx[i]]
        v+= αⱼ_educ_age[educ_age_idx[i]] .+ αⱼ_educ_eth[educ_eth_idx[i]]
        v+= X[:,i]'*β 
        v = logistic(v)
        y[i] ~ Bernoulli(v)
    end
    
end;

I call it with:

# Retrieve the number of observations.
n, _ = size(cces_df)
y = cces_df[:, :abortion];
X = Matrix{Float64}(cces_df[:, [:repvote, :region_NE, :region_S, :region_W, :male]]);
X = X' # transpose matrix to speedup computation and avoid transpose on each for loop
state_idx = cces_df[:, :state_idx];
eth_idx = cces_df[:, :eth_idx];
educ_idx = cces_df[:, :educ_idx];
male = cces_df[:, :male];
educ_age_idx = cces_df[:, :educ_age_idx];
educ_eth_idx = cces_df[:, :educ_eth_idx];
male_eth_idx = cces_df[:, :male_eth_idx];

Nadapt = 1000
delta = 0.99 # MRP tutorial says they had to increase adapt_delta to 0.99 to avoid divergent transitions

model_intercept = varying_intercept(
        state_idx, eth_idx, male_eth_idx, educ_idx, educ_age_idx, educ_eth_idx, y, X, n
)
chain_intercept = sample(model_intercept, NUTS(Nadapt,delta), MCMCThreads(), 2000, 4)
println(DataFrame(summarystats(chain_intercept)))

One immediate improvement is to change the AD backend in Turing to use ReverseDiff.jl instead of ForwardDiff.jl (the default)

using Turing, ReverseDiff

Turing.setadbackend(:reversediff)
Turing.setrdcache(true)

See also Automatic Differentiation

In my experience Turing will still be slower than Stan for these types of models where the number of parameters is large.

We had some discussions in the past which might also be helpful to you:

It looks like you put together a good MRE for the problem. I’ll go through these threads and see what improvements I can make in the model.

Pretty sure the truncated priors are one really slow thing in Turing.

I’ve implemented reverse AD. @dlakelan is there anything I can do about the truncated prior in Turing?

I’m trying to adapt the solution from @torfjelde given in the linked thread from @p-gw to my problem. I have the following (have to keep the loop since I build up v from indexed vectors with multiple observations per index/individual.

    inter ~ Normal(0, 1)
    v_vect = ~ filldist(Normal(0, 1), n) # trying to make a generic vector of distributions to fill with loop result. 
    for i in 1:n
        v = inter # v has additional terms in the full model - e.g., v = inter .+ αⱼ_state[state_idx[i]] .+ αⱼ_eth[eth_idx[i]]
        v_vect[n] = v
    end
    y ~ arraydist(BroadcastArray(BernoulliLogit, v_vect));

This gives me: nested task error: MethodError: no method matching ~(::DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}) at the line where I generate v_vect.

I think I need to look closely at the Stan code generated by rstanarm in the MRP tutorial and output from TuringGLM.jl (which I remember reading cannot do my full model yet because it requires many varying intercept terms, but I could be wrong).

I was able to make some progress by reviewing the Julia source code for TuringGLM (single language implementation for the win!). I was able to remove the loop based on their example, as well as standardize predictors (as done in the rstanarm example). I also changed my priors for the SD to use Exponential(0.5), which is again done in the MRP example in Stan and also addresses the truncated() speed concern.

I’m not sure I’ve fully optimized the ported example (still noticeably slower than rstanarm version), but it’s getting there. For anyone interested, the idea is to use MRP in a multinomial choice model of vehicle ownership, stratified against both the human and vehicle populations.

For reference, my model function/call are now updated to:

@model function varying_intercept(
    state_idx, eth_idx, educ_idx, male_eth_idx, educ_age_idx, educ_eth_idx, y, X, n; 
        n_gr_state=length(unique(state_idx)), n_gr_eth=length(unique(eth_idx)), n_gr_educ=length(unique(educ_idx)),
        n_gr_educ_age=length(unique(educ_age_idx)), n_gr_educ_eth=length(unique(educ_eth_idx)), 
        n_gr_male_eth=length(unique(male_eth_idx)), predictors=size(X, 2), standardize::Bool=false
)
    
    if standardize
        μ_X, σ_X, X = standardize_predictors(X) # Use the TuringGLM function
    end
    
    #priors
    inter ~ Normal(0, 1)         # Overall intercept term
    β ~ filldist(Normal(), predictors)  # population-level coefficients
    #prior for variance of random intercepts
    τ_state ~ Exponential(0.5)    # group-level SDs intercepts for state
    τ_eth ~ Exponential(0.5)    # group-level SDs intercepts for ethnicity
    τ_educ ~ Exponential(0.5)    # group-level SDs intercepts for education
    τ_male_eth ~ Exponential(0.5)    # group-level SDs intercepts for male:ethnicity
    τ_educ_age ~ Exponential(0.5)    # group-level SDs intercepts for education:age
    τ_educ_eth ~ Exponential(0.5)    # group-level SDs intercepts for education:ethnicity
    zⱼ_state ~ filldist(Normal(), n_gr_state)       # group-level intercepts for state
    zⱼ_eth ~ filldist(Normal(), n_gr_eth)       # group-level intercepts for ethnicity
    zⱼ_educ ~ filldist(Normal(), n_gr_educ)       # group-level intercepts for education
    zⱼ_male_eth ~ filldist(Normal(), n_gr_male_eth)       # group-level intercepts for male:ethnicity interaction
    zⱼ_educ_age ~ filldist(Normal(), n_gr_educ_age)       # group-level intercepts for education:age interaction
    zⱼ_educ_eth ~ filldist(Normal(), n_gr_educ_eth)       # group-level intercepts for education:ethnicity interaction
    αⱼ_state = τ_state .* getindex.((zⱼ_state,), state_idx)       # group-level intercepts for state
    αⱼ_eth = τ_eth .* getindex.((zⱼ_eth,), eth_idx)       # group-level intercepts for ethnicity
    αⱼ_educ = τ_educ .* getindex.((zⱼ_educ,), educ_idx)       # group-level intercepts for education
    αⱼ_male_eth = τ_male_eth .* getindex.((zⱼ_male_eth,), male_eth_idx)       # group-level intercepts for male:ethnicity interaction
    αⱼ_educ_age = τ_educ_age .* getindex.((zⱼ_educ_age,), educ_age_idx)       # group-level intercepts for education:age interaction
    αⱼ_educ_eth = τ_educ_eth .* getindex.((zⱼ_educ_eth,), educ_eth_idx)       # group-level intercepts for education:ethnicity interaction
    v = inter .+ αⱼ_state .+ αⱼ_eth .+ αⱼ_educ .+ αⱼ_male_eth .+ αⱼ_educ_age .+ αⱼ_educ_eth .+ X * β
    #likelihood
    y ~ arraydist(LazyArray(@~ BernoulliLogit.(v)))
    
end;
# Retrieve the number of observations.
n, _ = size(cces_df)
y = cces_df[:, :abortion];
X = Matrix{Float64}(cces_df[:, [:repvote, :region_NE, :region_S, :region_W, :male]]);
state_idx = cces_df[:, :state_idx];
eth_idx = cces_df[:, :eth_idx];
educ_idx = cces_df[:, :educ_idx];
male = cces_df[:, :male];
educ_age_idx = cces_df[:, :educ_age_idx];
educ_eth_idx = cces_df[:, :educ_eth_idx];
male_eth_idx = cces_df[:, :male_eth_idx];

Nadapt = 1000
delta = 0.99

model_intercept = varying_intercept(
        state_idx, eth_idx, male_eth_idx, educ_idx, educ_age_idx, educ_eth_idx, y, X, n, standardize = true
)
chain_intercept = sample(model_intercept, NUTS(Nadapt,delta), MCMCThreads(), 2000, 4)
println(DataFrame(summarystats(chain_intercept)))
1 Like

I ran my code above through a proper benchmark and Turing actually run in comparable time to rstanarm (about 45 minutes). I’ve marked my last post as the solution. Any improvements would simply add an extra speed boost.

2 Likes