Setting up multiple chains in Turing.jl

I’ve run a couple Metropolis chains and produced a Chains object for convergence diagnosis, but I’m pretty sure there is a better way to do this:

fit1 = sample(SBM(Y = Y, Pi = π₀, B = B₀),MH(1000,))
fit2 = sample(SBM(Y = Y, Pi = π₀, B = B₀),MH(1000,))
fit = Chains(cat( hcat(fit1[:z]…)‘, hcat(fit1[:z]…)’, dims = 3))

1 Like

I opened a issue to request internal handling multiple chains in parallel. In the meantime, I have been using pmap to run multiple chains.

using Distributed
Nchains = 4
addprocs(Nchains)
@everywhere begin
    #put all of your supporting code here
    push!(LOAD_PATH,"path/myfile.jl")
    using Turing,myModule
end

@everywhere @model mymodel() begin = 
    #stuff here
end

chain = pmap(x->sample(SBM(Y = Y, Pi = π₀, B = B₀),MH(1000,)),1:Nchains)
chains = removeBurnin(chain)
r̂ = gelmandiag(chains)

There is probably a more elegant solution for handling the chains and removing the burnin samples, but this is what I use:

function removeBurnin(ch::Array{Chain{AbstractRange{Int64}},1},Nadapt=0)
    nch = Chain()
    Nchains = length(ch)
    v = ch[1].value
    dims = (size(v,1)-Nadapt,size(v,2),Nchains)
    nch.value = fill(0.0,dims)
    value2 = []
    rng = (Nadapt+1):size(v,1)
    for (i,c) in enumerate(ch)
        nch.value[:,:,i] = c.value[rng,:,:]
        push!(value2,c.value2[rng])
    end
    nch.value2 =reshape(vcat(value2...),dims[1],dims[3])
    nch.names = ch[1].names
    nch.range = 1:dims[1]
    nch.chains = 1:Nchains
    return nch
end
1 Like

Won’t simply creating a new Chains object work, i.e.:

draws = adapt_cycles+1:samples
chn2 = MCMCChain.Chains(chn.value[draws,:,:], names=chn.names)

where chn is the uncorrected chn?

It would also be interesting to know if anyone has experimented with the number of required draws. I’m wondering if the numbers have gotten inflated because the results - as in describe(chn) - of the uncorrected chn were usually off. Slowly I’m testing that by reducing the number of draws in TuringModels.jl.

I believe this correct for the simple case with one chain. The problem is that Turing does not handle multiple chains internally. My function restructures an array of chains from pmap into a single chain. At some point, all the seemingly superfluous information was necessary to generate the plots (but not the diagnostics). That may have changed since I created the function.

Thanks, for now I will keep a copy of your removeBurnin around if you don’t mind.

This is very helpful, thanks

Please do, Rob. I owe you one… or ten.

You are welcome mkarikom. I’m glad it’s helpful.

Rob, I was wondering if you or someone else might now how to adapt removeBurnin to the new Chains object? Here is what I have so far:

function removeBurnin(chainArray::Array{<:Chains,1},Nadapt)
    Nchains = length(chainArray)
    ch = chainArray[1]
    v = ch.value
    dims = (size(v,1)-Nadapt,size(v,2),Nchains)
    value = Array{Union{Missing,Real}}(undef,dims...)
    rng = (Nadapt+1):size(v,1)
    for (i,c) in enumerate(chainArray)
        value[:,:,i] = c.value[rng,:,:]
    end
    nch = Chains(value,ch.name_map.parameters)
    return nch
end

Error Message:

ArgumentError: the length of each axis must match the corresponding size of data
AxisArray(::Array{Union{Missing, Float64},3}, ::Tuple{Axis{:iter,StepRange{Int64,Int64}},Axis{:var,Array{String,1}},Axis{:chain,Array{Symbol,1}}}) at core.jl:222
AxisArray(::Array{Union{Missing, Float64},3}, ::Axis{:iter,StepRange{Int64,Int64}}, ::Vararg{Union{Axis, AbstractArray{T,1} where T},N} where N) at core.jl:215
#Chains#12(::Int64, ::Int64, ::Float64, ::NamedTuple{(),Tuple{}}, ::Type, ::Array{Union{Missing, Float64},3}, ::Array{String,1}, ::Dict{Symbol,Array{Any,1}}) at chains.jl:93
Type at chains.jl:28 [inlined]
Chains(::Array{Union{Missing, Float64},3}, ::Array{String,1}) at chains.jl:28
top-level scope at none:0

Hi Chris, I’ll take a look this morning.

