Regularized horseshoe prior

Anyone here who has some working Turing.jl (or Soss.jl) code for the regularized horseshoe prior (e.g., Piironen, 2007)?

The horseshoe prior is a regularization technique (reduces the chance of overfitting) for when the number of features very large compared to the number of observations:

image source

OK here are some tries.

Setup

using Turing

using FillArrays
using LinearAlgebra

struct HorseShoePrior{T}
    X::T
end

Original Horseshoe 2009

Based on The Horseshoe+ Prior: Normal Vector Mean:

@model function (m::HorseShoePrior)(::Val{:original})
  # Independent variable
  X = m.X
  J = size(X, 2)

  # Priors
  halfcauchy  = truncated(Cauchy(0, 1); lower=0)
  τ ~ halfcauchy
  λ ~ filldist(halfcauchy, J)
  α ~ TDist(3) # Intercept
  β ~ MvNormal(Diagonal((λ .* τ).^2)) # Coefficients
  σ ~ Exponential(1) # Errors

  # Dependent variable
  y ~ MvNormal(α .+ X * β, σ^2 * I)

  return (; τ, λ, α, β, σ, y)
end

Horseshoe+ Bhadra et al. (2015)

Also based on The Horseshoe+ Prior: Normal Vector Mean

@model function (m::HorseShoePrior)(::Val{:+})
  # Independent variable
  X = m.X
  J = size(X, 2)

  # Priors
  τ ~ truncated(Cauchy(0, 1/J); lower=0)
  η ~ truncated(Cauchy(0, 1); lower=0)
  λ ~ filldist(Cauchy(0, 1), J)
  β ~ MvNormal(Diagonal(((η * τ) .* λ).^2)) # Coefficients
  α ~ TDist(3) # Intercept
  σ ~ Exponential(1) # Errors
  
  # Dependent variable
  y ~ MvNormal(α .+ X * β, σ^2 * I)

  return (; τ, η, λ, β, y)
end

Finnish Horseshoe Prior Piironen & Vehtari (2017)

From appendix C.1 of the arXiv paper using default parameters as denoted in the brms's docstring of the horseprior:

@model function (m::HorseShoePrior)(::Val{:finnish}; τ₀=3, ν_local=1, ν_global=1, slab_df=4, slab_scale=2)
  # Independent variable
  X = m.X
  J = size(X, 2)

  # Priors
  z ~ MvNormal(Zeros(J), I) # Standard Normal for Coefs
  α ~ TDist(3) # Intercept
  σ ~ Exponential(1) # Errors
  λ ~ filldist(truncated(TDist(ν_local); lower=0), J)  # local shrinkage
  τ ~ (τ₀ * σ) * truncated(TDist(ν_global); lower=0)  # global shrinkage
  c_aux ~ InverseGamma(0.5 * slab_df, 0.5 * slab_df)
  
  # Transformations
  c = slab_scale * sqrt(c_aux)
  λtilde = λ ./ hypot.(1, (τ / c) .* λ)
  β = τ .* z .* λtilde # Regression coefficients

  # Dependent variable
  y ~ MvNormal(α .+ X * β,  σ^2 * I)
  return (; τ, σ, λ, λtilde, z, c, c_aux, α, β, y)
end

R2-D2 Prior (Zhang et al., 2020)

Translated from a brms model with make_stancode:

@model function (m::HorseShoePrior)(::Val{:R2D2}; mean_R2=0.5, prec_R2=2, cons_D2=1)
  # Independent variable
  X = m.X
  J = size(X, 2)

  # Priors
  z ~ filldist(Normal(), J)
  α ~ TDist(3) # Intercept
  σ ~ Exponential(1) # Errors
  R2 ~ Beta(mean_R2 * prec_R2, (1 - mean_R2) * prec_R2) # R2 parameter
  ϕ ~ Dirichlet(J, cons_D2)
  τ2 = σ^2 * R2 / (1 - R2)
  
  # Coefficients
  β = z .* sqrt.(ϕ * τ2)

  # Dependent variable
  y ~ MvNormal(α .+ X * β,  σ^2 * I)
  return (; σ, z, ϕ, τ2 , R2, α, β, y)
end

Comparison

# data
X = randn(100, 2)
X = hcat(X, randn(100) * 2) # let's bias this third variable
y = X[:, 3] .+ (randn(100) * 0.1) # y is X[3] plus a 10% Gaussian noise

