Custom priors in Soss

On Slack, Jonatan Werpers asks,

How can I use a custom prior with a model? Basically I want to do something like:

my_pdf(θ) = ....
m = @model n begin
    θ ~ my_pdf
    x ~ Bernoulli(θ) |> iid(n)

Generally, you’ll need three methods:

  1. Base.rand
  2. Distributions.logpdf
  3. Soss.xform

For example, say we want the logistic transform of a Normal. We could do this entirely in Soss, but for now (due to Distribution.jl overhead), there may be a performance advantage to building it directly.

Starting out is pretty easy:

using Soss
import Base: rand

struct MyPrior end

Base.rand(::MyPrior) = logistic(randn())
Distributions.logpdf(MyPrior, x) = logistic(-x^2/2)

Note that there’s no need for the logpdf to be an actual logpdf - normalization is irrelevant for many applications. We’ll soon be working instead in terms of measures, to avoid the extra computation when possible.

Soss.xform needs to return a transform in the sense of @Tamas_Papp’s TransformVariables.jl. For univariate cases this is especially simple:

Soss.xform(::MyPrior, _data=NamedTuple()) = as𝕀

The _data optional argument is just to make the dispatch work out, and it’s almost always fine to have it exactly like this.

Then we can do, e.g.,

julia> post = dynamicHMC(m(n=100), (x=x,));

julia> particles(post)
(θ = 0.245 ± 0.043,)

julia> m = @model n begin
           θ ~ MyPrior()
           x ~ Bernoulli(θ) |> iid(n)

julia> x = rand(100) .< 0.3;

julia> post = dynamicHMC(m(n=100), (x=x,));

julia> particles(post)
(θ = 0.314 ± 0.043,)

I would recommend implementing Base.rand as follows:

Base.rand(rng::Random.AbstractRNG, ::MyPrior) = logistic(randn(rng))

Base.rand(prior::MyPrior) = Base.rand(Random.GLOBAL_RNG, prior)

Good point @dilumaluthge, that makes it much more flexible

1 Like

Thank you for your detailed answer!

After writing on slack I managed to get something working through a lot of trail and error. It’s not what you suggest but I think it’s working correctly.

What I came up with was the following

struct MyPrior <: Distribution{Univariate,Continuous} end
Base.minimum(::MyPrior) = 0
Base.maximum(::MyPrior) = 1
Base.rand(rng::AbstractRNG, ::MyPrior) = rejection_sampling(p,Uniform(0,1),3.5)
Distributions.pdf(::MyPrior, π) = exp(-1 + λ - b*(1-π)^3 - c*(1-π)^12)
Distributions.logpdf(::MyPrior, π::Real) = -1 + λ - b*(1-π)^3 - c*(1-π)^12

Since I’m pretty new to Bayesian statistics in general and completely new to PPLs I’d be very interested in hearing if there are any upsides or downsides to doing what I did compared to your solution, if it’s even correct.

From my limited knowledge of how MCMC and similar methods work I was a little bit surprised I needed to implement rand(). I thought the sampling methods only used the “shape” of the pdf when doing inference.

Also, what is the return as𝕀 value in this function?

Again, thank you very much!

I guess you’re abbreviating here, right? Does MyPrior have λ, b, and c as parameters?

Otherwise, this looks good to me. You won’t need pdf for MCMC, but it can be nice to have around otherwise. Rather than copying the formula, I’d probably write it as

Distributions.pdf(dist::MyPrior, π) = exp(logpdf(dist, π))

There’s no overhead, and it will propagate any changes instead of needing to type them twice,

This is just a hack because Distributions.jl has no way of representing the type of the support. For example,

julia> d = MvNormal(ones(3))
dim: 3
μ: 3-element FillArrays.Zeros{Float64,1,Tuple{Base.OneTo{Int64}}} = 0.0
Σ: [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 1.0]

julia> typeof(d)

We need to know that rand(d) returns a Float64 vector of length 3. But there’s not a three to be found! So for now I just draw a rand and use that for the analysis. This will get better with MeasureTheory.jl, which we’ll be switching to when it’s mature enough.

as𝕀 is from TransformVariables.jl. We need to transform to a density over ℝⁿ. This will go in as one component, and all the pieces compose to get the final result. Like this:

julia> using TransformVariables
[ Info: Precompiling TransformVariables [84d833dd-6860-57f9-a1a7-6da5db126cff]

julia> f = as𝕀

julia> f(-30.0)

julia> f(20.0)

julia> inverse(f)(0.2)

julia> inverse(f)(0.99)
1 Like

Yes, I’m abbreviating, λ , b , and c are all constants.

The rand thing indeed seems like a little bit of a hack. Could you solve it by using the length function?

julia> d = MvNormal(ones(3))
dim: 3
μ: 3-element FillArrays.Zeros{Float64,1,Tuple{Base.OneTo{Int64}}} = 0.0
Σ: [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 1.0]

julia> length(d)

Or perhaps you need it to be more “static” in some sense?

Thanks again, I’ll be keeping an eye on these packages to learn more!

Yep, it’s entirely a hack.

No, because there are plenty of values that aren’t arrays.

Yes! That’s the plan, in MeasureTheory.jl. Similar to how Gen goes about it, we’ll have a type parameter that specifies the return value. For “built-in” measures, we’ll probably use SizedArrays so we can get the shapes right without the extra overhead and annoyance :slight_smile:

1 Like