# Regularized horseshoe prior

Anyone here who has some working Turing.jl (or Soss.jl) code for the regularized horseshoe prior (e.g., Piironen, 2007)?

The horseshoe prior is a regularization technique (reduces the chance of overfitting) for when the number of features very large compared to the number of observations:

image source

OK here are some tries.

## Setup

``````using Turing

using FillArrays
using LinearAlgebra

struct HorseShoePrior{T}
X::T
end
``````

## Original Horseshoe 2009

``````@model function (m::HorseShoePrior)(::Val{:original})
# Independent variable
X = m.X
J = size(X, 2)

# Priors
halfcauchy  = truncated(Cauchy(0, 1); lower=0)
τ ~ halfcauchy
λ ~ filldist(halfcauchy, J)
α ~ TDist(3) # Intercept
β ~ MvNormal(Diagonal((λ .* τ).^2)) # Coefficients
σ ~ Exponential(1) # Errors

# Dependent variable
y ~ MvNormal(α .+ X * β, σ^2 * I)

return (; τ, λ, α, β, σ, y)
end
``````

## Horseshoe+ Bhadra et al. (2015)

Also based on The Horseshoe+ Prior: Normal Vector Mean

``````@model function (m::HorseShoePrior)(::Val{:+})
# Independent variable
X = m.X
J = size(X, 2)

# Priors
τ ~ truncated(Cauchy(0, 1/J); lower=0)
η ~ truncated(Cauchy(0, 1); lower=0)
λ ~ filldist(Cauchy(0, 1), J)
β ~ MvNormal(Diagonal(((η * τ) .* λ).^2)) # Coefficients
α ~ TDist(3) # Intercept
σ ~ Exponential(1) # Errors

# Dependent variable
y ~ MvNormal(α .+ X * β, σ^2 * I)

return (; τ, η, λ, β, y)
end
``````

## Finnish Horseshoe Prior Piironen & Vehtari (2017)

From appendix C.1 of the arXiv paper using default parameters as denoted in the `brms`'s docstring of the horseprior:

``````@model function (m::HorseShoePrior)(::Val{:finnish}; τ₀=3, ν_local=1, ν_global=1, slab_df=4, slab_scale=2)
# Independent variable
X = m.X
J = size(X, 2)

# Priors
z ~ MvNormal(Zeros(J), I) # Standard Normal for Coefs
α ~ TDist(3) # Intercept
σ ~ Exponential(1) # Errors
λ ~ filldist(truncated(TDist(ν_local); lower=0), J)  # local shrinkage
τ ~ (τ₀ * σ) * truncated(TDist(ν_global); lower=0)  # global shrinkage
c_aux ~ InverseGamma(0.5 * slab_df, 0.5 * slab_df)

# Transformations
c = slab_scale * sqrt(c_aux)
λtilde = λ ./ hypot.(1, (τ / c) .* λ)
β = τ .* z .* λtilde # Regression coefficients

# Dependent variable
y ~ MvNormal(α .+ X * β,  σ^2 * I)
return (; τ, σ, λ, λtilde, z, c, c_aux, α, β, y)
end
``````

## R2-D2 Prior (Zhang et al., 2020)

Translated from a `brms` model with `make_stancode`:

``````@model function (m::HorseShoePrior)(::Val{:R2D2}; mean_R2=0.5, prec_R2=2, cons_D2=1)
# Independent variable
X = m.X
J = size(X, 2)

# Priors
z ~ filldist(Normal(), J)
α ~ TDist(3) # Intercept
σ ~ Exponential(1) # Errors
R2 ~ Beta(mean_R2 * prec_R2, (1 - mean_R2) * prec_R2) # R2 parameter
ϕ ~ Dirichlet(J, cons_D2)
τ2 = σ^2 * R2 / (1 - R2)

# Coefficients
β = z .* sqrt.(ϕ * τ2)

# Dependent variable
y ~ MvNormal(α .+ X * β,  σ^2 * I)
return (; σ, z, ϕ, τ2 , R2, α, β, y)
end
``````

## Comparison

``````# data
X = randn(100, 2)
X = hcat(X, randn(100) * 2) # let's bias this third variable
y = X[:, 3] .+ (randn(100) * 0.1) # y is X plus a 10% Gaussian noise

