Custom Models For Bayesian Inference

Hello all. I’m hoping someone can help me navigate the web of probabilistic programming packages so that I can begin using the Julia ecosystem for my custom models. I am experienced with Bayesian methods and probabilistic programming, to the point of writing my own samplers 10 years ago, but I’m new to the Julia ecosystem for this.

I am trying to figure out a way to use some combination of Turing, AbstractHMC, AbstractMCMC, etc. to build and sample from custom models. I work with complicated and somewhat expensive models, so I like to be able to build and optimize my own sum-log-likelihood functions.

Attempt 1: I tried to figure out how to do this in Turing to no avail. The best I can find out how to do is construct a custom distribution and attach logpdf methods to them. This is limiting though and in some cases I don’t want to have to work through the Distributions.jl package and so forth.

Attempt 2: I’ve tried to use AdvancedHMC to construct my own LogDensityModel, LogDensityProblems. I have some progress on this, though there are few examples in this space. The problem I run into is that the output of the samplers for this do not produce an appropriate Chain data structure (as far as I can tell) and therefore plotting and analysis that you can perform with the output of the Turing sample are not directly accessible.

While my models are complicated my process is not. I just want to be able to construct my own custom models via a log density function (i.e. both sum-log-likelihood and log-priors) and sample from that using HMC/NUTS. There has to be a way to do this, but I can’t find how Turing, AdvancedHMC, AbstractMCMC, MCMCChains, StatsPlots etc. tie together.

Anyone have suggestions?

I am following this because I’d like to write some custom samplers (to do some research on sampling methodology) and if I understand how these things plug together I would have a lot easier time doing it…

I don’t find that limiting. Pretty much you get all the Turing-verse stuff for free if you defined two methods, logpdf and rand, for you custom new distribution.
Imagine doing that in Stan, PyMC, JAX etc… It is a bunch more work.

However you can pretty much get a Matrix of MCMC draws, where each column is a “parameter” and each row is an iteration, and construct easily a MCMCChains.Chains object and get diagnostics like ESS and \hat{R} for free:

julia> using MCMCChains

julia> draws = randn(50, 4);

julia> Chains(draws)
Chains MCMC chain (50×4×1 Array{Float64, 3}):

Iterations        = 1:1:50
Number of chains  = 1
Samples per chain = 50
parameters        = param_1, param_2, param_3, param_4

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Missing

     param_1   -0.0870    1.0446    0.1470    51.0843    58.4202    0.9838       missing
     param_2   -0.1164    1.0437    0.1306    66.7452    62.5367    1.0085       missing
     param_3   -0.0008    1.0607    0.1618    41.3179    43.7731    1.0396       missing
     param_4   -0.3235    0.9880    0.1343    53.6033    35.9954    1.0229       missing

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

     param_1   -2.0622   -0.5363   -0.0219    0.5629    1.5960
     param_2   -1.5460   -0.6366   -0.1992    0.4018    2.0178
     param_3   -2.1301   -0.7801    0.0772    0.7131    1.8766
     param_4   -1.8707   -1.0969   -0.2682    0.3361    1.2791

You can even name the parameters:

julia> Chains(draws, [:x, :y, :z, :w])
Chains MCMC chain (50×4×1 Array{Float64, 3}):

Iterations        = 1:1:50
Number of chains  = 1
Samples per chain = 50
parameters        = x, y, z, w

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Missing

           x   -0.0870    1.0446    0.1470    51.0843    58.4202    0.9838       missing
           y   -0.1164    1.0437    0.1306    66.7452    62.5367    1.0085       missing
           z   -0.0008    1.0607    0.1618    41.3179    43.7731    1.0396       missing
           w   -0.3235    0.9880    0.1343    53.6033    35.9954    1.0229       missing

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

           x   -2.0622   -0.5363   -0.0219    0.5629    1.5960
           y   -1.5460   -0.6366   -0.1992    0.4018    2.0178
           z   -2.1301   -0.7801    0.0772    0.7131    1.8766
           w   -1.8707   -1.0969   -0.2682    0.3361    1.2791

And also you can concatenate multi-threaded MCMC samples as a tensor so that MCMCChains.jl will recognize it:

julia> draws = randn(50, 4, 4); # 50 samples 4 params 4 chains

julia> Chains(draws, [:x, :y, :z, :w])
Chains MCMC chain (50×4×4 Array{Float64, 3}):