Thanks, Rob. From what I can tell, the issue is that value contains samples from the parameters and info from the MCMC sampler, such as lp.

As a workaround I removed the MCMC info:

function removeBurnin(chainArray::Array{<:Chains,1},Nadapt)
    Nchains = length(chainArray)
    ch = chainArray[1]
    v = ch.value
    parms = ch.name_map.parameters
    dims = (size(v,1)-Nadapt,length(parms),Nchains)
    value = Array{Union{Missing,Real}}(undef,dims...)
    rng = (Nadapt+1):size(v,1)
    for (i,c) in enumerate(chainArray)
        value[:,:,i] = c.value[rng,parms,:]
    end
    nch = Chains(value,parms)
    return nch
end

The downside is that the MCMC info is not included. I usually don’t use it, but there could be situations in which it might be needed for diagnostics. So it might be nice to have the info.

Hi Chris,

Given that you don’t use the :internals section that often, would the following work (provided the chn below is defined as the output of a single run)?

function removeBurnin(chainArray::Array{<:Chains,1},Nadapt, ::Val{:parameters})
    Nchains = length(chainArray)
    ch = chainArray[1]
    v = ch.value
    parms = ch.name_map.parameters
    dims = (size(v,1)-Nadapt,length(parms),Nchains)
    value = Array{Union{Missing,Real}}(undef,dims...)
    rng = (Nadapt+1):size(v,1)
    for (i,c) in enumerate(chainArray)
        value[:,:,i] = c.value[rng,parms,:]
    end
    nch = Chains(value,parms)
    return nch
end

function removeBurnin(chainArray::Array{<:Chains,1},Nadapt, ::Val{:internals})
    Nchains = length(chainArray)
    ch = chainArray[1]
    v = ch.value
    parms = ch.name_map.internals
    dims = (size(v,1)-Nadapt,length(parms),Nchains)
    value = Array{Union{Missing,Real}}(undef,dims...)
    rng = (Nadapt+1):size(v,1)
    for (i,c) in enumerate(chainArray)
        value[:,:,i] = c.value[rng,parms,:]
    end
    nch = Chains(value,parms)
    return nch
end

ca = Vector{MCMCChains.Chains}(undef, 4)
for i in 1:4
  ca[i] = chn
end

describe(ca[2])
describe(ca[2], section=:internals)

parameters_chain = removeBurnin(ca, 1000, Val(:parameters))

describe(parameters_chain)

internals_chain = removeBurnin(ca, 1000, Val(:internals))

describe(internals_chain)

It would not be too hard to recreate the normal structure (i.e. a Chains object with both sections present,
just let us know what your preference is, I stuck to how you did this as you are familiar with it). I’ve copied the master ( @cpfiffer )!

Other ways of removing the adaptation draws you can see in examples such as m10.4t.jl:

draws = 1001:2000
posterior2 = Chains(posterior[draws,:,:], :parameters)

you would have to apply to the elements of the chain array.

I love these sections! Especially in multilevel models, e.g. m10.4s.jl.

Thanks, Rob! I think both of these look good.

Of course, in practice, it might be convenient to have both sections present.

Agreed, I’ll make a new version.

By the way, I noticed that in some cases removeBurnin does not assign the values to the correct parameters. Here is a simple example:

using Distributed
addprocs(4)
@everywhere using Turing,Distributions

@everywhere @model model() = begin
    a ~ Normal(-10,.1)
    b ~ Normal(-8,.1)
    c ~ Normal(-6,.1)
    d ~ Normal(-4,.1)
    return nothing
end

Nsamples = 2000
Nadapt = 1000
δ = .85
specs = NUTS(Nsamples,Nadapt,δ)
chainArray = pmap(x->sample(model(),specs),1:4)
chain = removeBurnin(chainArray,Nadapt)

Output:

parameters
    Mean    SD   Naive SE  MCSE   ESS
**a -6.0009 0.0989   0.0016 0.0015 1000**
b -8.0006 0.0996   0.0016 0.0013 1000
**c -9.9994 0.1007   0.0016 0.0016 1000**
d -4.0009 0.1007   0.0016 0.0015 1000

I think it has something to do with slicing the array c.value.

Yip, I think in the new MCMCChains this mapping is via AxisArrays and can be tricky. I just finished another, somewhat similar exercise to split an existing section in 2 sections and ran into this as well. Now looking at removeBurnin().

Hi Chris, can you have a look at this setup?

using Distributed
addprocs(4)
@everywhere using Turing

@everywhere @model model() = begin
    a ~ Normal(-10,.1)
    b ~ Normal(-8,.1)
    c ~ Normal(-6,.1)
    d ~ Normal(-4,.1)
    return nothing