# models
model_original = HorseShoePrior(X)(Val(:original));
model_plus = HorseShoePrior(X)(Val(:+));
model_finnish = HorseShoePrior(X)(Val(:finnish));
model_R2D2 = HorseShoePrior(X)(Val(:R2D2));

# Condition on data y
model_original_y = model_original | (; y);
model_plus_y = model_plus | (; y);
model_finnish_y = model_finnish | (; y);
model_R2D2_y = model_R2D2 | (; y);

# sample
fit_original = sample(model_original_y, NUTS(), MCMCThreads(), 2_000, 4)
fit_plus = sample(model_plus_y, NUTS(), MCMCThreads(), 2_000, 4)
fit_finnish = sample(model_finnish_y, NUTS(), MCMCThreads(), 2_000, 4)
fit_R2D2 = sample(model_R2D2_y, NUTS(), MCMCThreads(), 2_000, 4)
``````

### Results

``````summarystats(fit_original)

Summary Statistics
parameters      mean       std   naive_se      mcse         ess      rhat   ess_per_sec
Symbol   Float64   Float64    Float64   Float64     Float64   Float64       Float64

τ    0.3967    0.4614     0.0052    0.0083   3073.1052    0.9999      351.7749
λ    0.4326    0.8874     0.0099    0.0144   3623.3156    1.0001      414.7568
λ    0.4775    1.0191     0.0114    0.0150   3979.4685    1.0004      455.5252
λ   14.1351   54.4997     0.6093    1.2510   2405.6428    1.0011      275.3712
α    0.0131    0.0100     0.0001    0.0001   5136.7337    0.9999      587.9961
β    0.0015    0.0088     0.0001    0.0001   5177.8127    1.0006      592.6983
β   -0.0048    0.0088     0.0001    0.0001   5085.0760    1.0001      582.0829
β    0.9971    0.0050     0.0001    0.0001   6347.7815    0.9997      726.6233
σ    0.0978    0.0071     0.0001    0.0001   5847.1149    1.0005      669.3126
``````
``````summarystats(fit_plus)

Summary Statistics
parameters      mean       std   naive_se      mcse        ess      rhat   ess_per_sec
Symbol   Float64   Float64    Float64   Float64    Float64   Float64       Float64

τ    0.3568    0.6939     0.0078    0.0358   352.2589    1.0124        4.7969
η    0.9706    1.4499     0.0162    0.0606   382.3926    1.0194        5.2072
λ   -0.0444    1.0414     0.0116    0.0701   160.3538    1.0261        2.1836
λ   -0.1429    1.5455     0.0173    0.1279    91.6935    1.0493        1.2486
λ    7.9526   56.0588     0.6268    4.1094    62.7527    1.1085        0.8545
β    0.0005    0.0080     0.0001    0.0004   187.2131    1.0303        2.5494
β   -0.0048    0.0086     0.0001    0.0005   304.4416    1.0053        4.1457
β    0.9967    0.0049     0.0001    0.0002   508.4565    1.0124        6.9239
α    0.0146    0.0101     0.0001    0.0005   177.0852    1.0228        2.4115
σ    0.0975    0.0071     0.0001    0.0003   285.9428    1.0109        3.8938
``````
``````summarystats(fit_finnish)
Summary Statistics
parameters      mean         std   naive_se      mcse         ess      rhat   ess_per_sec
Symbol   Float64     Float64    Float64   Float64     Float64   Float64       Float64

z   -0.2310      0.5224     0.0058    0.0109   2873.7349    1.0009        8.3275
z    0.1238      0.5840     0.0065    0.0105   2917.6985    1.0003        8.4549
z    1.2096      0.5956     0.0067    0.0099   2955.8443    1.0005        8.5655
α   -0.0063      0.0115     0.0001    0.0002   6179.4768    0.9998       17.9069
σ    0.1124      0.0081     0.0001    0.0001   5128.4019    1.0014       14.8611
λ    0.6621      2.2831     0.0255    0.0378   3509.4786    1.0005       10.1698
λ    3.7349    276.3015     3.0891    3.2077   7586.3461    1.0000       21.9837
λ   61.0742   1637.4816    18.3076   23.4661   4975.3094    1.0002       14.4175
τ    0.2609      0.3701     0.0041    0.0062   4024.7332    1.0001       11.6629
c_aux    1.8302      2.9149     0.0326    0.0520   2469.9089    1.0007        7.1573

# Finnish needs to reconstruct the β
using DataFramesMeta

