Function for resuming mcmc chain

I want to write a function so I will not have to wait for my long inference and get an end result but I will be able to get intermediate results.

let’s say if I want to do inference for 100000 iteration.

I can get results after every 10000 iteration and resume chain for next 10000, using end points from last chain. this will help in having an idea about if things are working out or not with inference rather then sitting blindly for end results.

I have tried function before too with motivation from one of post here in discourse though it didn’t worked it does resume chain but more or less its just starting chain again from start (different start point) so its not learning anything form previous chain.

I wanted to ask if you guys come across such issues and if some of you already implemented something like this in your programs. if not I would appreciate suggestions that you think will be helpful in my case because waiting for inference to end after some days and then getting weird fits, will be waste of time.

Thanks

I’m not sure how you are implementing your sampler, but whenever I had implemented this sort of thing (in other languages), the sampler object had state, and I just asked it for the next sample. Is there a particular library in Julia you are using, or an example that you could give us of what you have tried?

1 Like

If I understand properly, you are wishing to be able to “pause” an inference chain at a specific point to investigate if it is converging well, and then be able to resume (or have it automatically resume).

I presume you are attempting some sort of bayesian inference? If so, here are my thoughts:

  • If you are struggling to converge, it is often going to be an issue with the model/priors. My typical approach is to prototype with shorter chains (1k-5k) to see if it will begin convergence before then.
    • Terms you might be interested in looking up include stabilizing priors and reparameterization.
  • If you’ve run the model and it does converge after a very long run (50k or more maybe?) could you use that as a starting point in your next run? Like @frylock mentioned, you could extract the state from the Chains structs and then use that as a starting point for the next run.

My experience with Stan (only a couple of projects) suggests that if something isn’t beginning to converge quickly, I need to adjust my priors/model away from a pathological condition.

This might also be a good topic to put in the Specific Domains/Probabalistic Programming domain, that will make it more visible to the community.

1 Like

This is my current code.

chain_reloaded = deserialize("/Program_Julia/chain.jls")

for i in 1:20
 
    println("Start$(i)")
    chains = sample(model, NUTS(), MCMCThreads(), 1000, 2; progress = true, save_state = true, resume_from = chain_reloaded)
   
    plot(chains)
    savefig("/Program_Julia/trial/chain_$(i).pdf")

    plot_fit(chains, i)
    savefig("/Program_Julia/trial/fit_$(i).pdf")
    
    serialize("/Program_Julia/chain_new.jls", chains)
    chain_reloaded = deserialize("/Program_Julia/chain_new.jls")
    println("End$(i)")

end

Yes I’m doing bayesian inference and using Turing.jl.

I also approach inference just like that I start with low iteration number to see issues in inference or my code though if model it taking let’s say around 120k iterations to give convergence then its hard to judge with 2000-5000 iterations.

As you can see in code shared by me in reply to @frylock, i’m doing something similar as you mentioned about extracting state from chain element.

I have already done inference for this model in r with one chain. here I was using multiple chain and even around 60k I can’t see stable chains. So, i expect it will take much more for all 3 chains to converge. For now I’m running it for 120k iteration with 1 chain to see if there is nothing wrong in code as i expect it to converge by 120k for 1 chain.

I have changed category to specific domain - probabilistic programming as you suggested.

1 Like

Unfortuantely, at the moment (but this will change soon) you have to make use of the underlying step interface to be able to resume a chain.

julia> using Turing, Random

julia> @model function gdemo(xs)
           # Assumptions
           σ² ~ InverseGamma(2, 3)
           μ ~ Normal(0, √σ²)
           # Observations
           for i = 1:length(xs)
               xs[i] ~ Normal(μ, √σ²)
           end
       end

gdemo (generic function with 2 methods)

julia> # Set up.
       xs = randn(100);

julia> model = gdemo(xs);

julia> # Sampler.
       alg = NUTS(0.65);

julia> kwargs = (nadapts=50,);

julia> num_samples = 100;

julia> ### The following two methods are equivalent ###
       ## Using `sample` ##
       rng = MersenneTwister(42);

julia> chain = sample(rng, model, alg, num_samples; kwargs...)
┌ Info: Found initial step size
└   ϵ = 0.4
Sampling 100%|█████████████████████████████████████████████████████| Time: 0:00:00
Chains MCMC chain (100×14×1 Array{Float64, 3}):

