SampleChains

I’m working on a new interface for chains from MCMC sampling:

I’ve copied the README below. I’d love any feedback on the interface, or ideas for making it more useful and/or powerful.

Gratuitous @-ing…

  • @Tamas_Papp there’s an interface to DynamicHMC here. More back-ends on the way
  • @theogf we should come up with a nice way to incorporate callbacks so this works well with Turkie.jl
  • @baggepinnen I think I had been stretching MonteCarloMeasurements.jl beyond its intended use. Now we should be able to load Particles into a column and easily compute row-by-row.
  • @sethaxen, @cpfiffer, and I talked a little bit about moving diagnostics from MCMCChains to a lighter-weight package here

Example Use

Using SampleChainsDynamicHMC.jl, Soss can sample with very little Soss-specific code. I want to tidy this up more, but here’s how it looks now:

using Soss
using SampleChains
using SampleChainsDynamicHMC

pr = @model k begin
    σ ~ Exponential()
    α ~ Cauchy()
    β ~ Normal() |> iid(k)
end;

ℓ(x) = logdensity(pr(k=5), x);

t = xform(pr(k=5));

chain1 = initialize!(DynamicHMCChain, ℓ, t);
drawsamples!(chain1,1000)

chain2 = initialize!(DynamicHMCChain, ℓ, t);
drawsamples!(chain2,1000)

chain3 = initialize!(DynamicHMCChain, ℓ, t);
drawsamples!(chain3,1000)

chains = MultiChain(chain1, chain2, chain3)

From the README

julia> chains
3003-element MultiChain with 3 chains and schema (σ = Float64, α = Float64, β = Vector{Float64})
(σ = 0.9±0.88, α = -5.1±10.0, β = [-0.0±0.9, 0.04±0.99, 0.04±1.1, 0.02±0.92, 0.06±0.95])

Some features (many still in progress):

  • Simple visual representation
  • “Samples first”, though diagnostic information is easily available
  • Each Chain can be indexed as a Vector, or as a NamedTuple
  • Interrupting (CTRL-C) returns the current chain, including iterator information so work can be resumed
  • Built on ElasticArrays to make it easy to add new samples after the fact
  • Adaptable to many different sampling algorithms, including with or without (log-)weights
  • Easy summarization functions: expectations per-dimension quantiles, etc

In progress:

  • More back-ends (currently just DynamicHMC, using SampleChainsDynamicHMC)
  • Diagnostic warnings during sampling
  • Callbacks for plotting, etc
  • Sample count based on desired standard error of a specified expected value
  • Summarization by different functions
    • Highest posterior density intervals
    • R-hat statistics
4 Likes

Thanks for the ping! I have been thinking about the need for this, thanks for writing it. I am currently very busy but plan to look at the details.

2 Likes

@cscherrer couple thoughts:

First, I don’t like the different Chain types for each sampler. I think the sampler being separate and using an iterator interface (like AbstractMCMC.jl) makes it easier for writing samplers. I also found it pretty confusing that the actual DynamicHMC algorithm is defined with the DynamicHMCChain type. I think it is much more straightforward to see code like chain = sample(model, sampler, sample_options).

This also reminds me of the switch to immutable sampler state structs in AbstractMCMC v2. Apparently the mutable states can get hard to work with deep in the Turing modeling code (anecdotally; I don’t work on Turing internals).

Second, how would a system like this work with convergence based sampling? E.g., nested sampling has convergence criteria and much prefers an interface like res, state = sample(model, nested_sampler). (as an aside, in nested sampling the state is very important since it contains the log-evidence estimate, so having an interface with convergence sampling AND easy access to the state would be great).

Overall I like the way AbstractMCMC defines the iterator interface for samplers, I think it is straightforward for end-users and while it’s not perfect it’s good enough for someone with little prior PPL knowledge like myself to implement NestedSamplers.jl and have it work. The biggest downside is that Turing.jl itself doesn’t interface as perfectly with AbstractMCMC.jl and requires writing the sampler interface again internally.

That’s not the case, the algorithm is in DynamicHMC.jl. SampleChainsDynamicHMC is just a very thin wrapper than abstract away some of the details specific to DynamicHMC.

There’s no free lunch. Either

  • Every sampler can have exactly the same interface (not going to happen), or
  • Users can look up the differences each time they change samplers (big hassle, so they won’t), or
  • There can be an abstraction like this providing a uniform interface.

I guess I’m missing something, I don’t see how this would be harder than any other back-end.

That’s great, but my goal is making this easier for end users. If the iterator is separate, an end user will have to write boilerplate code to wrap it. With this we can figure it out once and be done with it.

2 Likes

I hope this didn’t come off as dismissive. If there are real downsides, I’d like to understand them. But I’m not seeing it yet.

This design replaces a lot of very Soss-specific code with a much more general interface. In Soss we can now write

using Soss
using SampleChains
using SampleChainsDynamicHMC

pr = @model k begin
    σ ~ Exponential()
    α ~ Normal()
    β ~ Normal() |> iid(k)
end;

ℓ(x) = logdensity(pr(k=5), x);

t = xform(pr(k=5));

chains = initialize!(4, DynamicHMCChain, ℓ, t)
drawsamples!(chains, 1000)

and that’s it! The final step will be to have this wrapped up in a sample method so you can say

using Soss
using SampleChainsDynamicHMC

pr = @model k begin
    σ ~ Exponential()
    α ~ Normal()
    β ~ Normal() |> iid(k)