Iterations        = 1:1:50
Number of chains  = 4
Samples per chain = 50
parameters        = x, y, z, w

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Missing

           x   -0.0797    0.9477    0.0730   171.3384   194.1114    1.0091       missing
           y   -0.0304    0.9866    0.0669   218.1038   191.4629    1.0054       missing
           z   -0.0119    1.0029    0.0730   200.9825   236.7906    1.0195       missing
           w   -0.0526    1.0084    0.0649   245.2925   180.1219    1.0012       missing

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

           x   -1.8766   -0.7210   -0.1164    0.6181    1.4645
           y   -1.7218   -0.6753   -0.0991    0.5224    2.2803
           z   -1.9002   -0.7389   -0.0070    0.7286    1.8079
           w   -1.8965   -0.7332   -0.0311    0.7306    1.9235

Do these examples help you in your endeavor?
Let me know.

Thanks for your suggestions @Storopoli . This does help in terms of using AdvancedHMC outside of the Turing landscape, though I still have a ways to go. Regarding your point about defining custom models through Distributions.jl to get the full use of the Turing-verse. I would love to do this, but am not sure exactly what Turing targets to evaluate sum-log-likelihoods; that is what I need to overload. Consider the following example (this is a toy example but illustrates the point).

This would be a standard way to define a mixture model in Distributions.

using Random
using Distributions

# Generate a standard Mixture using Distributions.jl

μ1 , μ2 , σ1 , σ2 = 0.5 , -0.5 , 1.0 , 0.4;
weight1 = 0.3;

N1 , N2 = Normal(μ1,σ1) , Normal(μ2,σ2);

True_dist = MixtureModel([N1 , N2] , [weight1 , 1.0-weight1]);
test_data = rand(True_dist , 1_000);

Now suppose I want to construct this model, but avoid using the Distributions.MixtureModel constructor (it is allocation heavy). Also, suppose I want my log-likelihood to batch evaluate on a set of iid data and return the sum-log-likelihood. I would start with the following.

# Create my own mixture
struct my_mixture{T<:Real} <: ContinuousUnivariateDistribution
    μ1::T
    μ2::T
    σ1::T
    σ2::T
    weight::T
end

# Assign pdf, logpdf, rand, etc.
import Distributions.pdf
function pdf(model::my_mixture , x::Real)
    μ1 , μ2 , σ1 , σ2 , weight = model.μ1 , model.μ2 , model.σ1 , model.σ2 , model.weight
    dist1 , dist2 = Normal(μ1 , σ1) , Normal(μ2 , σ2)
    pdf_val = weight .* pdf.(dist1,x) .+ (1.0 - weight) .* pdf.(dist2,x)
    return pdf_val
end

# Construct a function that performs custom batch evaluation for 
# sum-log-likelihood.
function fast_SLL(model::my_mixture , data::Vector{<:Real})

    μ1 , μ2 , σ1 , σ2 , weight = model.μ1 , model.μ2 , model.σ1 , model.σ2 , model.weight
    dist1 , dist2 = Normal(μ1 , σ1) , Normal(μ2 , σ2)
    SLL = sum(log.( weight .* pdf.(dist1,data) .+ (1.0 - weight) .* pdf.(dist2,data)))

    return SLL
end

Since this is a univariate distribution, all of the various methods (pdf, logpdf, etc) need to be defined for scaler inputs. But for inference, batch evaluation is required and I would like to be able to perform that myself without resorting to something like logpdf.(model,data) .

The question is, how do I make Turing.jl look at fast_SLL (which is defined on vectors of data) when it is performing inference, rather than something like pdf or logpdf (which are defined on scalers).

I can appropriately wrap all of this and use with AdvancedHMC, but if Turing is flexible enough I would prefer this.

Any advice is greatly appreciated.

Why not define the pdf function over a Vector{<:Real} like the fast_SLL function?
Then you define your custom distribution as <: ContinuousMultivariateDistribution (Multivariate Distributions · Distributions.jl) and run the Turing model like this:

@model function my_model_multivariate(data)
    # ...
    data ~ my_mixture(...) # data is a vector of values
end

instead of the univariate:

@model function my_model_univariate(data)
    # ...
    for i in 1:length(data)
        data[i] ~ my_mixture(...) # data[i] is a scalar
    end
end
2 Likes