OK here are some tries.
Setup
using Turing
using FillArrays
using LinearAlgebra
struct HorseShoePrior{T}
X::T
end
Original Horseshoe 2009
Based on The Horseshoe+ Prior: Normal Vector Mean:
@model function (m::HorseShoePrior)(::Val{:original})
# Independent variable
X = m.X
J = size(X, 2)
# Priors
halfcauchy = truncated(Cauchy(0, 1); lower=0)
τ ~ halfcauchy
λ ~ filldist(halfcauchy, J)
α ~ TDist(3) # Intercept
β ~ MvNormal(Diagonal((λ .* τ).^2)) # Coefficients
σ ~ Exponential(1) # Errors
# Dependent variable
y ~ MvNormal(α .+ X * β, σ^2 * I)
return (; τ, λ, α, β, σ, y)
end
Horseshoe+ Bhadra et al. (2015)
Also based on The Horseshoe+ Prior: Normal Vector Mean
@model function (m::HorseShoePrior)(::Val{:+})
# Independent variable
X = m.X
J = size(X, 2)
# Priors
τ ~ truncated(Cauchy(0, 1/J); lower=0)
η ~ truncated(Cauchy(0, 1); lower=0)
λ ~ filldist(Cauchy(0, 1), J)
β ~ MvNormal(Diagonal(((η * τ) .* λ).^2)) # Coefficients
α ~ TDist(3) # Intercept
σ ~ Exponential(1) # Errors
# Dependent variable
y ~ MvNormal(α .+ X * β, σ^2 * I)
return (; τ, η, λ, β, y)
end
Finnish Horseshoe Prior Piironen & Vehtari (2017)
From appendix C.1 of the arXiv paper using default parameters as denoted in the brms
’s docstring of the horseprior:
@model function (m::HorseShoePrior)(::Val{:finnish}; τ₀=3, ν_local=1, ν_global=1, slab_df=4, slab_scale=2)
# Independent variable
X = m.X
J = size(X, 2)
# Priors
z ~ MvNormal(Zeros(J), I) # Standard Normal for Coefs
α ~ TDist(3) # Intercept
σ ~ Exponential(1) # Errors
λ ~ filldist(truncated(TDist(ν_local); lower=0), J) # local shrinkage
τ ~ (τ₀ * σ) * truncated(TDist(ν_global); lower=0) # global shrinkage
c_aux ~ InverseGamma(0.5 * slab_df, 0.5 * slab_df)
# Transformations
c = slab_scale * sqrt(c_aux)
λtilde = λ ./ hypot.(1, (τ / c) .* λ)
β = τ .* z .* λtilde # Regression coefficients
# Dependent variable
y ~ MvNormal(α .+ X * β, σ^2 * I)
return (; τ, σ, λ, λtilde, z, c, c_aux, α, β, y)
end
R2-D2 Prior (Zhang et al., 2020)
Translated from a brms
model with make_stancode
:
@model function (m::HorseShoePrior)(::Val{:R2D2}; mean_R2=0.5, prec_R2=2, cons_D2=1)
# Independent variable
X = m.X
J = size(X, 2)
# Priors
z ~ filldist(Normal(), J)
α ~ TDist(3) # Intercept
σ ~ Exponential(1) # Errors
R2 ~ Beta(mean_R2 * prec_R2, (1 - mean_R2) * prec_R2) # R2 parameter
ϕ ~ Dirichlet(J, cons_D2)
τ2 = σ^2 * R2 / (1 - R2)
# Coefficients
β = z .* sqrt.(ϕ * τ2)
# Dependent variable
y ~ MvNormal(α .+ X * β, σ^2 * I)
return (; σ, z, ϕ, τ2 , R2, α, β, y)
end
Comparison
# data
X = randn(100, 2)
X = hcat(X, randn(100) * 2) # let's bias this third variable
y = X[:, 3] .+ (randn(100) * 0.1) # y is X[3] plus a 10% Gaussian noise
# models
model_original = HorseShoePrior(X)(Val(:original));
model_plus = HorseShoePrior(X)(Val(:+));
model_finnish = HorseShoePrior(X)(Val(:finnish));
model_R2D2 = HorseShoePrior(X)(Val(:R2D2));
# Condition on data y
model_original_y = model_original | (; y);
model_plus_y = model_plus | (; y);
model_finnish_y = model_finnish | (; y);
model_R2D2_y = model_R2D2 | (; y);
# sample
fit_original = sample(model_original_y, NUTS(), MCMCThreads(), 2_000, 4)
fit_plus = sample(model_plus_y, NUTS(), MCMCThreads(), 2_000, 4)
fit_finnish = sample(model_finnish_y, NUTS(), MCMCThreads(), 2_000, 4)
fit_R2D2 = sample(model_R2D2_y, NUTS(), MCMCThreads(), 2_000, 4)
Results
summarystats(fit_original)
Summary Statistics
parameters mean std naive_se mcse ess rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
τ 0.3967 0.4614 0.0052 0.0083 3073.1052 0.9999 351.7749
λ[1] 0.4326 0.8874 0.0099 0.0144 3623.3156 1.0001 414.7568
λ[2] 0.4775 1.0191 0.0114 0.0150 3979.4685 1.0004 455.5252
λ[3] 14.1351 54.4997 0.6093 1.2510 2405.6428 1.0011 275.3712
α 0.0131 0.0100 0.0001 0.0001 5136.7337 0.9999 587.9961
β[1] 0.0015 0.0088 0.0001 0.0001 5177.8127 1.0006 592.6983
β[2] -0.0048 0.0088 0.0001 0.0001 5085.0760 1.0001 582.0829
β[3] 0.9971 0.0050 0.0001 0.0001 6347.7815 0.9997 726.6233
σ 0.0978 0.0071 0.0001 0.0001 5847.1149 1.0005 669.3126
summarystats(fit_plus)
Summary Statistics
parameters mean std naive_se mcse ess rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
τ 0.3568 0.6939 0.0078 0.0358 352.2589 1.0124 4.7969
η 0.9706 1.4499 0.0162 0.0606 382.3926 1.0194 5.2072
λ[1] -0.0444 1.0414 0.0116 0.0701 160.3538 1.0261 2.1836
λ[2] -0.1429 1.5455 0.0173 0.1279 91.6935 1.0493 1.2486
λ[3] 7.9526 56.0588 0.6268 4.1094 62.7527 1.1085 0.8545
β[1] 0.0005 0.0080 0.0001 0.0004 187.2131 1.0303 2.5494
β[2] -0.0048 0.0086 0.0001 0.0005 304.4416 1.0053 4.1457
β[3] 0.9967 0.0049 0.0001 0.0002 508.4565 1.0124 6.9239
α 0.0146 0.0101 0.0001 0.0005 177.0852 1.0228 2.4115
σ 0.0975 0.0071 0.0001 0.0003 285.9428 1.0109 3.8938
summarystats(fit_finnish)
Summary Statistics
parameters mean std naive_se mcse ess rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
z[1] -0.2310 0.5224 0.0058 0.0109 2873.7349 1.0009 8.3275
z[2] 0.1238 0.5840 0.0065 0.0105 2917.6985 1.0003 8.4549
z[3] 1.2096 0.5956 0.0067 0.0099 2955.8443 1.0005 8.5655
α -0.0063 0.0115 0.0001 0.0002 6179.4768 0.9998 17.9069
σ 0.1124 0.0081 0.0001 0.0001 5128.4019 1.0014 14.8611
λ[1] 0.6621 2.2831 0.0255 0.0378 3509.4786 1.0005 10.1698
λ[2] 3.7349 276.3015 3.0891 3.2077 7586.3461 1.0000 21.9837
λ[3] 61.0742 1637.4816 18.3076 23.4661 4975.3094 1.0002 14.4175
τ 0.2609 0.3701 0.0041 0.0062 4024.7332 1.0001 11.6629
c_aux 1.8302 2.9149 0.0326 0.0520 2469.9089 1.0007 7.1573
# Finnish needs to reconstruct the β
using DataFramesMeta
finnish_chain = generated_quantities(model_finnish_y, MCMCChains.get_sections(fit_finnish, :parameters))
finnish_df = reduce(vcat, DataFrame(finnish_chain[:, i]) for i in 1:size(finnish_chain, 2))
@chain finnish_df begin
@rselect @astable :β = :z * :τ .* :λtilde
@combine :β_mean = mean(:β)
end
3×1 DataFrame
Row │ β_mean
│ Float64
─────┼─────────────
1 │ -0.00739686
2 │ 0.00389713
3 │ 1.00772
summarystats(fit_R2D2)
Summary Statistics
parameters mean std naive_se mcse ess rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
z[1] -0.0236 0.1753 0.0020 0.0035 3209.7503 1.0006 13.4425
z[2] -0.1035 0.2117 0.0024 0.0051 1705.6669 1.0016 7.1434
z[3] 2.1017 0.6558 0.0073 0.0128 2565.3850 1.0008 10.7439
α 0.0138 0.0100 0.0001 0.0001 5088.3175 1.0005 21.3100
σ 0.0966 0.0073 0.0001 0.0001 4218.4459 1.0011 17.6670
R2 0.9682 0.0214 0.0002 0.0005 2228.9746 1.0012 9.3350
ϕ[1] 0.1284 0.1395 0.0016 0.0034 1766.5938 1.0017 7.3985
ϕ[2] 0.1312 0.1454 0.0016 0.0028 3297.8737 1.0022 13.8116
ϕ[3] 0.7404 0.1870 0.0021 0.0044 2247.4954 1.0040 9.4126
# R2D2 needs to reconstruct the β
using DataFramesMeta
R2D2_chain = generated_quantities(model_R2D2_y, MCMCChains.get_sections(fit_R2D2, :parameters))
R2D2_df = reduce(vcat, DataFrame(R2D2_chain[:, i]) for i in 1:size(R2D2_chain, 2))
@combine R2D2_df :β_mean = mean(:β)
3×1 DataFrame
Row │ β_mean
│ Float64
─────┼─────────────
1 │ -0.00219241
2 │ -0.00975412
3 │ 1.00053
So there you go! Immense thanks to @devmotion! Learned tons!
References
- Carvalho, C. M., Polson, N. G., & Scott, J. G. (2009). Handling sparsity via the horseshoe. In International Conference on Artificial Intelligence and Statistics (pp. 73-80).
- Bhadra, A., Datta, J., Polson, N. G., & Willard, B. (2015). The Horseshoe+ Estimator of Ultra-Sparse Signals. ArXiv:1502.00560 [Math, Stat] . [1502.00560] The Horseshoe+ Estimator of Ultra-Sparse Signals
- Piironen, J., & Vehtari, A. (2017). Sparsity information and regularization in the horseshoe and other shrinkage priors. ArXiv:1707.01694 [Stat] . Sparsity information and regularization in the horseshoe and other shrinkage priors
- Zhang, Y. D., Naughton, B. P., Bondell, H. D., & Reich, B. J. (2020). Bayesian regression using a prior on the model fit: The R2-D2 shrinkage prior. Journal of the American Statistical Association.
Cc @sethaxen, @devmotion.
EDIT: Don’t forget to standardize your X
and y
. The sampler will thank you
EDIT2: Suggestions from @devmotion.
EDIT3: Thank you again @devmotion! There is also the R2-D2 prior which is a new method for Bayesian sparse regression with shrinkage. I’ve put in the references.
EDIT4: R2-D2 Prior and typos in finnish.