# models
model_original = HorseShoePrior(X)(Val(:original));
model_plus = HorseShoePrior(X)(Val(:+));
model_finnish = HorseShoePrior(X)(Val(:finnish));
model_R2D2 = HorseShoePrior(X)(Val(:R2D2));

# Condition on data y
model_original_y = model_original | (; y);
model_plus_y = model_plus | (; y);
model_finnish_y = model_finnish | (; y);
model_R2D2_y = model_R2D2 | (; y);

# sample
fit_original = sample(model_original_y, NUTS(), MCMCThreads(), 2_000, 4)
fit_plus = sample(model_plus_y, NUTS(), MCMCThreads(), 2_000, 4)
fit_finnish = sample(model_finnish_y, NUTS(), MCMCThreads(), 2_000, 4)
fit_R2D2 = sample(model_R2D2_y, NUTS(), MCMCThreads(), 2_000, 4)

Results

summarystats(fit_original)

Summary Statistics
  parameters      mean       std   naive_se      mcse         ess      rhat   ess_per_sec
      Symbol   Float64   Float64    Float64   Float64     Float64   Float64       Float64

           τ    0.3967    0.4614     0.0052    0.0083   3073.1052    0.9999      351.7749
        λ[1]    0.4326    0.8874     0.0099    0.0144   3623.3156    1.0001      414.7568
        λ[2]    0.4775    1.0191     0.0114    0.0150   3979.4685    1.0004      455.5252
        λ[3]   14.1351   54.4997     0.6093    1.2510   2405.6428    1.0011      275.3712
           α    0.0131    0.0100     0.0001    0.0001   5136.7337    0.9999      587.9961
        β[1]    0.0015    0.0088     0.0001    0.0001   5177.8127    1.0006      592.6983
        β[2]   -0.0048    0.0088     0.0001    0.0001   5085.0760    1.0001      582.0829
        β[3]    0.9971    0.0050     0.0001    0.0001   6347.7815    0.9997      726.6233
           σ    0.0978    0.0071     0.0001    0.0001   5847.1149    1.0005      669.3126
summarystats(fit_plus)

Summary Statistics
  parameters      mean       std   naive_se      mcse        ess      rhat   ess_per_sec
      Symbol   Float64   Float64    Float64   Float64    Float64   Float64       Float64

           τ    0.3568    0.6939     0.0078    0.0358   352.2589    1.0124        4.7969
           η    0.9706    1.4499     0.0162    0.0606   382.3926    1.0194        5.2072
        λ[1]   -0.0444    1.0414     0.0116    0.0701   160.3538    1.0261        2.1836
        λ[2]   -0.1429    1.5455     0.0173    0.1279    91.6935    1.0493        1.2486
        λ[3]    7.9526   56.0588     0.6268    4.1094    62.7527    1.1085        0.8545
        β[1]    0.0005    0.0080     0.0001    0.0004   187.2131    1.0303        2.5494
        β[2]   -0.0048    0.0086     0.0001    0.0005   304.4416    1.0053        4.1457
        β[3]    0.9967    0.0049     0.0001    0.0002   508.4565    1.0124        6.9239
           α    0.0146    0.0101     0.0001    0.0005   177.0852    1.0228        2.4115
           σ    0.0975    0.0071     0.0001    0.0003   285.9428    1.0109        3.8938
summarystats(fit_finnish)
Summary Statistics
  parameters      mean         std   naive_se      mcse         ess      rhat   ess_per_sec
      Symbol   Float64     Float64    Float64   Float64     Float64   Float64       Float64

        z[1]   -0.2310      0.5224     0.0058    0.0109   2873.7349    1.0009        8.3275
        z[2]    0.1238      0.5840     0.0065    0.0105   2917.6985    1.0003        8.4549
        z[3]    1.2096      0.5956     0.0067    0.0099   2955.8443    1.0005        8.5655
           α   -0.0063      0.0115     0.0001    0.0002   6179.4768    0.9998       17.9069
           σ    0.1124      0.0081     0.0001    0.0001   5128.4019    1.0014       14.8611
        λ[1]    0.6621      2.2831     0.0255    0.0378   3509.4786    1.0005       10.1698
        λ[2]    3.7349    276.3015     3.0891    3.2077   7586.3461    1.0000       21.9837
        λ[3]   61.0742   1637.4816    18.3076   23.4661   4975.3094    1.0002       14.4175
           τ    0.2609      0.3701     0.0041    0.0062   4024.7332    1.0001       11.6629
       c_aux    1.8302      2.9149     0.0326    0.0520   2469.9089    1.0007        7.1573