Iterations        = 51:1:150
Number of chains  = 1
Samples per chain = 100
Wall duration     = 1.12 seconds
Compute duration  = 1.12 seconds
parameters        = σ², μ
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec 
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Float64 

          σ²    1.1090    0.1728    0.0591    10.8467    51.5563    1.1136        9.6587
           μ   -0.1753    0.1030    0.0126    66.7404    78.3393    0.9940       59.4304

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

          σ²    0.8499    0.9980    1.1052    1.2155    1.4554
           μ   -0.3526   -0.2485   -0.1774   -0.1074    0.0239


julia> ## Using the iterator-interface ##
       rng = MersenneTwister(42);

julia> spl = DynamicPPL.Sampler(alg);

julia> nadapts = 50;

julia> # Create an iterator we can just step through.
       it = AbstractMCMC.Stepper(rng, model, spl, kwargs);

julia> # Initial sample and state.
       transition, state = iterate(it);
┌ Info: Found initial step size
└   ϵ = 0.4

julia> # Simple container to hold the samples.
       transitions = [];

julia> # Simple condition that says we only want `num_samples` samples.
       condition(spls) = length(spls) < num_samples
condition (generic function with 1 method)

julia> # Sample until `condition` is no longer satisfied
       while condition(transitions)
           # For an iterator we pass in the previous `state` as the second argument
           transition, state = iterate(it, state)
           # Save `transition` if we're not adapting anymore
           if state.i > nadapts
               push!(transitions, transition)
           end
       end

julia> length(transitions), state.i, state.i == length(transitions) + nadapts
(100, 150, true)

julia> # Finally, if you want to convert the vector of `transitions` into a
       # `MCMCChains.Chains` like is typically done:
       chain = AbstractMCMC.bundle_samples(
           map(identity, transitions),  # trick to concretize the eltype of `transitions`
           model,
           spl,
           state,
           MCMCChains.Chains
       )
Chains MCMC chain (100×14×1 Array{Float64, 3}):

Iterations        = 1:1:100
Number of chains  = 1
Samples per chain = 100
parameters        = σ², μ
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec 
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Missing 

          σ²    1.1090    0.1728    0.0591    10.8467    51.5563    1.1136       missing
           μ   -0.1753    0.1030    0.0126    66.7404    78.3393    0.9940       missing

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

          σ²    0.8499    0.9980    1.1052    1.2155    1.4554
           μ   -0.3526   -0.2485   -0.1774   -0.1074    0.0239

(this example is from Simple example of using NUTS with the new iterator interface in AbstractMCMC.jl available using Turing.jl > 0.15. · GitHub)

In here you can just serialize the state, and then when you want continue from a given state, you can just load the state from disk, skip the first line of calling iterate(it) and instead use the state you just loaded:)

1 Like

half a year later, someone else was looking for a solution to that problem. Additionally to what was described above, I am trying to hand over the initial parameters for the sampler (which does not seem to be an accepted keyword for the sampler) with several chains (using multi-threading). I played a bit around with the code snippet you provided, but could not quite figure out how to do this.
You also mentioned that the provided approach will change soon. Would you now recommend a different way to code all of this?

Indeed, it’s a bit easier to do these things now.

Now AbstractMCMC.sample (which is the same sample as the one in Turing.jl) allows you to pass initial_state as a keyword argument.

Combine this with the save_state=true keyword argument, and it’s fairly straight-forward to resume a run.

julia> using Turing


julia> @model function gdemo(xs)
           # Assumptions
           σ² ~ InverseGamma(2, 3)
           μ ~ Normal(0, √σ²)
           # Observations
           for i = 1:length(xs)
               xs[i] ~ Normal(μ, √σ²)
           end
       end
gdemo (generic function with 2 methods)

julia> # Set up.
       xs = randn(100);

julia> model = gdemo(xs);

julia> # Sampler.
       alg = NUTS(0.65);

julia> # Note the `save_state=true`, which allows us to extract the last state from the chain.
       chain_first = sample(model, alg, 1000; save_state=true);

┌ Info: Found initial step size
└   ϵ = 0.4
Sampling 100%|█████████████████████████████████████████████████████| Time: 0:00:01

julia> last_state = chain_first.info.samplerstate;

julia> # Continue sampling.
       chain_continuation = sample(
           model, alg, 500;
           # NOTE: At the moment we have to use `resume_from` because Turing.jl
           # is slightly lagging behind AbstractMCMC.jl, but soon we will use
           # `initial_state` instead, which is consistent with the rest of the
           # ecosystem.
           resume_from=last_state,
           # initial_state=last_state,
       );
