Turing Model: Multiple return variables

Here:

julia> using Turing

julia> @model function observed_sum(N)
          p ~ Uniform(0,1)
          observed = []
          for i in 1:N
              true_answers ~ Bernoulli(p)
              first_coin_flip ~ Bernoulli(0.5)
              second_coin_flip ~ Bernoulli(0.5)

              if first_coin_flip
                 push!(observed, true_answers)
              else
                 second_coin_flip ? push!(observed, true) :  push!(observed, false)
             end
          end
         sum_obs = sum(observed)
        # return   observed
         return   observed, sum_obs
       end
observed_sum (generic function with 2 methods)

julia> model = observed_sum(10)
DynamicPPL.Model{typeof(observed_sum), (:N,), (), (), Tuple{Int64}, Tuple{}, DynamicPPL.DefaultContext}(:observed_sum, observed_sum, (N = 10,), NamedTuple(), DynamicPPL.DefaultContext())

julia> model = observed_sum(10)
DynamicPPL.Model{typeof(observed_sum), (:N,), (), (), Tuple{Int64}, Tuple{}, DynamicPPL.DefaultContext}(:observed_sum, observed_sum, (N = 10,), NamedTuple(), DynamicPPL.DefaultContext())

julia> methodswith(DynamicPPL.Model)
[1] observed_sum(__model__::DynamicPPL.Model, __varinfo__::DynamicPPL.AbstractVarInfo, __context__::DynamicPPL.AbstractContext, N) in Main at REPL[11]:1
[2] loglikelihood(left::NamedTuple, right::NamedTuple, _model::DynamicPPL.Model, _vi::Union{Nothing, DynamicPPL.VarInfo}) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/prob_macro.jl:188
[3] loglikelihood(model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:521
[4] bijector(model::DynamicPPL.Model) in Turing.Variational at /home/storopoli/.julia/packages/Turing/uMQmD/src/variational/advi.jl:7
[5] bijector(model::DynamicPPL.Model, ::Val{sym2ranges}; varinfo) where sym2ranges in Turing.Variational at /home/storopoli/.julia/packages/Turing/uMQmD/src/variational/advi.jl:7
[6] generated_quantities(model::DynamicPPL.Model, chain::AbstractMCMC.AbstractChains) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:585
[7] generated_quantities(model::DynamicPPL.Model, parameters::NamedTuple) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:629
[8] generated_quantities(model::DynamicPPL.Model, values, keys) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:635
[9] logjoint(model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:497
[10] logprior(model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:509
[11] logprior(left::NamedTuple, right::NamedTuple, _model::DynamicPPL.Model, _vi::Union{Nothing, DynamicPPL.VarInfo}) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/prob_macro.jl:129
[12] pointwise_loglikelihoods(model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/loglikelihoods.jl:243
[13] pointwise_loglikelihoods(model::DynamicPPL.Model, chain) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/loglikelihoods.jl:220
[14] pointwise_loglikelihoods(model::DynamicPPL.Model, chain, keytype::Type{T}) where T in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/loglikelihoods.jl:220
[15] predict(rng::Random.AbstractRNG, model::DynamicPPL.Model, chain::Chains; include_all) in Turing.Inference at /home/storopoli/.julia/packages/Turing/uMQmD/src/inference/Inference.jl:531
[16] predict(model::DynamicPPL.Model, chain::Chains; kwargs...) in Turing.Inference at /home/storopoli/.julia/packages/Turing/uMQmD/src/inference/Inference.jl:528
[17] vi(model::DynamicPPL.Model, alg::ADVI; optimizer) in Turing.Variational at /home/storopoli/.julia/packages/Turing/uMQmD/src/variational/advi.jl:90
[18] vi(model::DynamicPPL.Model, alg::ADVI, q::Bijectors.TransformedDistribution{var"#s149", B, V} where {var"#s149"<:DistributionsAD.TuringDiagMvNormal, B, V}; optimizer) in Turing.Variational at /home/storopoli/.julia/packages/Turing/uMQmD/src/variational/advi.jl:100
[19] condition(model::DynamicPPL.Model; values...) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:257
[20] condition(model::DynamicPPL.Model, values) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:258
[21] decondition(model::DynamicPPL.Model, syms...) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:310
[22] getargnames(model::DynamicPPL.Model{_F, argnames, defaultnames, missings, Targs, Tdefaults, Ctx} where {defaultnames, missings, Targs, Tdefaults, Ctx<:DynamicPPL.AbstractContext}) where {argnames, _F} in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:474
[23] getmissings(model::DynamicPPL.Model{_F, _a, _d, missings, Targs, Tdefaults, Ctx} where {Targs, Tdefaults, Ctx<:DynamicPPL.AbstractContext}) where {missings, _F, _a, _d} in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:481
[24] nameof(model::DynamicPPL.Model) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:488
[25] |(model::DynamicPPL.Model, values) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:93

All the things you can do with a instantiated Turing model.

For example Metropolis-Hastings monte carlo MH():

julia> chain = sample(model, MH(), 100)
Sampling 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:05
Chains MCMC chain (100×5×1 Array{Float64, 3}):

Iterations        = 1:1:100
Number of chains  = 1
Samples per chain = 100
Wall duration     = 5.87 seconds
Compute duration  = 5.87 seconds
parameters        = p, true_answers, first_coin_flip, second_coin_flip
internals         = lp

Summary Statistics
        parameters      mean       std   naive_se      mcse       ess      rhat   ess_per_sec
            Symbol   Float64   Float64    Float64   Float64   Float64   Float64       Float64

                 p    0.1192    0.2242     0.0224    0.0551   15.7810    1.0457        2.6870
      true_answers    0.0800    0.2727     0.0273    0.0696    8.6870    1.0801        1.4791
   first_coin_flip    0.5500    0.5000     0.0500    0.1352    9.3447    1.0000        1.5911
  second_coin_flip    0.5900    0.4943     0.0494    0.1362    3.7103    1.2364        0.6318

Quantiles
        parameters      2.5%     25.0%     50.0%     75.0%     97.5%
            Symbol   Float64   Float64   Float64   Float64   Float64

                 p    0.0002    0.0355    0.0372    0.0830    0.8566
      true_answers    0.0000    0.0000    0.0000    0.0000    1.0000
   first_coin_flip    0.0000    0.0000    1.0000    1.0000    1.0000
  second_coin_flip    0.0000    0.0000    1.0000    1.0000    1.0000

To generate the return values use the function generated_quantities:

help?> generated_quantities
search: generated_quantities

  generated_quantities(model::Model, chain::AbstractChains)

  Execute model for each of the samples in chain and return an array of the values returned by the model for each sample.

  Examples
  ≡≡≡≡≡≡≡≡≡≡

  General
  =========

  Often you might have additional quantities computed inside the model that you want to inspect, e.g.

  @model function demo(x)
      # sample and observe
      θ ~ Prior()
      x ~ Likelihood()
      return interesting_quantity(θ, x)
  end
  m = demo(data)
  chain = sample(m, alg, n)
  # To inspect the `interesting_quantity(θ, x)` where `θ` is replaced by samples
  # from the posterior/`chain`:
  generated_quantities(m, chain) # <= results in a `Vector` of returned values
                                 #    from `interesting_quantity(θ, x)`

  Concrete (and simple)
  =======================

  julia> using DynamicPPL, Turing

  julia> @model function demo(xs)
             s ~ InverseGamma(2, 3)
             m_shifted ~ Normal(10, √s)
             m = m_shifted - 10

             for i in eachindex(xs)
                 xs[i] ~ Normal(m, √s)
             end

             return (m, )
         end
  demo (generic function with 1 method)

  julia> model = demo(randn(10));

  julia> chain = sample(model, MH(), 10);

  julia> generated_quantities(model, chain)
  10×1 Array{Tuple{Float64},2}:
   (2.1964758025119338,)
   (2.1964758025119338,)
   (0.09270081916291417,)
   (0.09270081916291417,)
   (0.09270081916291417,)
   (0.09270081916291417,)
   (0.09270081916291417,)
   (0.043088571494005024,)
   (-0.16489786710222099,)
   (-0.16489786710222099,)

  ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

  generated_quantities(model::Model, parameters::NamedTuple)
  generated_quantities(model::Model, values, keys)
  generated_quantities(model::Model, values, keys)

  Execute model with variables keys set to values and return the values returned by the model.

  If a NamedTuple is given, keys=keys(parameters) and values=values(parameters).

  Example
  ≡≡≡≡≡≡≡≡≡

  julia> using DynamicPPL, Distributions

  julia> @model function demo(xs)
             s ~ InverseGamma(2, 3)
             m_shifted ~ Normal(10, √s)
             m = m_shifted - 10
             for i in eachindex(xs)
                 xs[i] ~ Normal(m, √s)
             end
             return (m, )
         end
  demo (generic function with 2 methods)

  julia> model = demo(randn(10));

  julia> parameters = (; s = 1.0, m_shifted=10);

  julia> generated_quantities(model, parameters)
  (0.0,)

  julia> generated_quantities(model, values(parameters), keys(parameters))
  (0.0,)

And voilà:

julia> generated_quantities(model, chain)
100×1 Matrix{Tuple{Vector{Any}, Int64}}:
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ⋮
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
1 Like