Regularized horseshoe prior

OK here are some tries.

Setup

using Turing

using FillArrays
using LinearAlgebra

struct HorseShoePrior{T}
    X::T
end

Original Horseshoe 2009

Based on The Horseshoe+ Prior: Normal Vector Mean:

@model function (m::HorseShoePrior)(::Val{:original})
  # Independent variable
  X = m.X
  J = size(X, 2)

  # Priors
  halfcauchy  = truncated(Cauchy(0, 1); lower=0)
  τ ~ halfcauchy
  λ ~ filldist(halfcauchy, J)
  α ~ TDist(3) # Intercept
  β ~ MvNormal(Diagonal((λ .* τ).^2)) # Coefficients
  σ ~ Exponential(1) # Errors

  # Dependent variable
  y ~ MvNormal(α .+ X * β, σ^2 * I)

  return (; τ, λ, α, β, σ, y)
end

Horseshoe+ Bhadra et al. (2015)

Also based on The Horseshoe+ Prior: Normal Vector Mean

@model function (m::HorseShoePrior)(::Val{:+})
  # Independent variable
  X = m.X
  J = size(X, 2)

  # Priors
  τ ~ truncated(Cauchy(0, 1/J); lower=0)
  η ~ truncated(Cauchy(0, 1); lower=0)
  λ ~ filldist(Cauchy(0, 1), J)
  β ~ MvNormal(Diagonal(((η * τ) .* λ).^2)) # Coefficients
  α ~ TDist(3) # Intercept
  σ ~ Exponential(1) # Errors
  
  # Dependent variable
  y ~ MvNormal(α .+ X * β, σ^2 * I)

  return (; τ, η, λ, β, y)
end

Finnish Horseshoe Prior Piironen & Vehtari (2017)

From appendix C.1 of the arXiv paper using default parameters as denoted in the brms’s docstring of the horseprior:

@model function (m::HorseShoePrior)(::Val{:finnish}; τ₀=3, ν_local=1, ν_global=1, slab_df=4, slab_scale=2)
  # Independent variable
  X = m.X
  J = size(X, 2)

  # Priors
  z ~ MvNormal(Zeros(J), I) # Standard Normal for Coefs
  α ~ TDist(3) # Intercept
  σ ~ Exponential(1) # Errors
  λ ~ filldist(truncated(TDist(ν_local); lower=0), J)  # local shrinkage
  τ ~ (τ₀ * σ) * truncated(TDist(ν_global); lower=0)  # global shrinkage
  c_aux ~ InverseGamma(0.5 * slab_df, 0.5 * slab_df)
  
  # Transformations
  c = slab_scale * sqrt(c_aux)
  λtilde = λ ./ hypot.(1, (τ / c) .* λ)
  β = τ .* z .* λtilde # Regression coefficients

  # Dependent variable
  y ~ MvNormal(α .+ X * β,  σ^2 * I)
  return (; τ, σ, λ, λtilde, z, c, c_aux, α, β, y)
end

R2-D2 Prior (Zhang et al., 2020)

Translated from a brms model with make_stancode:

@model function (m::HorseShoePrior)(::Val{:R2D2}; mean_R2=0.5, prec_R2=2, cons_D2=1)
  # Independent variable
  X = m.X
  J = size(X, 2)

  # Priors
  z ~ filldist(Normal(), J)
  α ~ TDist(3) # Intercept
  σ ~ Exponential(1) # Errors
  R2 ~ Beta(mean_R2 * prec_R2, (1 - mean_R2) * prec_R2) # R2 parameter
  ϕ ~ Dirichlet(J, cons_D2)
  τ2 = σ^2 * R2 / (1 - R2)
  
  # Coefficients
  β = z .* sqrt.(ϕ * τ2)

  # Dependent variable
  y ~ MvNormal(α .+ X * β,  σ^2 * I)
  return (; σ, z, ϕ, τ2 , R2, α, β, y)
end

Comparison

# data
X = randn(100, 2)
X = hcat(X, randn(100) * 2) # let's bias this third variable
y = X[:, 3] .+ (randn(100) * 0.1) # y is X[3] plus a 10% Gaussian noise