finnish_chain = generated_quantities(model_finnish_y, MCMCChains.get_sections(fit_finnish, :parameters))
finnish_df = reduce(vcat, DataFrame(finnish_chain[:, i]) for i in 1:size(finnish_chain, 2))
@chain finnish_df begin
@rselect @astable :β = :z * :τ .* :λtilde
@combine :β_mean = mean(:β)
end

3×1 DataFrame
Row │ β_mean
│ Float64
─────┼─────────────
1 │ -0.00739686
2 │  0.00389713
3 │  1.00772
``````
``````summarystats(fit_R2D2)
Summary Statistics
parameters      mean       std   naive_se      mcse         ess      rhat   ess_per_sec
Symbol   Float64   Float64    Float64   Float64     Float64   Float64       Float64

z   -0.0236    0.1753     0.0020    0.0035   3209.7503    1.0006       13.4425
z   -0.1035    0.2117     0.0024    0.0051   1705.6669    1.0016        7.1434
z    2.1017    0.6558     0.0073    0.0128   2565.3850    1.0008       10.7439
α    0.0138    0.0100     0.0001    0.0001   5088.3175    1.0005       21.3100
σ    0.0966    0.0073     0.0001    0.0001   4218.4459    1.0011       17.6670
R2    0.9682    0.0214     0.0002    0.0005   2228.9746    1.0012        9.3350
ϕ    0.1284    0.1395     0.0016    0.0034   1766.5938    1.0017        7.3985
ϕ    0.1312    0.1454     0.0016    0.0028   3297.8737    1.0022       13.8116
ϕ    0.7404    0.1870     0.0021    0.0044   2247.4954    1.0040        9.4126

# R2D2 needs to reconstruct the β
using DataFramesMeta

R2D2_chain = generated_quantities(model_R2D2_y, MCMCChains.get_sections(fit_R2D2, :parameters))
R2D2_df = reduce(vcat, DataFrame(R2D2_chain[:, i]) for i in 1:size(R2D2_chain, 2))
@combine R2D2_df :β_mean = mean(:β)

3×1 DataFrame
Row │ β_mean
│ Float64
─────┼─────────────
1 │ -0.00219241
2 │ -0.00975412
3 │  1.00053
``````

So there you go! Immense thanks to @devmotion! Learned tons!

## References

• Carvalho, C. M., Polson, N. G., & Scott, J. G. (2009). Handling sparsity via the horseshoe. In International Conference on Artificial Intelligence and Statistics (pp. 73-80).
• Bhadra, A., Datta, J., Polson, N. G., & Willard, B. (2015). The Horseshoe+ Estimator of Ultra-Sparse Signals. ArXiv:1502.00560 [Math, Stat] . [1502.00560] The Horseshoe+ Estimator of Ultra-Sparse Signals
• Piironen, J., & Vehtari, A. (2017). Sparsity information and regularization in the horseshoe and other shrinkage priors. ArXiv:1707.01694 [Stat] . Sparsity information and regularization in the horseshoe and other shrinkage priors
• Zhang, Y. D., Naughton, B. P., Bondell, H. D., & Reich, B. J. (2020). Bayesian regression using a prior on the model fit: The R2-D2 shrinkage prior. Journal of the American Statistical Association.

EDIT: Don’t forget to standardize your `X` and `y`. The sampler will thank you EDIT2: Suggestions from @devmotion.
EDIT3: Thank you again @devmotion! There is also the R2-D2 prior which is a new method for Bayesian sparse regression with shrinkage. I’ve put in the references.
EDIT4: R2-D2 Prior and typos in finnish.

5 Likes

I’m away from my computer and only had a quick look. However, here are some quick comments before I forget them:

