Resume chains Turing

Hi to everybody,
I want to use Turing on the lsf cluster of my university. Following the official documentation and using the ClusterManager package I was able to run Turing on my cluster.

using Turing
using Distributed
using ClusterManager

#add 20 processes on the long queue
ClusterManagers.addprocs_lsf(20; bsub_flags = `-q long`)

#loading Turing on the processes
@everywhere using Turing

#toy model found in the Turing documentation
@everywhere @model function gdemo(x)
           s² ~ InverseGamma(2, 3)
           m ~ Normal(0, sqrt(s²))

           for i in eachindex(x)
               x[i] ~ Normal(m, sqrt(s²))
           end
       end

@everywhere model = gdemo([1.5, 2.0])

#sampling
chains = sample(model, NUTS(), MCMCDistributed(), 1000, 10; save_state = true)

write("chain-file.jls", chains)

This worked: after generating the samples, by 10 processes on the cluster, the chains have been saved.
The next step would be to resume the chains from their state.

So, I tried this way

chains_reloaded = read("chain-file.jls", Chains)

#resume sampling
sample(model, NUTS(), MCMCDistributed(), 10000, 10; save_state = true, resume_from = chains_reloaded)

However, looks to me that this did not work. I obtained

Chains MCMC chain (10000×14×10 Array{Float64, 3}):

Iterations        = 1:1:10000
Number of chains  = 10
Samples per chain = 10000
Wall duration     = 1.85 seconds
Compute duration  = 5.34 seconds
parameters        = s², m
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std   naive_se      mcse          ess      rhat   ess_per_sec 
      Symbol   Float64   Float64    Float64   Float64      Float64   Float64       Float64 

          s²    2.0238    1.8881     0.0060    0.0102   35011.0145    1.0001     6556.3698
           m    1.1626    0.8159     0.0026    0.0037   46594.2097    1.0000     8725.5074

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

          s²    0.5652    1.0406    1.5241    2.3436    6.5242
           m   -0.4784    0.6858    1.1662    1.6391    2.7920

However, I expected to find 10 chains with 20000 samples…not 10000.
Is there anything I am doing wrong?

Cheers,
Marco

1 Like

I think you’ve asked for 10k samples starting from the end-point of the previous chain, and it gave you that, no? What were you expecting? that it would append your new samples and the original samples?

1 Like

How can I check if this are resumed from the previous chains? I would have preferred to have the new samples appended to the old ones. What is your suggestion? Manually merge the chains?

Yes, I think just concatenate the chains. I’m not sure how you can check, have you tried comparing the first sample of the new chain to the last sample of the earlier chain?

1 Like

Sorry for the late answer.
I checked and I found out that all chains restarts from the last point of the first chains (!)… I am confused :sweat_smile:

yeah, that’s what resume means right? start sampling where the first chains left off. Like if I drive to a nearby town, sleep overnight, and then resume my trip the next morning, it means I continue on from the town.

1 Like

Yeah, that’s true. But, if I am running more chains in parallel, I would expect that the i-th chain would resume from the last point of the i-th chain, not from the last point of the first chain.

Am I missing something?

Oh, I see what you’re saying. I misunderstood.

Yes, I don’t know why the code doesn’t do that, but I think your work around will be to manually @spawn N threads and then resume from each single chain?

1 Like

I don’t know if the right solution is a workaround or rather looking inside Turing code and open an issue or a PR. This is not the best approach to resume chains, IMHO.

You could certainly open a PR. Arguably it’s better to give the user control over where exactly they’re starting. You might for example have a chain that’s stuck and not want to restart that one. It’s not totally clear that there’s a single right answer.

Yeah, this is not intended functionality and needs to be fixed. An issue would be most welcome.

3 Likes

Here’s a related issue:

startch = sample(model,
                 MH(diagm([.002 for i in 1:npars])),1200; thinning=10, init_theta=op.values.array)

resumech = sample(model,MH(diagm([.002 for i in 1:npars])),1200;resume_from=startch)