# models
model_original = HorseShoePrior(X)(Val(:original));
model_plus = HorseShoePrior(X)(Val(:+));
model_finnish = HorseShoePrior(X)(Val(:finnish));
model_R2D2 = HorseShoePrior(X)(Val(:R2D2));

# Condition on data y
model_original_y = model_original | (; y);
model_plus_y = model_plus | (; y);
model_finnish_y = model_finnish | (; y);
model_R2D2_y = model_R2D2 | (; y);

# sample
fit_original = sample(model_original_y, NUTS(), MCMCThreads(), 2_000, 4)
fit_plus = sample(model_plus_y, NUTS(), MCMCThreads(), 2_000, 4)
fit_finnish = sample(model_finnish_y, NUTS(), MCMCThreads(), 2_000, 4)
fit_R2D2 = sample(model_R2D2_y, NUTS(), MCMCThreads(), 2_000, 4)

Results

summarystats(fit_original)

Summary Statistics
  parameters      mean       std   naive_se      mcse         ess      rhat   ess_per_sec
      Symbol   Float64   Float64    Float64   Float64     Float64   Float64       Float64

           τ    0.3967    0.4614     0.0052    0.0083   3073.1052    0.9999      351.7749
        λ[1]    0.4326    0.8874     0.0099    0.0144   3623.3156    1.0001      414.7568
        λ[2]    0.4775    1.0191     0.0114    0.0150   3979.4685    1.0004      455.5252
        λ[3]   14.1351   54.4997     0.6093    1.2510   2405.6428    1.0011      275.3712
           α    0.0131    0.0100     0.0001    0.0001   5136.7337    0.9999      587.9961
        β[1]    0.0015    0.0088     0.0001    0.0001   5177.8127    1.0006      592.6983
        β[2]   -0.0048    0.0088     0.0001    0.0001   5085.0760    1.0001      582.0829
        β[3]    0.9971    0.0050     0.0001    0.0001   6347.7815    0.9997      726.6233
           σ    0.0978    0.0071     0.0001    0.0001   5847.1149    1.0005      669.3126
summarystats(fit_plus)

Summary Statistics
  parameters      mean       std   naive_se      mcse        ess      rhat   ess_per_sec
      Symbol   Float64   Float64    Float64   Float64    Float64   Float64       Float64

           τ    0.3568    0.6939     0.0078    0.0358   352.2589    1.0124        4.7969
           η    0.9706    1.4499     0.0162    0.0606   382.3926    1.0194        5.2072
        λ[1]   -0.0444    1.0414     0.0116    0.0701   160.3538    1.0261        2.1836
        λ[2]   -0.1429    1.5455     0.0173    0.1279    91.6935    1.0493        1.2486
        λ[3]    7.9526   56.0588     0.6268    4.1094    62.7527    1.1085        0.8545
        β[1]    0.0005    0.0080     0.0001    0.0004   187.2131    1.0303        2.5494
        β[2]   -0.0048    0.0086     0.0001    0.0005   304.4416    1.0053        4.1457
        β[3]    0.9967    0.0049     0.0001    0.0002   508.4565    1.0124        6.9239
           α    0.0146    0.0101     0.0001    0.0005   177.0852    1.0228        2.4115
           σ    0.0975    0.0071     0.0001    0.0003   285.9428    1.0109        3.8938
summarystats(fit_finnish)
Summary Statistics
  parameters      mean         std   naive_se      mcse         ess      rhat   ess_per_sec
      Symbol   Float64     Float64    Float64   Float64     Float64   Float64       Float64

        z[1]   -0.2310      0.5224     0.0058    0.0109   2873.7349    1.0009        8.3275
        z[2]    0.1238      0.5840     0.0065    0.0105   2917.6985    1.0003        8.4549
        z[3]    1.2096      0.5956     0.0067    0.0099   2955.8443    1.0005        8.5655
           α   -0.0063      0.0115     0.0001    0.0002   6179.4768    0.9998       17.9069
           σ    0.1124      0.0081     0.0001    0.0001   5128.4019    1.0014       14.8611
        λ[1]    0.6621      2.2831     0.0255    0.0378   3509.4786    1.0005       10.1698
        λ[2]    3.7349    276.3015     3.0891    3.2077   7586.3461    1.0000       21.9837
        λ[3]   61.0742   1637.4816    18.3076   23.4661   4975.3094    1.0002       14.4175
           τ    0.2609      0.3701     0.0041    0.0062   4024.7332    1.0001       11.6629
       c_aux    1.8302      2.9149     0.0326    0.0520   2469.9089    1.0007        7.1573