end;

chains = sample(pr(k=5), DynamicHMC, nchains=4, nsamples=1000)

or something similar. Soss will have very little code specific to DynamicHMC, and SampleChainsDynamicHMC will have no code at all that’s specific to Soss.

2 Likes

Yeah, so I think the reason we’re pushing back a little is that this is essentially what AbstractMCMC does. If you’re interested in implementing that interface, the ultimate goal would be to have an AbstractMCMC wrapper around DynamicHMC, and then you just overwrite a bundle_samples call to wrap things up into this MultiChain struct.

1 Like

Maybe part of the disconnect here is that at the core I’m using TupleVectors to store everything. It doesn’t make sense in my approach to put everything into a Base.Vector{NamnedTuple} and then empty and repack it later.

Yeah you don’t have to do that at all. AbstractMCMC does not require you to use Vector{NamedTuple}. You can work internally with TupleVectors as much as you want.

Ok, so I guess it’s not clear to me which of this code could be moved into AbstractMCMC without giving up functionality. I should try to connect to that next, maybe then it will become clear.

The main point of this is to make things as easy as possible for end-users. A very common problem is a user drawing some samples, thinking it will be enough. But it’s not, so they need to draw some more. With this you can just do

julia> chains = initialize!(4, DynamicHMCChain, ℓ, t);

julia> drawsamples!(chains, 199)
800-element MultiChain with 4 chains and schema (σ = Float64, α = Float64, β = Vector{Float64})
(σ = 1.07±0.94, α = -0.02±0.81, β = [-0.03±0.82, 0.08±1.1, 0.18±1.2, -0.05±1.1, -0.05±0.83])

julia> drawsamples!(chains, 800)
4000-element MultiChain with 4 chains and schema (σ = Float64, α = Float64, β = Vector{Float64})
(σ = 1.07±0.95, α = -0.02±0.79, β = [-0.02±0.81, 0.08±1.1, 0.19±1.2, -0.05±1.0, -0.0±0.85])

Without easy access to the iterator, I can’t imagine how you’d even go about this.

My guess:

drawsamples!(model, chains, 800)

I think you’d have to go back through warmup in that case, unless model stores the sampler state

model would be able to query chains for the state in this case.

I hope this didn’t come off as dismissive.

No problems!

That’s not the case, the algorithm is in DynamicHMC.jl. SampleChainsDynamicHMC is just a very thin wrapper than abstract away some of the details specific to DynamicHMC.

Yep, I understand that, but I’ll admit it took me a minute of reading your code to understand where the “sampler” was. In terms of design, it is not obvious to me that DynamicHMCChain is also the thing that holds all the options for the HMC algorithm. It is much clearer to me when you see a struct like HMC(opts...).

I guess I’m missing something, I don’t see how this would be harder than any other back-end.

I was probably confounding topics here: before AbstractMCMC used its iterator interface, when you called sample(model, alg, N) it would typically pre-allocate the memory, which is great for speed, but doesn’t work for dynamic outputs like NestedSampling has. You aren’t doing that here but I wanted to point it out in case you weigh different design options.

If the iterator is separate, an end user will have to write boilerplate code to wrap it.

I disagree. AbstractMCMC does just this- the boilerplate is written into the package, so neither the users nor the samplers have to define the same for loop again.

A very common problem is a user drawing some samples, thinking it will be enough. But it’s not, so they need to draw some more.

I agree this is very useful. I actually think the iterator interface of AbstractMCMC does this pretty well, but that is at odds with the previous point of avoiding boilerplate. Perhaps the solution should be that even the sample functions return the state, too, something like

mod = #...
alg = NUTS(0.65) # e.g.
chains, state = sample(mod, alg, 1000)
# oh no, my samples suck
chains2, state2 = sample(mod, alg, state, 1000)
chains_final = vcat(chains, chains2)

(written in the style of AbstractMCMC but hopefully the point is clear)

Only if chains holds the state, which it seemed you were arguing against.

Thanks. I hope this might just be a matter of some better documentation and examples, and getting used to a different paradigm. I agree it’s confusing that it’s so different, but hopefully it’s worth it :slight_smile:

I see. Yes, pre-allocation is great for speed but from what I’m seeing ElasticArrays makes it plenty fast and gives a more elegant interface. The cost of pushing O(1000) numbers is negligible relative to the cost of sampling.

I think I’m not being clear. The boilerplate I’m talking about in this case is the code a user needs to add to manually track the iterator state in case more samples are needed. I agree AbstractMCMC does a good job with it, but it’s convenient for the user to not worry about planning ahead in this way.

I’m only arguing against holding an iterator, which as I understand it contains some kind of sampling logic. The state differs from the iterator in that it’s a functional input to the model and sampling logic, I think?

1 Like

Oh, I think I see. So you’d like a chain to hold the iterator state, just not the iterator itself? So you’re effectively rebuilding the iterator each time?

The compiler can probably make that cheap, I’m just concerned it could get tricky to keep track of all the dependencies you need to do it. I don’t think model would be enough, for example that doesn’t know what algorithm you’re using.

It doesn’t have to! The model just passes the state to the sampler and moves on with it’s day. Everything is compartmentalized and just works. And it’s super cheap to do – iterators of this sort are not expensive to make, moreso since they will not be rebuilt that often.

But how does it know where the sampler is? In Soss the model will get you the logdensity very quickly, but other things would need to be recomputed. If it’s all stored in the chains, that seems equivalent to what I’m doing now.