Turing Model: Multiple return variables

Hi, I’m new to Turing and Julia and I honestly can’t figure out what I am doing wrong here.
I have a model that I want to return:

  1. an array and of Boolean values of N values eg: [true, false, true, false]
  2. An integer value for the sum of the array above. eg: 2

When I define my model with only the observed array as a return, everything works as expected.
But if I define my model to return both the array and the sum value, then suddenly all the elements of the array becomes either all false or all true. And the sum is either 0 or N.
I don’t know why this keeps happening, please help
My code snippet is below:

N = 5

@model function observed_sum(N)
   p ~ Uniform(0,1) 
   observed = []  
   for i in 1:N 
       true_answers ~ Bernoulli(p) 
       first_coin_flip ~ Bernoulli(0.5) 
       second_coin_flip ~ Bernoulli(0.5) 

       if first_coin_flip 
          push!(observed, true_answers)
       else          
          second_coin_flip ? push!(observed, true) :  push!(observed, false)
      end
   end
  sum_obs = sum(observed)
 # return   observed        
  return   observed, sum_obs
end


x= observed_proportion(N)
# for i in 1:N println( "λ sample value is          = ", x()[i]) end         #this works for only observed return
for i in 1:2 println( "λ sample value is              = ", x()[i]) end      #This doesn't work

You defined the model with the @model macro. Now you need to sample from it.

Take a look at these tutorials, they might help.

1 Like

Here:

julia> using Turing

julia> @model function observed_sum(N)
          p ~ Uniform(0,1)
          observed = []
          for i in 1:N
              true_answers ~ Bernoulli(p)
              first_coin_flip ~ Bernoulli(0.5)
              second_coin_flip ~ Bernoulli(0.5)

              if first_coin_flip
                 push!(observed, true_answers)
              else
                 second_coin_flip ? push!(observed, true) :  push!(observed, false)
             end
          end
         sum_obs = sum(observed)
        # return   observed
         return   observed, sum_obs
       end
observed_sum (generic function with 2 methods)

julia> model = observed_sum(10)
DynamicPPL.Model{typeof(observed_sum), (:N,), (), (), Tuple{Int64}, Tuple{}, DynamicPPL.DefaultContext}(:observed_sum, observed_sum, (N = 10,), NamedTuple(), DynamicPPL.DefaultContext())

julia> model = observed_sum(10)
DynamicPPL.Model{typeof(observed_sum), (:N,), (), (), Tuple{Int64}, Tuple{}, DynamicPPL.DefaultContext}(:observed_sum, observed_sum, (N = 10,), NamedTuple(), DynamicPPL.DefaultContext())