# Finnish needs to reconstruct the β
using DataFramesMeta

finnish_chain = generated_quantities(model_finnish_y, MCMCChains.get_sections(fit_finnish, :parameters))
finnish_df = reduce(vcat, DataFrame(finnish_chain[:, i]) for i in 1:size(finnish_chain, 2))
@chain finnish_df begin
  @rselect @astable :β = :z * :τ .* :λtilde
  @combine :β_mean = mean(:β)
end

3×1 DataFrame
 Row │ β_mean
     │ Float64
─────┼─────────────
   1 │ -0.00739686
   2 │  0.00389713
   3 │  1.00772
summarystats(fit_R2D2)
Summary Statistics
  parameters      mean       std   naive_se      mcse         ess      rhat   ess_per_sec
      Symbol   Float64   Float64    Float64   Float64     Float64   Float64       Float64

        z[1]   -0.0236    0.1753     0.0020    0.0035   3209.7503    1.0006       13.4425
        z[2]   -0.1035    0.2117     0.0024    0.0051   1705.6669    1.0016        7.1434
        z[3]    2.1017    0.6558     0.0073    0.0128   2565.3850    1.0008       10.7439
           α    0.0138    0.0100     0.0001    0.0001   5088.3175    1.0005       21.3100
           σ    0.0966    0.0073     0.0001    0.0001   4218.4459    1.0011       17.6670
          R2    0.9682    0.0214     0.0002    0.0005   2228.9746    1.0012        9.3350
        ϕ[1]    0.1284    0.1395     0.0016    0.0034   1766.5938    1.0017        7.3985
        ϕ[2]    0.1312    0.1454     0.0016    0.0028   3297.8737    1.0022       13.8116
        ϕ[3]    0.7404    0.1870     0.0021    0.0044   2247.4954    1.0040        9.4126

# R2D2 needs to reconstruct the β
using DataFramesMeta

R2D2_chain = generated_quantities(model_R2D2_y, MCMCChains.get_sections(fit_R2D2, :parameters))
R2D2_df = reduce(vcat, DataFrame(R2D2_chain[:, i]) for i in 1:size(R2D2_chain, 2))
@combine R2D2_df :β_mean = mean(:β)

3×1 DataFrame
 Row │ β_mean
     │ Float64
─────┼─────────────
   1 │ -0.00219241
   2 │ -0.00975412
   3 │  1.00053

So there you go! Immense thanks to @devmotion! Learned tons!

References

  • Carvalho, C. M., Polson, N. G., & Scott, J. G. (2009). Handling sparsity via the horseshoe. In International Conference on Artificial Intelligence and Statistics (pp. 73-80).
  • Bhadra, A., Datta, J., Polson, N. G., & Willard, B. (2015). The Horseshoe+ Estimator of Ultra-Sparse Signals. ArXiv:1502.00560 [Math, Stat] . [1502.00560] The Horseshoe+ Estimator of Ultra-Sparse Signals
  • Piironen, J., & Vehtari, A. (2017). Sparsity information and regularization in the horseshoe and other shrinkage priors. ArXiv:1707.01694 [Stat] . Sparsity information and regularization in the horseshoe and other shrinkage priors
  • Zhang, Y. D., Naughton, B. P., Bondell, H. D., & Reich, B. J. (2020). Bayesian regression using a prior on the model fit: The R2-D2 shrinkage prior. Journal of the American Statistical Association.

Cc @sethaxen, @devmotion.

EDIT: Don’t forget to standardize your X and y. The sampler will thank you :slight_smile:
EDIT2: Suggestions from @devmotion.
EDIT3: Thank you again @devmotion! There is also the R2-D2 prior which is a new method for Bayesian sparse regression with shrinkage. I’ve put in the references.
EDIT4: R2-D2 Prior and typos in finnish.

9 Likes