• The relatively new preferred syntax in Distributions is `truncated(dist; [lower], [upper])`, thanks to @sethaxen. The default values of the bounds are `nothing` which allows us to 1) dispatch on half-truncated distributions and 2) avoid weird and undesired promotions (in the next breaking release, currently we still use infinite bounds internally).
• Covariance matrices of `MvNormal` should be specified as `AbstractMatrix` or `UniformScaling` (if the number of components is clear from the mean). Using scalars or vectors is deprecated - it will be nicer internally, reduce inconsistencies with `Normal`, and hopefully avoid confusion of users about whether the values are standard deviations or variances.
• `LocationScale` is deprecated, one should use `+` and `*` instead.
• I’m surprised that `Uniform(-Inf, Inf)` works. This improper prior is available as `Turing.Flat()`, which should be more well-behaved (probably exported).
• I don’t think you ever want to set `J` to anything but `size(X, 2)` so personally I wouldn’t make it an keyword argument but define it in the function body. This also ensures it’s always correct, and I don’t think there are any relevant performance benefits from caching it.
• The more modern approach is to make `y` not an atgument of the model but use `model | (y,)` (ie `condition(model, (y,))`) if it is observed. Then you can also sample from the unconditioned model without having to specify `y = missing` as argument (IMO the `missing` stuff is annoying, both for users and us internally, it’s nice that there is an alternative and we can move away from it).
• You could use a brand new feature here: DynamicPPL and hence Turing support functors, ie callable structs. You can make a horseshoe model struct with field `X` and don’t have to make it a model argument.
3 Likes

Thanks for the amazing quick comments.

Yes you are right! I will update the code.

That I don’t know how to do it yet. But seems a very nice flexible way to define your model.

Same as before. I don’t know how to do it yet but seems powerful and handy.

1 Like

There are still some problems with the code that you might want to fix:

• It still uses the deprecated `MvNormal` constructors.
• The Horseshoe+ model is wrong it seems. I didn’t study all details but `y` does not depend on `X` - I think you just specified the prior of the coefficients but you forgot intercept, noise, and putting them together (similar to the original model). Maybe similar updates are needed in the third variant, I didn’t check it carefully.
• The `InverseGamma` distribution in the third variant is only supported on the positive real line, and hence it is not necessary to truncate it.
• There are some redundant computations and allocations that can be avoided and might lead to minor performance improvements.

That I don’t know how to do it yet.

Hopefully the following code illustrates what I mean. It contains also some other fixes but the main point is the use of functors without `X` and `y` arguments. Of course, one could use different types but I wanted to show that multiple definitions + dispatching works just fine (also a recent bug fix that it works for `Val` etc.):

``````using Turing

using FillArrays
using LinearAlgebra

struct HorseShoePrior{T}
X::T
end

@model function (m::HorseShoePrior)(::Val{:original})
# Independent variable
X = m.X
J = size(X, 2)

# Priors
halfcauchy  = truncated(Cauchy(0, 1); lower=0)
τ ~ halfcauchy
λ ~ filldist(halfcauchy, J)
β ~ MvNormal(Diagonal((λ .* τ).^2)) # Coefficients
α ~ TDist(3) # Intercept
σ ~ Exponential(1) # Errors

# Dependent variable
y ~ MvNormal(α .+ X * β, σ^2 * I)

return (; τ, λ, α, β, σ, y)
end

@model function (m::HorseShoePrior)(::Val{:+})
# Independent variable
X = m.X
J = size(X, 2)

# Priors
τ ~ truncated(Cauchy(0, 1/J); lower=0)
η ~ truncated(Cauchy(0, 1); lower=0)
λ ~ filldist(Cauchy(0, 1), J)
β ~ MvNormal(Diagonal(((η * τ) .* λ).^2)) # Coefficients
α ~ TDist(3) # Intercept
σ ~ Exponential(1) # Errors

# Dependent variable
y ~ MvNormal(α .+ X * β, σ^2 * I)

return (; τ, η, λ, β, y)
end

@model function (m::HorseShoePrior)(::Val{:finish}; τ₀=3, ν_local=1, ν_global=1, half_slab_df=2, slab_scale=2)
# Independent variable
X = m.X
J = size(X, 2)

# Priors
β₀ ~ Flat()
logσ ~ Flat()
σ = exp(logσ) # Noise std
z ~ MvNormal(Zeros(J), I)
λ ~ filldist(truncated(TDist(ν_local); lower=0), J) # Local shrinkage
τ ~ (τ₀ * σ) * truncated(TDist(ν_global); lower=0) # Global shrinkage
c_aux ~ InverseGamma(half_slab_df, half_slab_df)

# Dependent variable
c = slab_scale * sqrt(c_aux)
λtilde = λ ./ hypot.(1, (τ / c) .* λ)
β = τ .* z .* λtilde # Regression coefficients
f = β₀ .+ X * β # Latent function values
y ~ MvNormal(f,  I)

return (; τ, σ, logσ, λ, λtilde, z, c, c_aux, β₀, β, f, y)
end
``````