# Finnish needs to reconstruct the β
using DataFramesMeta

finnish_chain = generated_quantities(model_finnish_y, MCMCChains.get_sections(fit_finnish, :parameters))
finnish_df = reduce(vcat, DataFrame(finnish_chain[:, i]) for i in 1:size(finnish_chain, 2))
@chain finnish_df begin
  @rselect @astable :β = :z * :τ .* :λtilde
  @combine :β_mean = mean(:β)
end

3×1 DataFrame
 Row │ β_mean
     │ Float64
─────┼─────────────
   1 │ -0.00739686
   2 │  0.00389713
   3 │  1.00772
summarystats(fit_R2D2)
Summary Statistics
  parameters      mean       std   naive_se      mcse         ess      rhat   ess_per_sec
      Symbol   Float64   Float64    Float64   Float64     Float64   Float64       Float64

        z[1]   -0.0236    0.1753     0.0020    0.0035   3209.7503    1.0006       13.4425
        z[2]   -0.1035    0.2117     0.0024    0.0051   1705.6669    1.0016        7.1434
        z[3]    2.1017    0.6558     0.0073    0.0128   2565.3850    1.0008       10.7439
           α    0.0138    0.0100     0.0001    0.0001   5088.3175    1.0005       21.3100
           σ    0.0966    0.0073     0.0001    0.0001   4218.4459    1.0011       17.6670
          R2    0.9682    0.0214     0.0002    0.0005   2228.9746    1.0012        9.3350
        ϕ[1]    0.1284    0.1395     0.0016    0.0034   1766.5938    1.0017        7.3985
        ϕ[2]    0.1312    0.1454     0.0016    0.0028   3297.8737    1.0022       13.8116
        ϕ[3]    0.7404    0.1870     0.0021    0.0044   2247.4954    1.0040        9.4126

# R2D2 needs to reconstruct the β
using DataFramesMeta

R2D2_chain = generated_quantities(model_R2D2_y, MCMCChains.get_sections(fit_R2D2, :parameters))
R2D2_df = reduce(vcat, DataFrame(R2D2_chain[:, i]) for i in 1:size(R2D2_chain, 2))
@combine R2D2_df :β_mean = mean(:β)

3×1 DataFrame
 Row │ β_mean
     │ Float64
─────┼─────────────
   1 │ -0.00219241
   2 │ -0.00975412
   3 │  1.00053

So there you go! Immense thanks to @devmotion! Learned tons!

References

  • Carvalho, C. M., Polson, N. G., & Scott, J. G. (2009). Handling sparsity via the horseshoe. In International Conference on Artificial Intelligence and Statistics (pp. 73-80).
  • Bhadra, A., Datta, J., Polson, N. G., & Willard, B. (2015). The Horseshoe+ Estimator of Ultra-Sparse Signals. ArXiv:1502.00560 [Math, Stat] . [1502.00560] The Horseshoe+ Estimator of Ultra-Sparse Signals
  • Piironen, J., & Vehtari, A. (2017). Sparsity information and regularization in the horseshoe and other shrinkage priors. ArXiv:1707.01694 [Stat] . Sparsity information and regularization in the horseshoe and other shrinkage priors
  • Zhang, Y. D., Naughton, B. P., Bondell, H. D., & Reich, B. J. (2020). Bayesian regression using a prior on the model fit: The R2-D2 shrinkage prior. Journal of the American Statistical Association.

Cc @sethaxen, @devmotion.

EDIT: Don’t forget to standardize your X and y. The sampler will thank you :slight_smile:
EDIT2: Suggestions from @devmotion.
EDIT3: Thank you again @devmotion! There is also the R2-D2 prior which is a new method for Bayesian sparse regression with shrinkage. I’ve put in the references.
EDIT4: R2-D2 Prior and typos in finnish.

5 Likes

