Function for resuming mcmc chain

I want to write a function so I will not have to wait for my long inference and get an end result but I will be able to get intermediate results.

letโ€™s say if I want to do inference for 100000 iteration.

I can get results after every 10000 iteration and resume chain for next 10000, using end points from last chain. this will help in having an idea about if things are working out or not with inference rather then sitting blindly for end results.

I have tried function before too with motivation from one of post here in discourse though it didnโ€™t worked it does resume chain but more or less its just starting chain again from start (different start point) so its not learning anything form previous chain.

I wanted to ask if you guys come across such issues and if some of you already implemented something like this in your programs. if not I would appreciate suggestions that you think will be helpful in my case because waiting for inference to end after some days and then getting weird fits, will be waste of time.

Thanks

Iโ€™m not sure how you are implementing your sampler, but whenever I had implemented this sort of thing (in other languages), the sampler object had state, and I just asked it for the next sample. Is there a particular library in Julia you are using, or an example that you could give us of what you have tried?

1 Like

If I understand properly, you are wishing to be able to โ€œpauseโ€ an inference chain at a specific point to investigate if it is converging well, and then be able to resume (or have it automatically resume).

I presume you are attempting some sort of bayesian inference? If so, here are my thoughts:

  • If you are struggling to converge, it is often going to be an issue with the model/priors. My typical approach is to prototype with shorter chains (1k-5k) to see if it will begin convergence before then.
    • Terms you might be interested in looking up include stabilizing priors and reparameterization.
  • If youโ€™ve run the model and it does converge after a very long run (50k or more maybe?) could you use that as a starting point in your next run? Like @frylock mentioned, you could extract the state from the Chains structs and then use that as a starting point for the next run.

My experience with Stan (only a couple of projects) suggests that if something isnโ€™t beginning to converge quickly, I need to adjust my priors/model away from a pathological condition.

This might also be a good topic to put in the Specific Domains/Probabalistic Programming domain, that will make it more visible to the community.

1 Like

This is my current code.

chain_reloaded = deserialize("/Program_Julia/chain.jls")

for i in 1:20
 
    println("Start$(i)")
    chains = sample(model, NUTS(), MCMCThreads(), 1000, 2; progress = true, save_state = true, resume_from = chain_reloaded)
   
    plot(chains)
    savefig("/Program_Julia/trial/chain_$(i).pdf")

    plot_fit(chains, i)
    savefig("/Program_Julia/trial/fit_$(i).pdf")
    
    serialize("/Program_Julia/chain_new.jls", chains)
    chain_reloaded = deserialize("/Program_Julia/chain_new.jls")
    println("End$(i)")

end

Yes Iโ€™m doing bayesian inference and using Turing.jl.

I also approach inference just like that I start with low iteration number to see issues in inference or my code though if model it taking letโ€™s say around 120k iterations to give convergence then its hard to judge with 2000-5000 iterations.

As you can see in code shared by me in reply to @frylock, iโ€™m doing something similar as you mentioned about extracting state from chain element.

I have already done inference for this model in r with one chain. here I was using multiple chain and even around 60k I canโ€™t see stable chains. So, i expect it will take much more for all 3 chains to converge. For now Iโ€™m running it for 120k iteration with 1 chain to see if there is nothing wrong in code as i expect it to converge by 120k for 1 chain.

I have changed category to specific domain - probabilistic programming as you suggested.

1 Like

Unfortuantely, at the moment (but this will change soon) you have to make use of the underlying step interface to be able to resume a chain.

julia> using Turing, Random

julia> @model function gdemo(xs)
           # Assumptions
           ฯƒยฒ ~ InverseGamma(2, 3)
           ฮผ ~ Normal(0, โˆšฯƒยฒ)
           # Observations
           for i = 1:length(xs)
               xs[i] ~ Normal(ฮผ, โˆšฯƒยฒ)
           end
       end

gdemo (generic function with 2 methods)

julia> # Set up.
       xs = randn(100);

julia> model = gdemo(xs);

julia> # Sampler.
       alg = NUTS(0.65);

julia> kwargs = (nadapts=50,);

julia> num_samples = 100;

julia> ### The following two methods are equivalent ###
       ## Using `sample` ##
       rng = MersenneTwister(42);

julia> chain = sample(rng, model, alg, num_samples; kwargs...)
โ”Œ Info: Found initial step size
โ””   ฯต = 0.4
Sampling 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| Time: 0:00:00
Chains MCMC chain (100ร—14ร—1 Array{Float64, 3}):

Iterations        = 51:1:150
Number of chains  = 1
Samples per chain = 100
Wall duration     = 1.12 seconds
Compute duration  = 1.12 seconds
parameters        = ฯƒยฒ, ฮผ
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec 
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Float64 

          ฯƒยฒ    1.1090    0.1728    0.0591    10.8467    51.5563    1.1136        9.6587
           ฮผ   -0.1753    0.1030    0.0126    66.7404    78.3393    0.9940       59.4304

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

          ฯƒยฒ    0.8499    0.9980    1.1052    1.2155    1.4554
           ฮผ   -0.3526   -0.2485   -0.1774   -0.1074    0.0239


julia> ## Using the iterator-interface ##
       rng = MersenneTwister(42);

julia> spl = DynamicPPL.Sampler(alg);

julia> nadapts = 50;

julia> # Create an iterator we can just step through.
       it = AbstractMCMC.Stepper(rng, model, spl, kwargs);

julia> # Initial sample and state.
       transition, state = iterate(it);