Let us create some toy data:

``````julia> X = randn(2, 4);
``````

We can create a model as usual:

``````julia> model_original = HorseShoePrior(X)(Val(:original));
``````

The main point is that neither `X` nor `y` are arguments of the model

``````julia> model_original.args
(##arg#405 = Val{:original}(),)

julia> DynamicPPL.inargnames(@varname(X), model_original)
false

julia> DynamicPPL.inargnames(@varname(y), model_original)
false
``````

but `X` is captured by the model function:

``````julia> model_original.f
HorseShoePrior{Matrix{Float64}}([0.9008876947448639 0.9172982651499152 -1.7703753804116238 1.5009865604420913; 0.15786060957092293 0.6001936008984858 0.8759197705885826 0.9490750485288362])
``````

Since `y` is not an argument to the model we can sample from the prior without having to use another model instance with `y=missing`. Instead we can just use

``````julia> priorsample_original = rand(model_original)
(τ = 3.695731988541993, λ = [0.7107394347599458, 1.0038839300698026, 0.026017515446747786, 0.7473783016558637], β = [-1.4881921127500568, 1.7428254134564192, -0.014855283604250714, -4.405199362018812], α = 1.9879125028306674, σ = 0.5943693342886811, y = [-3.9333591000191186, -1.432671943129296])
``````

(BTW also a recent addition that `rand` allows to do that and is the official way for sampling from the prior.)
We can evaluate `logjoint` and `logprior` probabilities and the `loglikelihood` of the model for the prior samples:

``````julia> logjoint(model_original, priorsample_original)
-12.434955305181779

julia> loglikelihood(model_original, priorsample_original)
0.0

julia> logprior(model_original, priorsample_original)
-12.434955305181779
``````

There are no observations here and hence the `loglikelihood` is zero and the `logjoint` is equal to the `logprior`.

Now let us generate some targets

``````julia> y = rand(2);
``````

We can use these observations in our model by conditioning it on the observations. One can use

``````julia> model_original_y = model_original | (; y);
``````

which is equivalent to

``````julia> model_original_y = DynamicPPL.condition(model_original, (; y));
``````

or

``````julia> model_original_y = DynamicPPL.condition(model_original; y);
``````

The conditioned model will use the provided value for `y` during model execution everywhere it appears in the model. `y` did not became an argument

``````julia> model_original_y.args
(##arg#405 = Val{:original}(),)
``````

but this is done by using a special context information:

``````julia> model_original_y.context
ConditionContext((y = [-0.15666952029608294, 1.080560882554258],), DynamicPPL.DefaultContext())
``````

Since `y` is fixed now, it’s value is not returned if we sample from the conditioned model:

``````julia> sample_original = rand(model_original_y)
(τ = 0.7153433622210013, λ = [3.8969988625444807, 108.0150249313839, 0.26170567941384487, 0.09822866602648996], β = [0.7437719510675507, -58.85226099400912, 0.02501194719031626, -0.18210821639807004], α = 0.5063561352542082, σ = 1.3204010967080868)
``````

Of course, in the conditioned model the `loglikelihood` is non-zero and `logjoint` and `logprior` are different:

``````julia> logjoint(model_original_y, sample_original)
-1203.0847753639716

julia> loglikelihood(model_original_y, sample_original)
-1177.2934647548811

julia> logprior(model_original_y, sample_original)
-25.791310609090537
``````

And, of course, we can now perform inference as usual, e.g., with `NUTS`:

``````julia> sample(model_original_y, NUTS(), 1_000);
``````

The same works also for

``````julia> model_plus = HorseShoePrior(X)(Val(:+));

julia> model_finish = HorseShoePrior(X)(Val(:finish));
``````
5 Likes

Thanks! I was struggling with how to construct it, the `MvNormal` docstring does not have any examples.

This is really neat! Thanks! I will start using these in my own coding and research!

That makes really easy for prior predictive checks, again thank you!

I’ve edited and updated all models. Thank you for the tips and suggestions!

3 Likes

Just to notify everyone involved that I’ve added the R2-D2 prior based on `Zhang, Y. D., Naughton, B. P., Bondell, H. D., & Reich, B. J. (2020). Bayesian regression using a prior on the model fit: The R2-D2 shrinkage prior. Journal of the American Statistical Association`.

3 Likes

Great thread and great implementations. I’ll try them out as well. 1 Like