The second line gives an error:

julia> resumech = sample(model,MH(diagm([.002 for i in 1:npars])),1200;resume_from=startch)
ERROR: type NamedTuple has no field model
Stacktrace:
 [1] getindex(t::NamedTuple{(:start_time, :stop_time), Tuple{Float64, Float64}}, i::Symbol)
   @ Base ./namedtuple.jl:127
 [2] resume(rng::Random._GLOBAL_RNG, chain::Chains{Float64, AxisArrays.AxisArray{Float64, 3, Array{Float64, 3}, Tuple{AxisArrays.Axis{:iter, StepRange{Int64, Int64}}, AxisArrays.Axis{:var, Vector{Symbol}}, AxisArrays.Axis{:chain, UnitRange{Int64}}}}, Missing, NamedTuple{(:parameters, :internals), Tuple{Vector{Symbol}, Vector{Symbol}}}, NamedTuple{(:start_time, :stop_time), Tuple{Float64, Float64}}}, args::Int64; progress::Bool, kwargs::Base.Pairs{Symbol, UnionAll, Tuple{Symbol}, NamedTuple{(:chain_type,), Tuple{UnionAll}}})
   @ Turing.Inference ~/.julia/packages/Turing/uMQmD/src/inference/Inference.jl:404
 [3] resume(chain::Chains{Float64, AxisArrays.AxisArray{Float64, 3, Array{Float64, 3}, Tuple{AxisArrays.Axis{:iter, StepRange{Int64, Int64}}, AxisArrays.Axis{:var, Vector{Symbol}}, AxisArrays.Axis{:chain, UnitRange{Int64}}}}, Missing, NamedTuple{(:parameters, :internals), Tuple{Vector{Symbol}, Vector{Symbol}}}, NamedTuple{(:start_time, :stop_time), Tuple{Float64, Float64}}}, args::Int64; kwargs::Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol}, NamedTuple{(:chain_type, :progress), Tuple{UnionAll, Bool}}})
   @ Turing.Inference ~/.julia/packages/Turing/uMQmD/src/inference/Inference.jl:396
 [4] sample(rng::Random._GLOBAL_RNG, model::DynamicPPL.Model{typeof(teamskill), (:Nteams, :orefteam, :drefteam, :m1, :m2, :m3, :m4, :week, :ns, :hometeam, :homepts, :awayteam, :awaypts), (), (), Tuple{Int64, Vector{Union{Nothing, Int64}}, Vector{Union{Nothing, Int64}}, Float64, Float64, Float64, Float64, Vector{Int64}, Vector{Bool}, Vector{Union{Nothing, Int64}}, Vector{Union{Missing, Int64}}, Vector{Union{Nothing, Int64}}, Vector{Union{Missing, Int64}}}, Tuple{}, DynamicPPL.DefaultContext}, sampler::DynamicPPL.Sampler{MH{(), RandomWalkProposal{false, ZeroMeanFullNormal{Tuple{Base.OneTo{Int64}}}}}}, N::Int64; chain_type::Type, resume_from::Chains{Float64, AxisArrays.AxisArray{Float64, 3, Array{Float64, 3}, Tuple{AxisArrays.Axis{:iter, StepRange{Int64, Int64}}, AxisArrays.Axis{:var, Vector{Symbol}}, AxisArrays.Axis{:chain, UnitRange{Int64}}}}, Missing, NamedTuple{(:parameters, :internals), Tuple{Vector{Symbol}, Vector{Symbol}}}, NamedTuple{(:start_time, :stop_time), Tuple{Float64, Float64}}}, progress::Bool, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
   @ Turing.Inference ~/.julia/packages/Turing/uMQmD/src/inference/Inference.jl:159
 [5] sample(rng::Random._GLOBAL_RNG, model::DynamicPPL.Model{typeof(teamskill), (:Nteams, :orefteam, :drefteam, :m1, :m2, :m3, :m4, :week, :ns, :hometeam, :homepts, :awayteam, :awaypts), (), (), Tuple{Int64, Vector{Union{Nothing, Int64}}, Vector{Union{Nothing, Int64}}, Float64, Float64, Float64, Float64, Vector{Int64}, Vector{Bool}, Vector{Union{Nothing, Int64}}, Vector{Union{Missing, Int64}}, Vector{Union{Nothing, Int64}}, Vector{Union{Missing, Int64}}}, Tuple{}, DynamicPPL.DefaultContext}, alg::MH{(), RandomWalkProposal{false, ZeroMeanFullNormal{Tuple{Base.OneTo{Int64}}}}}, N::Int64; kwargs::Base.Pairs{Symbol, Chains{Float64, AxisArrays.AxisArray{Float64, 3, Array{Float64, 3}, Tuple{AxisArrays.Axis{:iter, StepRange{Int64, Int64}}, AxisArrays.Axis{:var, Vector{Symbol}}, AxisArrays.Axis{:chain, UnitRange{Int64}}}}, Missing, NamedTuple{(:parameters, :internals), Tuple{Vector{Symbol}, Vector{Symbol}}}, NamedTuple{(:start_time, :stop_time), Tuple{Float64, Float64}}}, Tuple{Symbol}, NamedTuple{(:resume_from,), Tuple{Chains{Float64, AxisArrays.AxisArray{Float64, 3, Array{Float64, 3}, Tuple{AxisArrays.Axis{:iter, StepRange{Int64, Int64}}, AxisArrays.Axis{:var, Vector{Symbol}}, AxisArrays.Axis{:chain, UnitRange{Int64}}}}, Missing, NamedTuple{(:parameters, :internals), Tuple{Vector{Symbol}, Vector{Symbol}}}, NamedTuple{(:start_time, :stop_time), Tuple{Float64, Float64}}}}}})
   @ Turing.Inference ~/.julia/packages/Turing/uMQmD/src/inference/Inference.jl:142
 [6] sample(model::DynamicPPL.Model{typeof(teamskill), (:Nteams, :orefteam, :drefteam, :m1, :m2, :m3, :m4, :week, :ns, :hometeam, :homepts, :awayteam, :awaypts), (), (), Tuple{Int64, Vector{Union{Nothing, Int64}}, Vector{Union{Nothing, Int64}}, Float64, Float64, Float64, Float64, Vector{Int64}, Vector{Bool}, Vector{Union{Nothing, Int64}}, Vector{Union{Missing, Int64}}, Vector{Union{Nothing, Int64}}, Vector{Union{Missing, Int64}}}, Tuple{}, DynamicPPL.DefaultContext}, alg::MH{(), RandomWalkProposal{false, ZeroMeanFullNormal{Tuple{Base.OneTo{Int64}}}}}, N::Int64; kwargs::Base.Pairs{Symbol, Chains{Float64, AxisArrays.AxisArray{Float64, 3, Array{Float64, 3}, Tuple{AxisArrays.Axis{:iter, StepRange{Int64, Int64}}, AxisArrays.Axis{:var, Vector{Symbol}}, AxisArrays.Axis{:chain, UnitRange{Int64}}}}, Missing, NamedTuple{(:parameters, :internals), Tuple{Vector{Symbol}, Vector{Symbol}}}, NamedTuple{(:start_time, :stop_time), Tuple{Float64, Float64}}}, Tuple{Symbol}, NamedTuple{(:resume_from,), Tuple{Chains{Float64, AxisArrays.AxisArray{Float64, 3, Array{Float64, 3}, Tuple{AxisArrays.Axis{:iter, StepRange{Int64, Int64}}, AxisArrays.Axis{:var, Vector{Symbol}}, AxisArrays.Axis{:chain, UnitRange{Int64}}}}, Missing, NamedTuple{(:parameters, :internals), Tuple{Vector{Symbol}, Vector{Symbol}}}, NamedTuple{(:start_time, :stop_time), Tuple{Float64, Float64}}}}}})

did you find any way to do this?