โ”Œ Info: Found initial step size
โ””   ฯต = 0.4

julia> # Simple container to hold the samples.
       transitions = [];

julia> # Simple condition that says we only want `num_samples` samples.
       condition(spls) = length(spls) < num_samples
condition (generic function with 1 method)

julia> # Sample until `condition` is no longer satisfied
       while condition(transitions)
           # For an iterator we pass in the previous `state` as the second argument
           transition, state = iterate(it, state)
           # Save `transition` if we're not adapting anymore
           if state.i > nadapts
               push!(transitions, transition)
           end
       end

julia> length(transitions), state.i, state.i == length(transitions) + nadapts
(100, 150, true)

julia> # Finally, if you want to convert the vector of `transitions` into a
       # `MCMCChains.Chains` like is typically done:
       chain = AbstractMCMC.bundle_samples(
           map(identity, transitions),  # trick to concretize the eltype of `transitions`
           model,
           spl,
           state,
           MCMCChains.Chains
       )
Chains MCMC chain (100ร—14ร—1 Array{Float64, 3}):

Iterations        = 1:1:100
Number of chains  = 1
Samples per chain = 100
parameters        = ฯƒยฒ, ฮผ
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec 
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Missing 

          ฯƒยฒ    1.1090    0.1728    0.0591    10.8467    51.5563    1.1136       missing
           ฮผ   -0.1753    0.1030    0.0126    66.7404    78.3393    0.9940       missing

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

          ฯƒยฒ    0.8499    0.9980    1.1052    1.2155    1.4554
           ฮผ   -0.3526   -0.2485   -0.1774   -0.1074    0.0239

(this example is from Simple example of using NUTS with the new iterator interface in AbstractMCMC.jl available using Turing.jl > 0.15. ยท GitHub)

In here you can just serialize the state, and then when you want continue from a given state, you can just load the state from disk, skip the first line of calling iterate(it) and instead use the state you just loaded:)

1 Like

half a year later, someone else was looking for a solution to that problem. Additionally to what was described above, I am trying to hand over the initial parameters for the sampler (which does not seem to be an accepted keyword for the sampler) with several chains (using multi-threading). I played a bit around with the code snippet you provided, but could not quite figure out how to do this.
You also mentioned that the provided approach will change soon. Would you now recommend a different way to code all of this?

Indeed, itโ€™s a bit easier to do these things now.

Now AbstractMCMC.sample (which is the same sample as the one in Turing.jl) allows you to pass initial_state as a keyword argument.

Combine this with the save_state=true keyword argument, and itโ€™s fairly straight-forward to resume a run.

julia> using Turing


julia> @model function gdemo(xs)
           # Assumptions
           ฯƒยฒ ~ InverseGamma(2, 3)
           ฮผ ~ Normal(0, โˆšฯƒยฒ)
           # Observations
           for i = 1:length(xs)
               xs[i] ~ Normal(ฮผ, โˆšฯƒยฒ)
           end
       end
gdemo (generic function with 2 methods)

julia> # Set up.
       xs = randn(100);

julia> model = gdemo(xs);

julia> # Sampler.
       alg = NUTS(0.65);

julia> # Note the `save_state=true`, which allows us to extract the last state from the chain.
       chain_first = sample(model, alg, 1000; save_state=true);

โ”Œ Info: Found initial step size
โ””   ฯต = 0.4
Sampling 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| Time: 0:00:01

julia> last_state = chain_first.info.samplerstate;

julia> # Continue sampling.
       chain_continuation = sample(
           model, alg, 500;
           # NOTE: At the moment we have to use `resume_from` because Turing.jl
           # is slightly lagging behind AbstractMCMC.jl, but soon we will use
           # `initial_state` instead, which is consistent with the rest of the
           # ecosystem.
           resume_from=last_state,
           # initial_state=last_state,
       );
Sampling 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| Time: 0:00:00

julia> range(chain_first)
501:1:1500

julia> range(chain_continuation)
1:1:500

julia> # Can only concatenate chains if the iterations are consistent.
       # So we have to update the iterations of the second chain.
       chain_continuation = setrange(
           chain_continuation,
           range(chain_first)[end] .+ range(chain_continuation)
       );

julia> chain_combined = vcat(chain_first, chain_continuation);

julia> range(chain_combined)
1500-element Vector{Int64}:
  501
  502
  503
  504
  505
  506
  507
  508
  509
  510
  511
  512
  513
  514
    โ‹ฎ
 1987
 1988
 1989
 1990
 1991
 1992
 1993
 1994
 1995
 1996
 1997
 1998
 1999
 2000

Note that there are clearly some quirks to iron out in the above example:

  • Turing.jl is a bit behind the AbstractMCMC.jl interface, so we need to update that (Iโ€™ll get on this asap though): we need to remove resume_from in favour of the new initial_state.
  • vcat for two chains concretizes the ranges, i.e. when concatenating 1:1:3 and 4:1:5 we get [1, 2, 3, 4, 5] instead of 1:1:5, which, though not incorrect, is potentially annoying since we show the entire range when displaying a Chains.
  • vcat requires monotonically increasing range, while Turing.jl constructs the range purely based on the number of samples + thinning, completely ignoring which iteration we started out at. Hence we need the line in the snippet above to manually align the two iteration ranges.

But these things can be sorted out fairly easily:)

EDIT: Ref Do we need `resume_from` now that we have `initial_state`? ยท Issue #2171 ยท TuringLang/Turing.jl ยท GitHub. This should be a fairly quick โ€œayeโ€ and then we just need a few simple PRs.

1 Like