I’m away from my computer and only had a quick look. However, here are some quick comments before I forget them:

  • The relatively new preferred syntax in Distributions is truncated(dist; [lower], [upper]), thanks to @sethaxen. The default values of the bounds are nothing which allows us to 1) dispatch on half-truncated distributions and 2) avoid weird and undesired promotions (in the next breaking release, currently we still use infinite bounds internally).
  • Covariance matrices of MvNormal should be specified as AbstractMatrix or UniformScaling (if the number of components is clear from the mean). Using scalars or vectors is deprecated - it will be nicer internally, reduce inconsistencies with Normal, and hopefully avoid confusion of users about whether the values are standard deviations or variances.
  • LocationScale is deprecated, one should use + and * instead.
  • I’m surprised that Uniform(-Inf, Inf) works. This improper prior is available as Turing.Flat(), which should be more well-behaved (probably exported).
  • I don’t think you ever want to set J to anything but size(X, 2) so personally I wouldn’t make it an keyword argument but define it in the function body. This also ensures it’s always correct, and I don’t think there are any relevant performance benefits from caching it.
  • The more modern approach is to make y not an atgument of the model but use model | (y,) (ie condition(model, (y,))) if it is observed. Then you can also sample from the unconditioned model without having to specify y = missing as argument (IMO the missing stuff is annoying, both for users and us internally, it’s nice that there is an alternative and we can move away from it).
  • You could use a brand new feature here: DynamicPPL and hence Turing support functors, ie callable structs. You can make a horseshoe model struct with field X and don’t have to make it a model argument.
3 Likes

Thanks for the amazing quick comments.

Yes you are right! I will update the code.

That I don’t know how to do it yet. But seems a very nice flexible way to define your model.

Same as before. I don’t know how to do it yet but seems powerful and handy.

1 Like

There are still some problems with the code that you might want to fix:

  • It still uses the deprecated MvNormal constructors.
  • The Horseshoe+ model is wrong it seems. I didn’t study all details but y does not depend on X - I think you just specified the prior of the coefficients but you forgot intercept, noise, and putting them together (similar to the original model). Maybe similar updates are needed in the third variant, I didn’t check it carefully.
  • The InverseGamma distribution in the third variant is only supported on the positive real line, and hence it is not necessary to truncate it.
  • There are some redundant computations and allocations that can be avoided and might lead to minor performance improvements.

That I don’t know how to do it yet.

Hopefully the following code illustrates what I mean. It contains also some other fixes but the main point is the use of functors without X and y arguments. Of course, one could use different types but I wanted to show that multiple definitions + dispatching works just fine (also a recent bug fix that it works for Val etc.):

using Turing

using FillArrays
using LinearAlgebra

struct HorseShoePrior{T}
    X::T
end

@model function (m::HorseShoePrior)(::Val{:original})
    # Independent variable
    X = m.X
    J = size(X, 2)

    # Priors
    halfcauchy  = truncated(Cauchy(0, 1); lower=0)
    τ ~ halfcauchy
    λ ~ filldist(halfcauchy, J)
    β ~ MvNormal(Diagonal((λ .* τ).^2)) # Coefficients
    α ~ TDist(3) # Intercept
    σ ~ Exponential(1) # Errors
  
    # Dependent variable
    y ~ MvNormal(α .+ X * β, σ^2 * I)

    return (; τ, λ, α, β, σ, y)
end

@model function (m::HorseShoePrior)(::Val{:+})
    # Independent variable
    X = m.X
    J = size(X, 2)

    # Priors
    τ ~ truncated(Cauchy(0, 1/J); lower=0)
    η ~ truncated(Cauchy(0, 1); lower=0)
    λ ~ filldist(Cauchy(0, 1), J)
    β ~ MvNormal(Diagonal(((η * τ) .* λ).^2)) # Coefficients
    α ~ TDist(3) # Intercept
    σ ~ Exponential(1) # Errors
  
    # Dependent variable
    y ~ MvNormal(α .+ X * β, σ^2 * I)

    return (; τ, η, λ, β, y)
end

@model function (m::HorseShoePrior)(::Val{:finish}; τ₀=3, ν_local=1, ν_global=1, half_slab_df=2, slab_scale=2)
    # Independent variable
    X = m.X
    J = size(X, 2)

    # Priors
    β₀ ~ Flat()
    logσ ~ Flat()
    σ = exp(logσ) # Noise std
    z ~ MvNormal(Zeros(J), I)
    λ ~ filldist(truncated(TDist(ν_local); lower=0), J) # Local shrinkage
    τ ~ (τ₀ * σ) * truncated(TDist(ν_global); lower=0) # Global shrinkage
    c_aux ~ InverseGamma(half_slab_df, half_slab_df)

    # Dependent variable
    c = slab_scale * sqrt(c_aux)
    λtilde = λ ./ hypot.(1, (τ / c) .* λ)
    β = τ .* z .* λtilde # Regression coefficients
    f = β₀ .+ X * β # Latent function values
    y ~ MvNormal(f,  I)

    return (; τ, σ, logσ, λ, λtilde, z, c, c_aux, β₀, β, f, y)
