Rand/logpdf semantic consistency

There’s a subtle issue I’ve mentioned before involving rand and logpdf that has gotten in the way of composability. I’m thinking in terms of Soss and MeasureTheory.jl, but I thought I’d post here since it may be helpful to others.

First, the problem. There’s an implicit law in Distributions.jl that for any distribution d,

logpdf(d, rand(d))

should be valid.

Now, suppose we have a Soss model

m = @model N begin
    μ ~ Normal()
    σ ~ HalfNormal()
    x ~ Normal(μ,σ) |> iid(N)
    return x
end

So x = rand(m(N=10)) would just give us a vector of 10 values, and evaluating logpdf(m(N=10), x) would require integration (by MCMC or other methods).

This is sometimes what we want, but often it’s not. The problem is that the internal state is discarded.

In other contexts (StatsModels, Turing, etc) we have a sample that’s similar in some ways to rand, but might have some additional context included. Also, in MeasureTheory.jl I’ve been working in terms of logdensity instead of logpdf, because many densities won’t be probability densities.

So currently I’m thinking we could keep logpdf(d, rand(d)) as is and add a similar law for logdensity(d, sample(d)), where the return value of sample includes that of rand in addition to a representation of the internal state. For example, here it could be a named tuple over (μ,σ,x). Then for PPL, we can work in terms of sample except for the special case where we want to intentionally discard the internal state.

Any thoughts/concerns?

2 Likes

So in the case of logdensity(d, sample(d)), sample(d) would be like a chain of MCMC samples or maybe some kind of iterator? I like it!

Though I would kind of also expect a method like logdensity(d, sample(d), x) where x is some overloaded value – for example if you wanted the marginal density over m=2 or whatever. Not sure if that makes any kind of sense, but maybe it’s a useful API function all the PPLs could talk to for posterior predictive checks.

The implicit assumption here is that Distributions.jl is a collection of distributions for which rand makes sense (being O(1)).

Generally this is not true for probability distributions — which motivates techniques like MCMC and variational inference. Mathematically, it makes sense to treat posteriors as distributions, but algorithmically, they are a totally different beast from Normal(0, 1).

Not an iterator in general. Ok, here’s a silly example. Say we have a model like

m = @model n begin
    μ ~ Normal()
    x ~ Normal(μ,1) |> iid(n)
end

but then we want some more flexibility in the prior, so we might change it to

m = @model dist,n begin
    μ ~ dist()
    x ~ Normal(μ,1) |> iid(n)
end

Now maybe dist could itself be a model, like

dist = @model begin
    u ~ Uniform()
    x ~ Normal()
    return x/u
end

[This is a slash distribution. We could work directly with the pdf, but let’s pretend we don’t know that]

Now the problem, is, rand just returns a Float64, e.g.,

julia> rand(dist())
-6.5117466025256405

And this forgetfulness propagates to m,

julia> rand(m(dist=dist, n=3))
(μ = 2.4853748267304936, x = [2.0221493956435337, 1.5847322295942938, 3.2714506239967456])

Now even something really simple like likelihood weighting would require an additional integration step, because we’ve lost information about μ.

So what I’m proposing here is that sample with one argument should behave like rand, except that it should also expose the internal state.

So (to start) maybe something like

julia> sample(dist())
(value = 0.029393215580522436, state = (x = 0.008020548138383859, u = 0.2728707281587359))

The particular data type here isn’t so important, the main idea is that sample has some additional information to what’s provided by rand. And we probably want an additional law that, whenever both sides are defined,

rand(d) == value(sample(d))

where value can extract the return value, discarding the internal state.

To @Tamas_Papp’s point,

We’ve discussed this before, and in principle I agree. But to pin things down a bit more I’d argue it’s not O(1) that’s special, for example for a d::MvNormal, rand(d) would be O(\texttt{size(d)}). And there’s no reason to disallow a distribution over iterators, in which case the cost of rand would depend how many you evaluate.

So maybe the rule here is that what’s evaluated by rand is determined by the distribution alone, while the algorithm used by sample often depends on its arguments.

How it scales with the dimensions is not really relevant here — what matters is that you have an algorithm for IID draws.

