Here:
julia> using Turing
julia> @model function observed_sum(N)
p ~ Uniform(0,1)
observed = []
for i in 1:N
true_answers ~ Bernoulli(p)
first_coin_flip ~ Bernoulli(0.5)
second_coin_flip ~ Bernoulli(0.5)
if first_coin_flip
push!(observed, true_answers)
else
second_coin_flip ? push!(observed, true) : push!(observed, false)
end
end
sum_obs = sum(observed)
# return observed
return observed, sum_obs
end
observed_sum (generic function with 2 methods)
julia> model = observed_sum(10)
DynamicPPL.Model{typeof(observed_sum), (:N,), (), (), Tuple{Int64}, Tuple{}, DynamicPPL.DefaultContext}(:observed_sum, observed_sum, (N = 10,), NamedTuple(), DynamicPPL.DefaultContext())
julia> model = observed_sum(10)
DynamicPPL.Model{typeof(observed_sum), (:N,), (), (), Tuple{Int64}, Tuple{}, DynamicPPL.DefaultContext}(:observed_sum, observed_sum, (N = 10,), NamedTuple(), DynamicPPL.DefaultContext())
julia> methodswith(DynamicPPL.Model)
[1] observed_sum(__model__::DynamicPPL.Model, __varinfo__::DynamicPPL.AbstractVarInfo, __context__::DynamicPPL.AbstractContext, N) in Main at REPL[11]:1
[2] loglikelihood(left::NamedTuple, right::NamedTuple, _model::DynamicPPL.Model, _vi::Union{Nothing, DynamicPPL.VarInfo}) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/prob_macro.jl:188
[3] loglikelihood(model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:521
[4] bijector(model::DynamicPPL.Model) in Turing.Variational at /home/storopoli/.julia/packages/Turing/uMQmD/src/variational/advi.jl:7
[5] bijector(model::DynamicPPL.Model, ::Val{sym2ranges}; varinfo) where sym2ranges in Turing.Variational at /home/storopoli/.julia/packages/Turing/uMQmD/src/variational/advi.jl:7
[6] generated_quantities(model::DynamicPPL.Model, chain::AbstractMCMC.AbstractChains) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:585
[7] generated_quantities(model::DynamicPPL.Model, parameters::NamedTuple) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:629
[8] generated_quantities(model::DynamicPPL.Model, values, keys) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:635
[9] logjoint(model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:497
[10] logprior(model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:509
[11] logprior(left::NamedTuple, right::NamedTuple, _model::DynamicPPL.Model, _vi::Union{Nothing, DynamicPPL.VarInfo}) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/prob_macro.jl:129
[12] pointwise_loglikelihoods(model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/loglikelihoods.jl:243
[13] pointwise_loglikelihoods(model::DynamicPPL.Model, chain) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/loglikelihoods.jl:220
[14] pointwise_loglikelihoods(model::DynamicPPL.Model, chain, keytype::Type{T}) where T in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/loglikelihoods.jl:220
[15] predict(rng::Random.AbstractRNG, model::DynamicPPL.Model, chain::Chains; include_all) in Turing.Inference at /home/storopoli/.julia/packages/Turing/uMQmD/src/inference/Inference.jl:531
[16] predict(model::DynamicPPL.Model, chain::Chains; kwargs...) in Turing.Inference at /home/storopoli/.julia/packages/Turing/uMQmD/src/inference/Inference.jl:528
[17] vi(model::DynamicPPL.Model, alg::ADVI; optimizer) in Turing.Variational at /home/storopoli/.julia/packages/Turing/uMQmD/src/variational/advi.jl:90
[18] vi(model::DynamicPPL.Model, alg::ADVI, q::Bijectors.TransformedDistribution{var"#s149", B, V} where {var"#s149"<:DistributionsAD.TuringDiagMvNormal, B, V}; optimizer) in Turing.Variational at /home/storopoli/.julia/packages/Turing/uMQmD/src/variational/advi.jl:100
[19] condition(model::DynamicPPL.Model; values...) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:257
[20] condition(model::DynamicPPL.Model, values) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:258
[21] decondition(model::DynamicPPL.Model, syms...) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:310
[22] getargnames(model::DynamicPPL.Model{_F, argnames, defaultnames, missings, Targs, Tdefaults, Ctx} where {defaultnames, missings, Targs, Tdefaults, Ctx<:DynamicPPL.AbstractContext}) where {argnames, _F} in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:474
[23] getmissings(model::DynamicPPL.Model{_F, _a, _d, missings, Targs, Tdefaults, Ctx} where {Targs, Tdefaults, Ctx<:DynamicPPL.AbstractContext}) where {missings, _F, _a, _d} in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:481
[24] nameof(model::DynamicPPL.Model) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:488
[25] |(model::DynamicPPL.Model, values) in DynamicPPL at /home/storopoli/.julia/packages/DynamicPPL/RcfQU/src/model.jl:93
All the things you can do with a instantiated Turing model.
For example Metropolis-Hastings monte carlo MH()
:
julia> chain = sample(model, MH(), 100)
Sampling 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:05
Chains MCMC chain (100×5×1 Array{Float64, 3}):
Iterations = 1:1:100
Number of chains = 1
Samples per chain = 100
Wall duration = 5.87 seconds
Compute duration = 5.87 seconds
parameters = p, true_answers, first_coin_flip, second_coin_flip
internals = lp
Summary Statistics
parameters mean std naive_se mcse ess rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
p 0.1192 0.2242 0.0224 0.0551 15.7810 1.0457 2.6870
true_answers 0.0800 0.2727 0.0273 0.0696 8.6870 1.0801 1.4791
first_coin_flip 0.5500 0.5000 0.0500 0.1352 9.3447 1.0000 1.5911
second_coin_flip 0.5900 0.4943 0.0494 0.1362 3.7103 1.2364 0.6318
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
p 0.0002 0.0355 0.0372 0.0830 0.8566
true_answers 0.0000 0.0000 0.0000 0.0000 1.0000
first_coin_flip 0.0000 0.0000 1.0000 1.0000 1.0000
second_coin_flip 0.0000 0.0000 1.0000 1.0000 1.0000
To generate the return values use the function generated_quantities
:
help?> generated_quantities
search: generated_quantities
generated_quantities(model::Model, chain::AbstractChains)
Execute model for each of the samples in chain and return an array of the values returned by the model for each sample.
Examples
≡≡≡≡≡≡≡≡≡≡
General
=========
Often you might have additional quantities computed inside the model that you want to inspect, e.g.
@model function demo(x)
# sample and observe
θ ~ Prior()
x ~ Likelihood()
return interesting_quantity(θ, x)
end
m = demo(data)
chain = sample(m, alg, n)
# To inspect the `interesting_quantity(θ, x)` where `θ` is replaced by samples
# from the posterior/`chain`:
generated_quantities(m, chain) # <= results in a `Vector` of returned values
# from `interesting_quantity(θ, x)`
Concrete (and simple)
=======================
julia> using DynamicPPL, Turing
julia> @model function demo(xs)
s ~ InverseGamma(2, 3)
m_shifted ~ Normal(10, √s)
m = m_shifted - 10
for i in eachindex(xs)
xs[i] ~ Normal(m, √s)
end
return (m, )
end
demo (generic function with 1 method)
julia> model = demo(randn(10));
julia> chain = sample(model, MH(), 10);
julia> generated_quantities(model, chain)
10×1 Array{Tuple{Float64},2}:
(2.1964758025119338,)
(2.1964758025119338,)
(0.09270081916291417,)
(0.09270081916291417,)
(0.09270081916291417,)
(0.09270081916291417,)
(0.09270081916291417,)
(0.043088571494005024,)
(-0.16489786710222099,)
(-0.16489786710222099,)
──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
generated_quantities(model::Model, parameters::NamedTuple)
generated_quantities(model::Model, values, keys)
generated_quantities(model::Model, values, keys)
Execute model with variables keys set to values and return the values returned by the model.
If a NamedTuple is given, keys=keys(parameters) and values=values(parameters).
Example
≡≡≡≡≡≡≡≡≡
julia> using DynamicPPL, Distributions
julia> @model function demo(xs)
s ~ InverseGamma(2, 3)
m_shifted ~ Normal(10, √s)
m = m_shifted - 10
for i in eachindex(xs)
xs[i] ~ Normal(m, √s)
end
return (m, )
end
demo (generic function with 2 methods)
julia> model = demo(randn(10));
julia> parameters = (; s = 1.0, m_shifted=10);
julia> generated_quantities(model, parameters)
(0.0,)
julia> generated_quantities(model, values(parameters), keys(parameters))
(0.0,)
And voilà:
julia> generated_quantities(model, chain)
100×1 Matrix{Tuple{Vector{Any}, Int64}}:
([false, false, false, false, false, false, false, false, false, false], 0)
([true, true, true, true, true, true, true, true, true, true], 10)
([true, true, true, true, true, true, true, true, true, true], 10)
([true, true, true, true, true, true, true, true, true, true], 10)
([true, true, true, true, true, true, true, true, true, true], 10)
([true, true, true, true, true, true, true, true, true, true], 10)
([true, true, true, true, true, true, true, true, true, true], 10)
([true, true, true, true, true, true, true, true, true, true], 10)
([true, true, true, true, true, true, true, true, true, true], 10)
([true, true, true, true, true, true, true, true, true, true], 10)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([true, true, true, true, true, true, true, true, true, true], 10)
([true, true, true, true, true, true, true, true, true, true], 10)
([true, true, true, true, true, true, true, true, true, true], 10)
([true, true, true, true, true, true, true, true, true, true], 10)
([true, true, true, true, true, true, true, true, true, true], 10)
([true, true, true, true, true, true, true, true, true, true], 10)
([false, false, false, false, false, false, false, false, false, false], 0)
([true, true, true, true, true, true, true, true, true, true], 10)
([true, true, true, true, true, true, true, true, true, true], 10)
([true, true, true, true, true, true, true, true, true, true], 10)
⋮
([true, true, true, true, true, true, true, true, true, true], 10)
([true, true, true, true, true, true, true, true, true, true], 10)
([true, true, true, true, true, true, true, true, true, true], 10)
([true, true, true, true, true, true, true, true, true, true], 10)
([true, true, true, true, true, true, true, true, true, true], 10)
([true, true, true, true, true, true, true, true, true, true], 10)
([true, true, true, true, true, true, true, true, true, true], 10)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([true, true, true, true, true, true, true, true, true, true], 10)
([true, true, true, true, true, true, true, true, true, true], 10)
([true, true, true, true, true, true, true, true, true, true], 10)
([true, true, true, true, true, true, true, true, true, true], 10)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([false, false, false, false, false, false, false, false, false, false], 0)
([true, true, true, true, true, true, true, true, true, true], 10)
([true, true, true, true, true, true, true, true, true, true], 10)
([true, true, true, true, true, true, true, true, true, true], 10)
([true, true, true, true, true, true, true, true, true, true], 10)
([true, true, true, true, true, true, true, true, true, true], 10)
([true, true, true, true, true, true, true, true, true, true], 10)
([true, true, true, true, true, true, true, true, true, true], 10)