end

Let us create some toy data:

julia> X = randn(2, 4);

We can create a model as usual:

julia> model_original = HorseShoePrior(X)(Val(:original));

The main point is that neither X nor y are arguments of the model

julia> model_original.args
(##arg#405 = Val{:original}(),)

julia> DynamicPPL.inargnames(@varname(X), model_original)
false

julia> DynamicPPL.inargnames(@varname(y), model_original)
false

but X is captured by the model function:

julia> model_original.f
HorseShoePrior{Matrix{Float64}}([0.9008876947448639 0.9172982651499152 -1.7703753804116238 1.5009865604420913; 0.15786060957092293 0.6001936008984858 0.8759197705885826 0.9490750485288362])

Since y is not an argument to the model we can sample from the prior without having to use another model instance with y=missing. Instead we can just use

julia> priorsample_original = rand(model_original)
(τ = 3.695731988541993, λ = [0.7107394347599458, 1.0038839300698026, 0.026017515446747786, 0.7473783016558637], β = [-1.4881921127500568, 1.7428254134564192, -0.014855283604250714, -4.405199362018812], α = 1.9879125028306674, σ = 0.5943693342886811, y = [-3.9333591000191186, -1.432671943129296])

(BTW also a recent addition that rand allows to do that and is the official way for sampling from the prior.)
We can evaluate logjoint and logprior probabilities and the loglikelihood of the model for the prior samples:

julia> logjoint(model_original, priorsample_original)
-12.434955305181779

julia> loglikelihood(model_original, priorsample_original)
0.0

julia> logprior(model_original, priorsample_original)
-12.434955305181779

There are no observations here and hence the loglikelihood is zero and the logjoint is equal to the logprior.

Now let us generate some targets

julia> y = rand(2);

We can use these observations in our model by conditioning it on the observations. One can use

julia> model_original_y = model_original | (; y);

which is equivalent to

julia> model_original_y = DynamicPPL.condition(model_original, (; y));

or

julia> model_original_y = DynamicPPL.condition(model_original; y);

The conditioned model will use the provided value for y during model execution everywhere it appears in the model. y did not became an argument

julia> model_original_y.args
(##arg#405 = Val{:original}(),)

but this is done by using a special context information:

julia> model_original_y.context
ConditionContext((y = [-0.15666952029608294, 1.080560882554258],), DynamicPPL.DefaultContext())

Since y is fixed now, it’s value is not returned if we sample from the conditioned model:

julia> sample_original = rand(model_original_y)
(τ = 0.7153433622210013, λ = [3.8969988625444807, 108.0150249313839, 0.26170567941384487, 0.09822866602648996], β = [0.7437719510675507, -58.85226099400912, 0.02501194719031626, -0.18210821639807004], α = 0.5063561352542082, σ = 1.3204010967080868)

Of course, in the conditioned model the loglikelihood is non-zero and logjoint and logprior are different:

julia> logjoint(model_original_y, sample_original)
-1203.0847753639716

julia> loglikelihood(model_original_y, sample_original)
-1177.2934647548811

julia> logprior(model_original_y, sample_original)
-25.791310609090537

And, of course, we can now perform inference as usual, e.g., with NUTS:

julia> sample(model_original_y, NUTS(), 1_000);

The same works also for

julia> model_plus = HorseShoePrior(X)(Val(:+));

julia> model_finish = HorseShoePrior(X)(Val(:finish));
5 Likes

Thanks! I was struggling with how to construct it, the MvNormal docstring does not have any examples.

This is really neat! Thanks! I will start using these in my own coding and research!

That makes really easy for prior predictive checks, again thank you!


I’ve edited and updated all models. Thank you for the tips and suggestions!

3 Likes

Just to notify everyone involved that I’ve added the R2-D2 prior based on Zhang, Y. D., Naughton, B. P., Bondell, H. D., & Reich, B. J. (2020). Bayesian regression using a prior on the model fit: The R2-D2 shrinkage prior. Journal of the American Statistical Association.

3 Likes

Great thread and great implementations. I’ll try them out as well. :muscle:t2:

1 Like