Sure, but as collections, they would be IID with O(1) cost, for which (again) you have an algorithm trivially from the IID draws.

I think that with the current state of algorithms (ie until someone discovers a general way of getting IID samples from a wider class of distributions), unifying the two domains is not possible or helpful in practice.

I agree with the IID assumption for rand, but I don’t see how that’s relevant to this discussion. I’m not suggesting changing rand in any way. Maybe I’m missing your point, and some more detail could help?

My understanding is that you somehow want to introduce an abstraction for steps of MCMC & friends, but it is not clear how that would work, or even if that is possible.

We do this in AbstractMCMC.

1 Like

Thanks @Tamas_Papp , I think my original post wasn’t so clear. I’ll try again :slight_smile:

Say we have a model like

slash = @model begin
    u ~ Uniform()
    x ~ Normal()
    return x/u
end

Let’s set aside inference, and focus on forward sampling. What does it mean to “run the model”?

The most obvious behavior we’d want is that of the function

function runslash()
    u = rand(Uniform())
    x = rand(Normal())
    return x/u
end

If we want to generate fake data, these is exactly what we need. It also matches the usual sense of a return value. It’s a relatively simple thing, and rand should also be a relatively simple thing. So I think it follows the “principle of least surprise” to have rand(slash()) do the same thing as runslash().

But sometimes we use forward sampling as a component in some larger inference procedure. In this case it’s important to have the values of u and x. Those are discarded by runslash. In this case we need something that can “look under the hood” and store this information together with the return value.

This is exactly what we need for sampling - information about a given sample, together with some metadata. So it seems natural to add a sample method to do this.

As I’m writing this, I realize there’s a question that’s central to all of this that I haven’t yet asked…

The Real Question (I think)

Turing uses sample, and this seems to return an MCMCChains.Chains. The implementation of that looks like

"""
    Chains
Parameters:
- `value`: An `AxisArray` object with axes `iter` × `var` × `chains`
- `logevidence` : A field containing the logevidence.
- `name_map` : A `NamedTuple` mapping each variable to a section.
- `info` : A `NamedTuple` containing miscellaneous information relevant to the chain.
The `info` field can be set using `setinfo(c::Chains, n::NamedTuple)`.
"""
struct Chains{T,A<:AxisArray{T,3},L,K<:NamedTuple,I<:NamedTuple} <: AbstractMCMC.AbstractChains
    value::A
    logevidence::L
    name_map::K
    info::I
end

I really like the idea of having value and info fields. I’m much less certain about having the value in an AxisArray. In most cases, wouldn’t this lead to T==Any? And it doesn’t seem to support returning an iterable.

I think I’d prefer to have a struct for a single observation, and then maybe a StructArray to hold an array of them. Or sometimes we just want one, or sometimes a Base.Generator, etc. Would that be compatible with Turing’s approach?

Ah, right. Need to dig into that again. The Turing ecosystem is getting big :slight_smile:

Us too – there’s been lots of discussion about ripping this backend out and making it more agnostic (maybe NamedTuples, for instance). See the issue here for additional thoughts on that.

In this case you would probably want to use step instead of sample, since step represents the creation of one parameter draw. If you want like an array of raw paramter draws you can use the chain_type = Nothing keyword in sample, which does’t wrap anything up in MCMCChains. This might work if you want to do some kind of integration.

If you wanted a generator, you could make an iterable sampler with steps, which returns an iterator that just spits out new samples forever. We also have a Transducers frontend that I haven’t really used but might also fit here.

1 Like

I was thinking of something like

struct Info{T, I<:NamedTuple}
    value :: T
    info :: I
end

value(m::Info) = m.value
value(x) = x

info(m::Info) = m.info
info(x) = NamedTuple()

Then this is pretty easy:

julia> using StructArrays

julia> x = Info(2, (a=10,))
Info{Int64,NamedTuple{(:a,),Tuple{Int64}}}(2, (a = 10,))

julia> y = Info(4, (a=20,))
Info{Int64,NamedTuple{(:a,),Tuple{Int64}}}(4, (a = 20,))