julia> methodswith(DynamicPPL.Model)
[1] observed_sum(__model__::DynamicPPL.Model, __varinfo__::DynamicPPL.AbstractVarInfo, __context__::DynamicPPL.AbstractContext, N) in Main at REPL[11]:1
[2] loglikelihood(left::NamedTuple, right::NamedTuple, _model::DynamicPPL.Model, _vi::Union{Nothing, DynamicPPL.VarInfo}) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/prob_macro.jl:188
[3] loglikelihood(model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:521
[4] bijector(model::DynamicPPL.Model) in Turing.Variational at /home/storopoli/.julia/packages/Turing/uMQmD/src/variational/advi.jl:7
[5] bijector(model::DynamicPPL.Model, ::Val{sym2ranges}; varinfo) where sym2ranges in Turing.Variational at /home/storopoli/.julia/packages/Turing/uMQmD/src/variational/advi.jl:7
[6] generated_quantities(model::DynamicPPL.Model, chain::AbstractMCMC.AbstractChains) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:585
[7] generated_quantities(model::DynamicPPL.Model, parameters::NamedTuple) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:629
[8] generated_quantities(model::DynamicPPL.Model, values, keys) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:635
[9] logjoint(model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:497
[10] logprior(model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:509
[11] logprior(left::NamedTuple, right::NamedTuple, _model::DynamicPPL.Model, _vi::Union{Nothing, DynamicPPL.VarInfo}) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/prob_macro.jl:129
[12] pointwise_loglikelihoods(model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/loglikelihoods.jl:243
[13] pointwise_loglikelihoods(model::DynamicPPL.Model, chain) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/loglikelihoods.jl:220
[14] pointwise_loglikelihoods(model::DynamicPPL.Model, chain, keytype::Type{T}) where T in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/loglikelihoods.jl:220
[15] predict(rng::Random.AbstractRNG, model::DynamicPPL.Model, chain::Chains; include_all) in Turing.Inference at /home/storopoli/.julia/packages/Turing/uMQmD/src/inference/Inference.jl:531
[16] predict(model::DynamicPPL.Model, chain::Chains; kwargs...) in Turing.Inference at /home/storopoli/.julia/packages/Turing/uMQmD/src/inference/Inference.jl:528
[17] vi(model::DynamicPPL.Model, alg::ADVI; optimizer) in Turing.Variational at /home/storopoli/.julia/packages/Turing/uMQmD/src/variational/advi.jl:90
[18] vi(model::DynamicPPL.Model, alg::ADVI, q::Bijectors.TransformedDistribution{var"#s149", B, V} where {var"#s149"<:DistributionsAD.TuringDiagMvNormal, B, V}; optimizer) in Turing.Variational at /home/storopoli/.julia/packages/Turing/uMQmD/src/variational/advi.jl:100
[19] condition(model::DynamicPPL.Model; values...) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:257
[20] condition(model::DynamicPPL.Model, values) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:258
[21] decondition(model::DynamicPPL.Model, syms...) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:310
[22] getargnames(model::DynamicPPL.Model{_F, argnames, defaultnames, missings, Targs, Tdefaults, Ctx} where {defaultnames, missings, Targs, Tdefaults, Ctx<:DynamicPPL.AbstractContext}) where {argnames, _F} in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:474
[23] getmissings(model::DynamicPPL.Model{_F, _a, _d, missings, Targs, Tdefaults, Ctx} where {Targs, Tdefaults, Ctx<:DynamicPPL.AbstractContext}) where {missings, _F, _a, _d} in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:481
[24] nameof(model::DynamicPPL.Model) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:488
[25] |(model::DynamicPPL.Model, values) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:93

All the things you can do with a instantiated Turing model.

For example Metropolis-Hastings monte carlo MH():

julia> chain = sample(model, MH(), 100)
Sampling 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:05
Chains MCMC chain (100×5×1 Array{Float64, 3}):

Iterations        = 1:1:100
Number of chains  = 1
Samples per chain = 100
Wall duration     = 5.87 seconds
Compute duration  = 5.87 seconds
parameters        = p, true_answers, first_coin_flip, second_coin_flip
internals         = lp

Summary Statistics
        parameters      mean       std   naive_se      mcse       ess      rhat   ess_per_sec
            Symbol   Float64   Float64    Float64   Float64   Float64   Float64       Float64

                 p    0.1192    0.2242     0.0224    0.0551   15.7810    1.0457        2.6870
      true_answers    0.0800    0.2727     0.0273    0.0696    8.6870    1.0801        1.4791
   first_coin_flip    0.5500    0.5000     0.0500    0.1352    9.3447    1.0000        1.5911
  second_coin_flip    0.5900    0.4943     0.0494    0.1362    3.7103    1.2364        0.6318

Quantiles
        parameters      2.5%     25.0%     50.0%     75.0%     97.5%
            Symbol   Float64   Float64   Float64   Float64   Float64

                 p    0.0002    0.0355    0.0372    0.0830    0.8566
      true_answers    0.0000    0.0000    0.0000    0.0000    1.0000
   first_coin_flip    0.0000    0.0000    1.0000    1.0000    1.0000
  second_coin_flip    0.0000    0.0000    1.0000    1.0000    1.0000

To generate the return values use the function generated_quantities:

help?> generated_quantities
search: generated_quantities

  generated_quantities(model::Model, chain::AbstractChains)

  Execute model for each of the samples in chain and return an array of the values returned by the model for each sample.

  Examples
  ≡≡≡≡≡≡≡≡≡≡

  General
  =========

  Often you might have additional quantities computed inside the model that you want to inspect, e.g.

  @model function demo(x)
      # sample and observe
      θ ~ Prior()
      x ~ Likelihood()
      return interesting_quantity(θ, x)
  end
  m = demo(data)
  chain = sample(m, alg, n)
  # To inspect the `interesting_quantity(θ, x)` where `θ` is replaced by samples
  # from the posterior/`chain`:
  generated_quantities(m, chain) # <= results in a `Vector` of returned values
                                 #    from `interesting_quantity(θ, x)`

  Concrete (and simple)
  =======================

  julia> using DynamicPPL, Turing

  julia> @model function demo(xs)
             s ~ InverseGamma(2, 3)
             m_shifted ~ Normal(10, √s)
             m = m_shifted - 10

             for i in eachindex(xs)
                 xs[i] ~ Normal(m, √s)
             end

             return (m, )
         end
  demo (generic function with 1 method)

  julia> model = demo(randn(10));

  julia> chain = sample(model, MH(), 10);

  julia> generated_quantities(model, chain)
  10×1 Array{Tuple{Float64},2}:
   (2.1964758025119338,)
   (2.1964758025119338,)
   (0.09270081916291417,)
   (0.09270081916291417,)
   (0.09270081916291417,)
   (0.09270081916291417,)
   (0.09270081916291417,)
   (0.043088571494005024,)
   (-0.16489786710222099,)
   (-0.16489786710222099,)

  ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

  generated_quantities(model::Model, parameters::NamedTuple)
  generated_quantities(model::Model, values, keys)
  generated_quantities(model::Model, values, keys)

  Execute model with variables keys set to values and return the values returned by the model.

  If a NamedTuple is given, keys=keys(parameters) and values=values(parameters).

  Example
  ≡≡≡≡≡≡≡≡≡

  julia> using DynamicPPL, Distributions

  julia> @model function demo(xs)
             s ~ InverseGamma(2, 3)
             m_shifted ~ Normal(10, √s)
             m = m_shifted - 10
             for i in eachindex(xs)
                 xs[i] ~ Normal(m, √s)
             end
             return (m, )
         end
  demo (generic function with 2 methods)

  julia> model = demo(randn(10));

  julia> parameters = (; s = 1.0, m_shifted=10);

  julia> generated_quantities(model, parameters)
  (0.0,)

  julia> generated_quantities(model, values(parameters), keys(parameters))
  (0.0,)

And voilà:

julia> generated_quantities(model, chain)
100×1 Matrix{Tuple{Vector{Any}, Int64}}:
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ⋮
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([false, false, false, false, false, false, false, false, false, false], 0)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
 ([true, true, true, true, true, true, true, true, true, true], 10)
1 Like

Thank you so much!

1 Like