Function for resuming mcmc chain

I want to write a function so I will not have to wait for my long inference and get an end result but I will be able to get intermediate results.

let’s say if I want to do inference for 100000 iteration.

I can get results after every 10000 iteration and resume chain for next 10000, using end points from last chain. this will help in having an idea about if things are working out or not with inference rather then sitting blindly for end results.

I have tried function before too with motivation from one of post here in discourse though it didn’t worked it does resume chain but more or less its just starting chain again from start (different start point) so its not learning anything form previous chain.

I wanted to ask if you guys come across such issues and if some of you already implemented something like this in your programs. if not I would appreciate suggestions that you think will be helpful in my case because waiting for inference to end after some days and then getting weird fits, will be waste of time.

Thanks

I’m not sure how you are implementing your sampler, but whenever I had implemented this sort of thing (in other languages), the sampler object had state, and I just asked it for the next sample. Is there a particular library in Julia you are using, or an example that you could give us of what you have tried?

1 Like

If I understand properly, you are wishing to be able to ā€œpauseā€ an inference chain at a specific point to investigate if it is converging well, and then be able to resume (or have it automatically resume).

I presume you are attempting some sort of bayesian inference? If so, here are my thoughts:

  • If you are struggling to converge, it is often going to be an issue with the model/priors. My typical approach is to prototype with shorter chains (1k-5k) to see if it will begin convergence before then.
    • Terms you might be interested in looking up include stabilizing priors and reparameterization.
  • If you’ve run the model and it does converge after a very long run (50k or more maybe?) could you use that as a starting point in your next run? Like @frylock mentioned, you could extract the state from the Chains structs and then use that as a starting point for the next run.

My experience with Stan (only a couple of projects) suggests that if something isn’t beginning to converge quickly, I need to adjust my priors/model away from a pathological condition.

This might also be a good topic to put in the Specific Domains/Probabalistic Programming domain, that will make it more visible to the community.

1 Like

This is my current code.

chain_reloaded = deserialize("/Program_Julia/chain.jls")

for i in 1:20
 
    println("Start$(i)")
    chains = sample(model, NUTS(), MCMCThreads(), 1000, 2; progress = true, save_state = true, resume_from = chain_reloaded)
   
    plot(chains)
    savefig("/Program_Julia/trial/chain_$(i).pdf")

    plot_fit(chains, i)
    savefig("/Program_Julia/trial/fit_$(i).pdf")
    
    serialize("/Program_Julia/chain_new.jls", chains)
    chain_reloaded = deserialize("/Program_Julia/chain_new.jls")
    println("End$(i)")

end

Yes I’m doing bayesian inference and using Turing.jl.

I also approach inference just like that I start with low iteration number to see issues in inference or my code though if model it taking let’s say around 120k iterations to give convergence then its hard to judge with 2000-5000 iterations.

As you can see in code shared by me in reply to @frylock, i’m doing something similar as you mentioned about extracting state from chain element.

I have already done inference for this model in r with one chain. here I was using multiple chain and even around 60k I can’t see stable chains. So, i expect it will take much more for all 3 chains to converge. For now I’m running it for 120k iteration with 1 chain to see if there is nothing wrong in code as i expect it to converge by 120k for 1 chain.

I have changed category to specific domain - probabilistic programming as you suggested.

1 Like

Unfortuantely, at the moment (but this will change soon) you have to make use of the underlying step interface to be able to resume a chain.

julia> using Turing, Random

julia> @model function gdemo(xs)
           # Assumptions
           σ² ~ InverseGamma(2, 3)
           μ ~ Normal(0, √σ²)
           # Observations
           for i = 1:length(xs)
               xs[i] ~ Normal(μ, √σ²)
           end
       end

gdemo (generic function with 2 methods)

julia> # Set up.
       xs = randn(100);

julia> model = gdemo(xs);

julia> # Sampler.
       alg = NUTS(0.65);

julia> kwargs = (nadapts=50,);

julia> num_samples = 100;

julia> ### The following two methods are equivalent ###
       ## Using `sample` ##
       rng = MersenneTwister(42);

julia> chain = sample(rng, model, alg, num_samples; kwargs...)
ā”Œ Info: Found initial step size
ā””   ϵ = 0.4
Sampling 100%|ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| Time: 0:00:00
Chains MCMC chain (100Ɨ14Ɨ1 Array{Float64, 3}):

Iterations        = 51:1:150
Number of chains  = 1
Samples per chain = 100
Wall duration     = 1.12 seconds
Compute duration  = 1.12 seconds
parameters        = σ², μ
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec 
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Float64 

          σ²    1.1090    0.1728    0.0591    10.8467    51.5563    1.1136        9.6587
           μ   -0.1753    0.1030    0.0126    66.7404    78.3393    0.9940       59.4304

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

          σ²    0.8499    0.9980    1.1052    1.2155    1.4554
           μ   -0.3526   -0.2485   -0.1774   -0.1074    0.0239


julia> ## Using the iterator-interface ##
       rng = MersenneTwister(42);

julia> spl = DynamicPPL.Sampler(alg);

julia> nadapts = 50;

julia> # Create an iterator we can just step through.
       it = AbstractMCMC.Stepper(rng, model, spl, kwargs);

julia> # Initial sample and state.
       transition, state = iterate(it);
ā”Œ Info: Found initial step size
ā””   ϵ = 0.4

julia> # Simple container to hold the samples.
       transitions = [];

julia> # Simple condition that says we only want `num_samples` samples.
       condition(spls) = length(spls) < num_samples
condition (generic function with 1 method)

julia> # Sample until `condition` is no longer satisfied
       while condition(transitions)
           # For an iterator we pass in the previous `state` as the second argument
           transition, state = iterate(it, state)
           # Save `transition` if we're not adapting anymore
           if state.i > nadapts
               push!(transitions, transition)
           end
       end

julia> length(transitions), state.i, state.i == length(transitions) + nadapts
(100, 150, true)

julia> # Finally, if you want to convert the vector of `transitions` into a
       # `MCMCChains.Chains` like is typically done:
       chain = AbstractMCMC.bundle_samples(
           map(identity, transitions),  # trick to concretize the eltype of `transitions`
           model,
           spl,
           state,
           MCMCChains.Chains
       )
Chains MCMC chain (100Ɨ14Ɨ1 Array{Float64, 3}):

Iterations        = 1:1:100
Number of chains  = 1
Samples per chain = 100
parameters        = σ², μ
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec 
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Missing 

          σ²    1.1090    0.1728    0.0591    10.8467    51.5563    1.1136       missing
           μ   -0.1753    0.1030    0.0126    66.7404    78.3393    0.9940       missing

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

          σ²    0.8499    0.9980    1.1052    1.2155    1.4554
           μ   -0.3526   -0.2485   -0.1774   -0.1074    0.0239

(this example is from Simple example of using NUTS with the new iterator interface in AbstractMCMC.jl available using Turing.jl > 0.15. Ā· GitHub)

In here you can just serialize the state, and then when you want continue from a given state, you can just load the state from disk, skip the first line of calling iterate(it) and instead use the state you just loaded:)