Sampling 100%|█████████████████████████████████████████████████████| Time: 0:00:00

julia> range(chain_first)
501:1:1500

julia> range(chain_continuation)
1:1:500

julia> # Can only concatenate chains if the iterations are consistent.
       # So we have to update the iterations of the second chain.
       chain_continuation = setrange(
           chain_continuation,
           range(chain_first)[end] .+ range(chain_continuation)
       );

julia> chain_combined = vcat(chain_first, chain_continuation);

julia> range(chain_combined)
1500-element Vector{Int64}:
  501
  502
  503
  504
  505
  506
  507
  508
  509
  510
  511
  512
  513
  514
    ⋮
 1987
 1988
 1989
 1990
 1991
 1992
 1993
 1994
 1995
 1996
 1997
 1998
 1999
 2000

Note that there are clearly some quirks to iron out in the above example:

  • Turing.jl is a bit behind the AbstractMCMC.jl interface, so we need to update that (I’ll get on this asap though): we need to remove resume_from in favour of the new initial_state.
  • vcat for two chains concretizes the ranges, i.e. when concatenating 1:1:3 and 4:1:5 we get [1, 2, 3, 4, 5] instead of 1:1:5, which, though not incorrect, is potentially annoying since we show the entire range when displaying a Chains.
  • vcat requires monotonically increasing range, while Turing.jl constructs the range purely based on the number of samples + thinning, completely ignoring which iteration we started out at. Hence we need the line in the snippet above to manually align the two iteration ranges.

But these things can be sorted out fairly easily:)

EDIT: Ref Do we need `resume_from` now that we have `initial_state`? · Issue #2171 · TuringLang/Turing.jl · GitHub. This should be a fairly quick “aye” and then we just need a few simple PRs.

3 Likes

Just wondering if this has now been released so we can resume the chain using initial_state, or do we still need to use the resume_from syntax?

Hello,

I used this function but it seems like one must already find a most probable region by high warmup in first chain and then it make sense to continue which I find not very useful.

So, How I imagined this that sample(model, NUTS(500, 0.65), MCMCSerial(), 1000, 1) should give same results as iterating sample(model, NUTS(0, 0.65), MCMCSerial(), 100, 1) for 15 times which is not the case. Can you please explain why without warmup but same number of sample doesn’t work?

NUTS warmup samples are not just more MCMC samples, they are adapting the mass matrix and the step size and such. At least that’s my understanding. When you run NUTS you almost always want enough adaptation to get you into the high probability region and adapt the mass matrix, then you want just a relatively small number of samples after that. If it’s not converged as of the first sample that’s retained you basically haven’t done enough warmup.

To get NUTS working well, first try to use optimization to find the region of interest. It can help if you have at least an idea of the region of interest even before optimization. Then take an optimized starting point and perturb it slightly, it’s best if you aren’t right on the optimum to begin with… Run NUTS with say 500 warmup, and 100 real samples on say 3 threaded chains… Did you get convergence? If not, run from your new starting point a fresh run with even more warmup… Basically more warmup until the chains are converged out of the box… Then finally you can do real samples. Usually a small number is sufficient, 100, 1000, 5000 but not 100k.

If you’ve done warmup properly NUTS will move around ok.

1 Like

ok I understand but lets say my model is complex and even getting 200 samples after 100 warmup is taking 2-3 days. now I dont want to sit back and wait for results and then finally decide oops it didnt work and then I go for longer warmup and lets say it didn’t work too. Now this feels like working with black box. Is there no way to check inference how its going while it’s running?

The issue is the warmup. Before it is done with the warmup, it’s not a Markov chain, since it’s modifying what it does during warmup based on the output of the samples… So the question is how to get warmed up quickly?

One thing you can do is find a starting point that makes good sense. So there’s Pathfinder and just regular optimization can help with both of those.

Then there’s the question of the mass matrix (this is basically a diagonal covariance matrix usually). I’m not sure if there’s any way for you to give it a starting mass matrix, but maybe.

Then there’s sampling using variational inference to get some approximate posteriors before you give it a real run.

Hopefully those can give you some ideas and/or you can ask more specifically about the mass matrix tuning with someone who knows the code internals.

You can use Pathfinder’s preconditioner as an estimate of the local precision. So it should also work as a mass matrix. In fact, Stan people were talking about doing just that. Not sure if it is currently implemented that way though.