Multi-level varying slopes with two clusters, cross classification

More information about centered parameterization

This is equivalent to

symmetrize(A) = Symmetric((A .+ A') ./ 2)
quad_form_diag(M, v) = symmetrize(Diagonal(v) * M * Diagonal(v))

but instead of 4 allocations, it does it all in one allocation. The symmetrize operation projects the matrix to the nearest exactly symmetric matrix. This is only necessary because one of the functions within MvNormal (maybe cholesky) performs a very (perhaps overly) strict check for exact symmetry of the covariance matrix. If M is a correlation matrix, and v is a vector of variances, then quad_form_diag produces the covariance matrix. It may not be the best way to do it numerically.

If it helps, these two lines are more-or-less equivalent:

logit_p = broadcast(treatment, actor, block) do t, a, b
    return @inbounds g[t] + α[t, a] + β[t, b]
end
logit_p = @inbounds getindex.(Ref(g), treatment) .+ getindex.(Ref(α), treatment, actor) .+ getindex.(Ref(β), treatment, block)

but in my view the first is much easier to read. We could also do this with map or a for loop. I didn’t check map, but the broadcasted version was faster than the for loop. Which is fastest will be AD-specific.

LKJ cholesky parameterization

You’re absolutely right that something like Stan’s lkj_corr_cholesky should speed things up a lot. This is because MvNormal’s constructor IIRC performs a cholesky decomposition, which is expensive, plus the LKJ density computes the determinant of the correlation matrix, which can be done cheaply from a cholesky factor. Unfortunately Distributions.jl doesn’t have such a distribution, so we can’t use it in Turing. I took a stab at creating a custom distribution, but I’m bad at using Bijectors, so I wasn’t successful. Perhaps one of the other folks here will have better luck. We definitely should have something like this, and I’d advise opening an issue on the Turing.jl repo to discuss it. We can get something working for Soss.jl though.

LKJ cholesky model with Soss

MeasureTheory.jl has an LKJL distribution for the lower triangular cholesky factor of a correlation matrix drawn from the LKJ distribution. It has some issues, but we can try working around them. Soss.jl works with MeasureTheory, so we’ll use Soss.

# using Pkg
# Pkg.add(["Soss", "MeasureTheory", "TuringModels", "CSV", "DataFrames", "DynamicHMC", "LogDensityProblems"])

using Soss, MeasureTheory, LinearAlgebra, Random
using Soss: TransformVariables

# work around LKJL issues, see https://github.com/cscherrer/MeasureTheory.jl/issues/100
# note that we're making LKJL behave like an LKJU
Soss.xform(d::LKJL, _data::NamedTuple=NamedTuple()) = TransformVariables.as(d);
function Base.rand(rng::AbstractRNG, ::Type, d::LKJL{k}) where {k}
    return cholesky(rand(rng, MeasureTheory.Dists.LKJ(k, d.η))).U
end

# get data
using TuringModels, CSV, DataFrames
data_path = joinpath(TuringModels.project_root, "data", "chimpanzees.csv");
df = CSV.read(data_path, DataFrame; delim=';');
df.block_id = df.block;
df.treatment = 1 .+ df.prosoc_left .+ 2 * df.condition;

# build the model
model = @model actor, block, treatment, n_treatment, n_actor, n_block begin
    σ_block ~ Exponential(1) |> iid(n_treatment)
    σ_actor ~ Exponential(1) |> iid(n_treatment)
    U_ρ_block ~ LKJL(n_treatment, 2)
    U_ρ_actor ~ LKJL(n_treatment, 2)
    g ~ Normal(0, 1) |> iid(n_treatment)
    z_block ~ Normal(0, 1) |> iid(n_treatment, n_block)
    z_actor ~ Normal(0, 1) |> iid(n_treatment, n_actor)
    β = σ_block .* (U_ρ_block' * z_block)
    α = σ_actor .* (U_ρ_actor' * z_actor)
    p = broadcast(treatment, actor, block) do t, a, b
        return @inbounds logistic(g[t] + α[t, a] + β[t, b])
    end
    pulled_left .~ Binomial.(1, p)
end

# condition the model on our data
params = (
    actor=df.actor,
    block=df.block,
    treatment=df.treatment,
    n_treatment=4,
    n_actor=7,
    n_block=6,
)
model_cond = model(params) | (pulled_left=df.pulled_left,)

# get ready to run DynamicHMC
using DynamicHMC, LogDensityProblems
ad_backend = Val(:ForwardDiff)
ℓ(x) = Soss.logdensity(model_cond, x)
trans = Soss.xform(model_cond)
P = LogDensityProblems.TransformedLogDensity(trans, ℓ)
∇P = LogDensityProblems.ADgradient(ad_backend, P)

rng = MersenneTwister(3702)
nchains = 4
results = map(1:nchains) do _
    @time mcmc_with_warmup(rng, ∇P, 1_000; reporter=ProgressMeterReport())
end;
# get posterior samples
posterior = map(r -> P.transformation.(r.chain), results)

Note that we’re using ForwardDiff here, so we expect it to be slow, but we still see major speed-ups, running in ~3 minutes/chain. This PR adds ReverseDiff support to LogDensityProblems and brings the runtime down to ~24s/chain. (just swap in ad_backend = Val(:ReverseDiff)). Not quite Stan’s ~6s/chain on my machine, but not bad.

EDIT: MeasureTheory.Binomial has a parameterization in terms of logitp, if we use that instead of calling logistic, with ReverseDiff the runtime drops to ~16s/chain.

You probably also want to check diagnostics and posterior predictions. Here’s how to get everything you need to use ArviZ.jl:

# extract useful statistics
# adapted from https://github.com/arviz-devs/ArviZ.jl/pull/131
sample_stats = map(results) do r
    map(r.chain, r.tree_statistics) do chn, stat
        term = stat.termination
        return (
            lp=LogDensityProblems.logdensity(P, chn),
            energy=stat.π,
            tree_depth=stat.depth,
            acceptance_rate=stat.acceptance_rate,
            n_steps=stat.steps,
            diverging=term.left == term.right,
            turning=term.left < term.right,
        )
    end
end

# collect posterior predictions and log likelihood
varkeys = keys(posterior[1][1])
model_pred = predictive(model, varkeys...)
posterior_predictive_log_likelihood = map(posterior) do post
    map(post) do draw
        pred = rand(rng, model_pred(merge(params, draw)))
        return (
            pulled_left=pred.pulled_left,
            log_likelihood=logdensity.(Binomial.(1, pred.p), df.pulled_left),
        )
    end
end
posterior_predictive = map(posterior_predictive_log_likelihood) do chn
    map(chn) do draw
        (pulled_left=draw.pulled_left,)
    end
end
log_likelihood = map(posterior_predictive_log_likelihood) do chn
    map(chn) do draw
        (pulled_left=draw.log_likelihood,)
    end
end

# gather samples together for analysis with ArviZ
using ArviZ
idata = from_namedtuple(
    posterior;
    sample_stats=sample_stats,
    posterior_predictive=posterior_predictive,
    log_likelihood=log_likelihood,
    observed_data=(pulled_left=df.pulled_left,),
    dims=Dict(
        :σ_block => [:treatment],
        :σ_actor => [:treatment],
        :g => [:treatment],
        :U_ρ_block => [:treatment, :treatment2],
        :U_ρ_actor => [:treatment, :treatment2],
        :z_block => [:treatment, :block],
        :z_actor => [:treatment, :actor],
    ),
    coords=Dict(:treatment => 1:4, :treatment2 => 1:4, :actor => 1:7, :block => 1:6),
)
summarystats(idata)

</details>
3 Likes