More information about centered parameterization
This is equivalent to
symmetrize(A) = Symmetric((A .+ A') ./ 2)
quad_form_diag(M, v) = symmetrize(Diagonal(v) * M * Diagonal(v))
but instead of 4 allocations, it does it all in one allocation. The symmetrize
operation projects the matrix to the nearest exactly symmetric matrix. This is only necessary because one of the functions within MvNormal
(maybe cholesky
) performs a very (perhaps overly) strict check for exact symmetry of the covariance matrix. If M
is a correlation matrix, and v
is a vector of variances, then quad_form_diag
produces the covariance matrix. It may not be the best way to do it numerically.
If it helps, these two lines are more-or-less equivalent:
logit_p = broadcast(treatment, actor, block) do t, a, b
return @inbounds g[t] + α[t, a] + β[t, b]
end
logit_p = @inbounds getindex.(Ref(g), treatment) .+ getindex.(Ref(α), treatment, actor) .+ getindex.(Ref(β), treatment, block)
but in my view the first is much easier to read. We could also do this with map
or a for
loop. I didn’t check map
, but the broadcasted version was faster than the for
loop. Which is fastest will be AD-specific.
LKJ cholesky parameterization
You’re absolutely right that something like Stan’s lkj_corr_cholesky
should speed things up a lot. This is because MvNormal
’s constructor IIRC performs a cholesky decomposition, which is expensive, plus the LKJ
density computes the determinant of the correlation matrix, which can be done cheaply from a cholesky factor. Unfortunately Distributions.jl doesn’t have such a distribution, so we can’t use it in Turing. I took a stab at creating a custom distribution, but I’m bad at using Bijectors, so I wasn’t successful. Perhaps one of the other folks here will have better luck. We definitely should have something like this, and I’d advise opening an issue on the Turing.jl repo to discuss it. We can get something working for Soss.jl though.
LKJ cholesky model with Soss
MeasureTheory.jl has an LKJL
distribution for the lower triangular cholesky factor of a correlation matrix drawn from the LKJ
distribution. It has some issues, but we can try working around them. Soss.jl works with MeasureTheory, so we’ll use Soss.
# using Pkg
# Pkg.add(["Soss", "MeasureTheory", "TuringModels", "CSV", "DataFrames", "DynamicHMC", "LogDensityProblems"])
using Soss, MeasureTheory, LinearAlgebra, Random
using Soss: TransformVariables
# work around LKJL issues, see https://github.com/cscherrer/MeasureTheory.jl/issues/100
# note that we're making LKJL behave like an LKJU
Soss.xform(d::LKJL, _data::NamedTuple=NamedTuple()) = TransformVariables.as(d);
function Base.rand(rng::AbstractRNG, ::Type, d::LKJL{k}) where {k}
return cholesky(rand(rng, MeasureTheory.Dists.LKJ(k, d.η))).U
end
# get data
using TuringModels, CSV, DataFrames
data_path = joinpath(TuringModels.project_root, "data", "chimpanzees.csv");
df = CSV.read(data_path, DataFrame; delim=';');
df.block_id = df.block;
df.treatment = 1 .+ df.prosoc_left .+ 2 * df.condition;
# build the model
model = @model actor, block, treatment, n_treatment, n_actor, n_block begin
σ_block ~ Exponential(1) |> iid(n_treatment)
σ_actor ~ Exponential(1) |> iid(n_treatment)
U_ρ_block ~ LKJL(n_treatment, 2)
U_ρ_actor ~ LKJL(n_treatment, 2)
g ~ Normal(0, 1) |> iid(n_treatment)
z_block ~ Normal(0, 1) |> iid(n_treatment, n_block)
z_actor ~ Normal(0, 1) |> iid(n_treatment, n_actor)
β = σ_block .* (U_ρ_block' * z_block)
α = σ_actor .* (U_ρ_actor' * z_actor)
p = broadcast(treatment, actor, block) do t, a, b
return @inbounds logistic(g[t] + α[t, a] + β[t, b])
end
pulled_left .~ Binomial.(1, p)
end
# condition the model on our data
params = (
actor=df.actor,
block=df.block,
treatment=df.treatment,
n_treatment=4,
n_actor=7,
n_block=6,
)
model_cond = model(params) | (pulled_left=df.pulled_left,)
# get ready to run DynamicHMC
using DynamicHMC, LogDensityProblems
ad_backend = Val(:ForwardDiff)
ℓ(x) = Soss.logdensity(model_cond, x)
trans = Soss.xform(model_cond)
P = LogDensityProblems.TransformedLogDensity(trans, ℓ)
∇P = LogDensityProblems.ADgradient(ad_backend, P)
rng = MersenneTwister(3702)
nchains = 4
results = map(1:nchains) do _
@time mcmc_with_warmup(rng, ∇P, 1_000; reporter=ProgressMeterReport())
end;
# get posterior samples
posterior = map(r -> P.transformation.(r.chain), results)
Note that we’re using ForwardDiff here, so we expect it to be slow, but we still see major speed-ups, running in ~3 minutes/chain. This PR adds ReverseDiff support to LogDensityProblems and brings the runtime down to ~24s/chain. (just swap in ad_backend = Val(:ReverseDiff)
). Not quite Stan’s ~6s/chain on my machine, but not bad.
EDIT: MeasureTheory.Binomial
has a parameterization in terms of logitp
, if we use that instead of calling logistic
, with ReverseDiff the runtime drops to ~16s/chain.
You probably also want to check diagnostics and posterior predictions. Here’s how to get everything you need to use ArviZ.jl:
# extract useful statistics
# adapted from https://github.com/arviz-devs/ArviZ.jl/pull/131
sample_stats = map(results) do r
map(r.chain, r.tree_statistics) do chn, stat
term = stat.termination
return (
lp=LogDensityProblems.logdensity(P, chn),
energy=stat.π,
tree_depth=stat.depth,
acceptance_rate=stat.acceptance_rate,
n_steps=stat.steps,
diverging=term.left == term.right,
turning=term.left < term.right,
)
end
end
# collect posterior predictions and log likelihood
varkeys = keys(posterior[1][1])
model_pred = predictive(model, varkeys...)
posterior_predictive_log_likelihood = map(posterior) do post
map(post) do draw
pred = rand(rng, model_pred(merge(params, draw)))
return (
pulled_left=pred.pulled_left,
log_likelihood=logdensity.(Binomial.(1, pred.p), df.pulled_left),
)
end
end
posterior_predictive = map(posterior_predictive_log_likelihood) do chn
map(chn) do draw
(pulled_left=draw.pulled_left,)
end
end
log_likelihood = map(posterior_predictive_log_likelihood) do chn
map(chn) do draw
(pulled_left=draw.log_likelihood,)
end
end
# gather samples together for analysis with ArviZ
using ArviZ
idata = from_namedtuple(
posterior;
sample_stats=sample_stats,
posterior_predictive=posterior_predictive,
log_likelihood=log_likelihood,
observed_data=(pulled_left=df.pulled_left,),
dims=Dict(
:σ_block => [:treatment],
:σ_actor => [:treatment],
:g => [:treatment],
:U_ρ_block => [:treatment, :treatment2],
:U_ρ_actor => [:treatment, :treatment2],
:z_block => [:treatment, :block],
:z_actor => [:treatment, :actor],
),
coords=Dict(:treatment => 1:4, :treatment2 => 1:4, :actor => 1:7, :block => 1:6),
)
summarystats(idata)
</details>