end

function flatten_name_map(chn::MCMCChains.AbstractChains)
  pn = String[]
  parms = values(chn.name_map)
  for i in 1:length(parms)
    for j in 1:length(parms[i])
      append!(pn, [String(parms[i][j])])
    end
  end
  pn
end

function construct_a3d(chn::MCMCChains.AbstractChains)
  cnames = flatten_name_map(chn)
  d, p, c = size(chn.value)
  a3d = fill(0.0, d, p, c);

  for (i, par) in enumerate(cnames)
    a3d[:, i, 1] = reshape(chn[par], d)
  end
  a3d
end

function removeBurnin(chainArray::Array{<:Chains,1},Nadapt)
    Nchains = length(chainArray)
    ch = chainArray[1]
    v = ch.value
    parms = flatten_name_map(ch)
    dims = (size(v,1)-Nadapt,length(parms),Nchains)
    value = Array{Union{Missing,Real}}(undef,dims...)
    rng = (Nadapt+1):size(v,1)
    for (i,c) in enumerate(chainArray)
        a3d = construct_a3d(c)
        value[:,:,i] = a3d[rng, :, :]
    end
    MCMCChains.Chains(value,
      Symbol.(parms),
      Dict(
        :parameters => Symbol.(values(ch.name_map.parameters)),
        :internals => Symbol.(values(ch.name_map.internals))
      )
    )
end

Nsamples = 2000
Nadapt = 1000
δ = .85
specs = NUTS(Nsamples,Nadapt,δ)
chainArray = pmap(x->sample(model(),specs),1:4)
chain = removeBurnin(chainArray,Nadapt)

Rob, this looks good as far as I can tell. Changing to AxisArrays seems to have complicated the manipulation of the chain object quite a bit. Thanks for your help!

I think you guys have a lot of moving machinery here. MCMCChains should support all this behavior. I added a chainscat function that adds chains together a little while ago — you should be able to do something like

chns = reduce(chainscat, pmap(x->sample(model(),specs),1:4))
subset = chns[Nadapt:Nsamples, :, :]

Here’s the output from describe(subset). Guess I should fix that quantile warning too, while I’m at it.

Log evidence      = 0.0
Iterations        = 1000:2000
Thinning interval = 1
Chains            = Chain1, Chain2, Chain3, Chain4
Samples per chain = 1001
parameters        = c, b, a, d

┌ Warning: `quantile(v::AbstractArray{<:Real})` is deprecated, use `quantile(v, [0.0, 0.25, 0.5, 0.75, 1.0])` instead.
│   caller = (::getfield(MCMCChains, Symbol("##102#104")){Chains{Union{Missing, Float64},Float64,NamedTuple{(:parameters,),Tuple{Array{String,1}}},NamedTuple{(:hashedsummary,),Tuple{Base.RefValue{Tuple{UInt64,MCMCChains.ChainSummaries}}}}}})(::String) at none:0
└ @ MCMCChains ./none:0
Empirical Posterior Estimates
───────────────────────────────────────────
parameters
    Mean    SD   Naive SE  MCSE   ESS
a -9.9998 0.0980 1.5×10⁻³ 0.0013 1001
b -7.9995 0.1019 1.6×10⁻³ 0.0016 1001
c -5.9986 0.1016 1.6×10⁻³ 0.0016 1001
d -3.9992 0.0992 1.6×10⁻³ 0.0017 1001

Quantiles
───────────────────────────────────────────
parameters
    2.5%     25.0%   50.0%   75.0%   97.5% 
a -10.3729 -10.0661 -9.9998 -9.9317 -9.6298
b  -8.3329  -8.0687 -8.0031 -7.9298 -7.6591
c  -6.3259  -6.0663 -5.9990 -5.9311 -5.6473
d  -4.3360  -4.0632 -3.9993 -3.9342 -3.6764

And here’s the full code I used to make that:

using Distributed
addprocs(4)
@everywhere using Turing,Distributions

@everywhere @model model() = begin
    a ~ Normal(-10,.1)
    b ~ Normal(-8,.1)
    c ~ Normal(-6,.1)
    d ~ Normal(-4,.1)
    return nothing
end

Nsamples = 2000
Nadapt = 1000
δ = .85
specs = NUTS(Nsamples,Nadapt,δ)
chns = reduce(chainscat, pmap(x->sample(model(),specs),1:4))
subset = chns[Nadapt:Nsamples, :, :]
2 Likes

Definitely looks simpler! I had looked at those cat functions but hadn’t figured out when to use them.

Will study them again! Thanks for your help!