julia> s = StructArray([x,y])
2-element StructArray(::Array{Int64,1}, ::Array{NamedTuple{(:a,),Tuple{Int64}},1}) with eltype Info{Int64,NamedTuple{(:a,),Tuple{Int64}}}:
 Info{Int64,NamedTuple{(:a,),Tuple{Int64}}}(2, (a = 10,))
 Info{Int64,NamedTuple{(:a,),Tuple{Int64}}}(4, (a = 20,))

julia> s.value
2-element Array{Int64,1}:
 2
 4

julia> s.info
2-element Array{NamedTuple{(:a,),Tuple{Int64}},1}:
 (a = 10,)
 (a = 20,)

julia> s[1]
Info{Int64,NamedTuple{(:a,),Tuple{Int64}}}(2, (a = 10,))

julia> s[2]
Info{Int64,NamedTuple{(:a,),Tuple{Int64}}}(4, (a = 20,))

Ok, I have a basic proof of concept. Say you have a contrived model like

m = @model a begin
    b ~ Normal(a, 1)
    c ~ Normal(b, 1)
    return b^2 + c^2
end

You can now do

julia> s = sample(m(a=0),10)
10-element StructArray(::Array{Float64,1}, StructArray(::Array{Float64,1}, ::Array{Float64,1})) with eltype Soss.Noted:
 Noted(1.3100372912053935, (b = 1.1252094868062938, c = 0.20962085298583777))
 Noted(4.745078687991199, (b = 0.21881883567215604, c = 2.1673017798973637))
 Noted(0.8426863421316825, (b = 0.19119954091586658, c = 0.8978469121655676))
 Noted(1.7526245538219878, (b = 1.1731807422277698, c = 0.6134097324691604))
 Noted(3.763157032169782, (b = -1.5678066141033025, c = -1.142427001144371))
 Noted(6.61875619007766, (b = 1.849770623498493, c = 1.7880449744119011))
 Noted(2.066216785671172, (b = -0.9687481224849007, c = 1.0619528524624582))
 Noted(1.7922240724341494, (b = 1.3387206621188208, c = 0.007145715520187501))
 Noted(1.866307324626571, (b = -1.103350467081465, c = -0.8055588565819907))
 Noted(2.9976385226467617, (b = -1.7182291644728145, c = -0.21290152888557534))

Each Noted contains a return value (the value) with its internal state (the info). And it’s in a StructArray, so you can do

julia> values(s)
10-element Array{Float64,1}:
 1.3100372912053935
 4.745078687991199
 0.8426863421316825
 1.7526245538219878
 3.763157032169782
 6.61875619007766
 2.066216785671172
 1.7922240724341494
 1.866307324626571
 2.9976385226467617

julia> infos(s)
10-element StructArray(::Array{Float64,1}, ::Array{Float64,1}) with eltype NamedTuple{(:b, :c),Tuple{Float64,Float64}}:
 (b = 1.1252094868062938, c = 0.20962085298583777)
 (b = 0.21881883567215604, c = 2.1673017798973637)
 (b = 0.19119954091586658, c = 0.8978469121655676)
 (b = 1.1731807422277698, c = 0.6134097324691604)
 (b = -1.5678066141033025, c = -1.142427001144371)
 (b = 1.849770623498493, c = 1.7880449744119011)
 (b = -0.9687481224849007, c = 1.0619528524624582)
 (b = 1.3387206621188208, c = 0.007145715520187501)
 (b = -1.103350467081465, c = -0.8055588565819907)
 (b = -1.7182291644728145, c = -0.21290152888557534)

but also

julia> s[1]
Noted(1.3100372912053935, (b = 1.1252094868062938, c = 0.20962085298583777))

Beyond that, the infos are also a StructArray, so slicing and dicing these is easy:

julia> infos(s)[2]
(b = 0.21881883567215604, c = 2.1673017798973637)

julia> infos(s).b
10-element Array{Float64,1}:
  1.1252094868062938
  0.21881883567215604
  0.19119954091586658
  1.1731807422277698
 -1.5678066141033025
  1.849770623498493
 -0.9687481224849007
  1.3387206621188208
 -1.103350467081465
 -1.7182291644728145

Things are rolling again, so I’m hoping to have some quick progress over the next week or so :slight_smile:

3 Likes