# Sampling predictive with MvNormal in Turing

tldr; none of these seem to work, and give different errors:

``````Vector{Float64}(undef, n) ~ MvNormal(yhat, Diagonal(fill(1.0, n)))
Vector{Real}(undef, n) ~ MvNormal(yhat, Diagonal(fill(1.0, n)))
Vector{Missing}(missing, n) ~ MvNormal(yhat, Diagonal(fill(1.0, n)))
``````

I was building this relatively simple univariate partial pooling model and was finding that it was very slow relative to numpyro, so I tried replacing all the for loops and switching to a different backend. That helped, but I’m now running into an issue sampling a predictive distribution from the model.

I’m relatively new to Turing and julia so I might be missing something basic.

Here’s a version of the model where everything works, using a `for` loop and `Normal`:

``````using Random: Random
using StatsBase
using Turing
using LinearAlgebra
using ArviZ: ArviZ

versioninfo()
using Pkg
Pkg.status()

seed_global = 42;
Random.seed!(seed_global);

data = Dict();
let
global data
X = ones(Float64, 3_000, 5);
X[:, 3:5] = Random.randn((3_000, 3));
data["X"] = X;
data["jstimulus"] = Random.rand(1:8, 3_000);
data["jsubject"] = Random.rand(1:50, 3_000);
data["y"] = Random.randn(3_000);
end;

@model function demo(;
y::Union{Vector{Float64},Vector{Real},Vector{Missing}},
X::Matrix{Float64},
jstimulus::Vector{Int},
jsubject::Vector{Int})

n_stimuli = length(unique(jstimulus))
n_subjects = length(unique(jsubject))

μ_prior_dist = Normal(0, 5)
σ_prior_dist = truncated(Normal(0, 10), 0, Inf)

μ0_stimulus ~ μ_prior_dist
σ0_stimulus ~ σ_prior_dist

μ0_subject ~ μ_prior_dist
σ0_subject ~ σ_prior_dist

μ1 ~ μ_prior_dist
σ1 ~ σ_prior_dist

μ2 ~ μ_prior_dist
σ2 ~ σ_prior_dist

μ3 ~ μ_prior_dist
σ3 ~ σ_prior_dist

β0_stimulus ~ filldist(Normal(μ0_stimulus, σ0_stimulus), n_stimuli)
β0_subject ~ filldist(Normal(μ0_subject, σ0_subject), n_subjects)
β1 ~ filldist(Normal(μ1, σ1), n_stimuli)
β2 ~ filldist(Normal(μ2, σ2), n_stimuli)
β3 ~ filldist(Normal(μ3, σ3), n_stimuli)

A = hcat(
β0_stimulus[jstimulus],
β0_subject[jsubject],
β1[jstimulus],
β2[jstimulus],
β3[jstimulus],
)
yhat = sum(A .* X, dims=2)

for i in eachindex(y)
y[i] ~ Normal(yhat[i], 1.0)
end

return (yhat=yhat,)
end

model_demo = demo(;
y=data["y"],
X=data["X"],
jstimulus=data["jstimulus"],
jsubject=data["jsubject"],
);

rng = Random.MersenneTwister(seed_global)

prior_chain = Turing.sample(rng, model_demo, Prior(), 200)

model_predictive = demo(;
y=similar(data["y"], Missing),
X=data["X"],
jstimulus=data["jstimulus"],
jsubject=data["jsubject"],
);

prior_predictive = Turing.predict(rng, model_predictive, prior_chain)
posterior_predictive = Turing.predict(rng, model_predictive, posterior_chain)
``````

However, if I replace the for loop with

``````    Σ = Diagonal(fill(1.0, length(yhat)))
y ~ MvNormal(vec(yhat), Σ)
``````

I get this error:

``````julia> prior_predictive = Turing.predict(rng, model_predictive, prior_chain)
ERROR: MethodError: no method matching loglikelihood(::DiagNormal, ::Vector{Missing})
``````

So I tried using
`y=Vector{Float64}(undef, length(data["y"]))`
rather than `y=similar(data["y"], Missing)`, as is suggested in the documentation.

That allows the model to run, but it doesn’t seem to treat the Float64 undef vector as parameters, and returns unusable chains:

``````julia> prior_predictive = Turing.predict(rng, model_predictive, prior_chain)
Chains MCMC chain (200×0×1 Array{Float64, 3}):

Iterations        = 1:1:200
Number of chains  = 1
Samples per chain = 200
parameters        =
internals         =

Summary Statistics
parameters   mean   std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec
Symbol    Any   Any   Float64    Float64    Float64   Float64       Missing

Quantiles
parameters   2.5%   25.0%   50.0%   75.0%   97.5%
Symbol    Any     Any     Any     Any     Any
``````

And I also tried `y=Vector{Real}(undef, length(data["y"]))`, which gives this error:

``````prior_predictive = Turing.predict(Random.MersenneTwister(seed_global), model_predictive, prior_chain)
Stacktrace:
[1] getindex
@ ./essentials.jl:13 [inlined]
[3] _getindex
[5] getindex
[6] copy
[7] materialize
[8] sqmahal(d::DiagNormal, x::Vector{Real})
@ Distributions ~/.julia/packages/Distributions/ji8PW/src/multivariate/mvnormal.jl:267
[9] _logpdf(d::DiagNormal, x::Vector{Real})
@ Distributions ~/.julia/packages/Distributions/ji8PW/src/multivariate/mvnormal.jl:143
[10] logpdf
@ ~/.julia/packages/Distributions/ji8PW/src/common.jl:263 [inlined]
[11] loglikelihood
@ ~/.julia/packages/Distributions/ji8PW/src/common.jl:448 [inlined]
[12] observe
@ ~/.julia/packages/DynamicPPL/E4kDs/src/context_implementations.jl:266 [inlined]
[13] observe
@ ~/.julia/packages/DynamicPPL/E4kDs/src/context_implementations.jl:263 [inlined]
[14] tilde_observe
@ ~/.julia/packages/DynamicPPL/E4kDs/src/context_implementations.jl:158 [inlined]
[15] tilde_observe
@ ~/.julia/packages/DynamicPPL/E4kDs/src/context_implementations.jl:156 [inlined]
[16] tilde_observe
@ ~/.julia/packages/DynamicPPL/E4kDs/src/context_implementations.jl:151 [inlined]
@ DynamicPPL ~/.julia/packages/DynamicPPL/E4kDs/src/context_implementations.jl:207
[18] tilde_observe!!(context::DynamicPPL.SamplingContext{…}, right::DiagNormal, left::Vector{…}, vname::AbstractPPL.VarName{…}, vi::DynamicPPL.ThreadSafeVarInfo{…})
@ DynamicPPL ~/.julia/packages/DynamicPPL/E4kDs/src/context_implementations.jl:194
[19] demo(__model__::DynamicPPL.Model{…}, __varinfo__::DynamicPPL.ThreadSafeVarInfo{…}, __context__::DynamicPPL.SamplingContext{…}; y::Vector{…}, X::Matrix{…}, jstimulus::Vector{…}, jsubject::Vector{…})
@ Main ./REPL[19]:42
[20] demo
@ ./REPL[19]:1 [inlined]
[21] _evaluate!!
@ ~/.julia/packages/DynamicPPL/E4kDs/src/model.jl:963 [inlined]
@ ~/.julia/packages/DynamicPPL/E4kDs/src/model.jl:952 [inlined]
[23] evaluate!!
@ ~/.julia/packages/DynamicPPL/E4kDs/src/model.jl:887 [inlined]
[24] evaluate!!(model::DynamicPPL.Model{…}, rng::Random.TaskLocalRNG, varinfo::DynamicPPL.UntypedVarInfo{…}, sampler::DynamicPPL.SampleFromPrior, context::DynamicPPL.DefaultContext)
@ DynamicPPL ~/.julia/packages/DynamicPPL/E4kDs/src/model.jl:900
@ DynamicPPL ~/.julia/packages/DynamicPPL/E4kDs/src/model.jl:860
[26] VarInfo
@ ~/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:129 [inlined]
[27] VarInfo
@ ~/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:128 [inlined]
[28] VarInfo
@ ~/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:132 [inlined]
[29] transitions_from_chain(rng::Random.MersenneTwister, model::DynamicPPL.Model{…}, chain::Chains{…}; sampler::DynamicPPL.SampleFromPrior)
@ Turing.Inference ~/.julia/packages/Turing/IyijE/src/mcmc/Inference.jl:718
[30] predict(rng::Random.MersenneTwister, model::DynamicPPL.Model{…}, chain::Chains{…}; include_all::Bool)
@ Turing.Inference ~/.julia/packages/Turing/IyijE/src/mcmc/Inference.jl:631
[31] predict(rng::Random.MersenneTwister, model::DynamicPPL.Model{…}, chain::Chains{…})
@ Turing.Inference ~/.julia/packages/Turing/IyijE/src/mcmc/Inference.jl:624
[32] top-level scope
@ REPL[24]:1
Some type information was truncated. Use `show(err)` to see complete types.
``````

So again, apologies if I missed something basic in the documentation, but how should I go about sampling a predictive distribution using `MvNormal`? (As a side note, I first tried to use `DiagNormal` but I couldn’t even get it to work with the data, not sure if it’s an internal-only function?.) Thanks for the help!

``````julia> versioninfo()
Julia Version 1.10.4
Commit 48d4fd48430 (2024-06-04 10:41 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: macOS (arm64-apple-darwin22.4.0)
CPU: 12 × Apple M2 Pro
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, apple-m1)
Threads: 8 default, 0 interactive, 4 GC (on 8 virtual cores)
Environment:
JULIA_EDITOR = code

julia> Pkg.status()
[cbdf2221] AlgebraOfGraphics v0.6.19
[c7e460c6] ArgParse v1.2.0
[131c737c] ArviZ v0.10.5
[4a6e88f0] ArviZPythonPlots v0.1.5
[336ed68f] CSV v0.10.14
⌃ [13f3f980] CairoMakie v0.11.11
[324d7699] CategoricalArrays v0.10.8
[a93c6f00] DataFrames v1.6.1
[1a297f60] FillArrays v1.11.0
[663a7486] FreeTypeAbstraction v0.10.3
[682c06a0] JSON v0.21.4
[98e50ef6] JuliaFormatter v1.0.56
⌅ [ee78f7c6] Makie v0.20.10
[7f7a1694] Optimization v3.25.1
[b1d3bc72] Pathfinder v0.8.7
[f27b6e38] Polynomials v4.0.9
[438e738f] PyCall v1.96.4
[37e2e3b7] ReverseDiff v1.15.3
[295af30f] Revise v3.5.14
[2913bbd2] StatsBase v0.34.3
[f3b207a7] StatsPlots v0.15.7
[fce5fe82] Turing v0.32.3
[e88e6eb3] Zygote v0.6.70
``````

It appears that the `MvNormal` method fails when `y` is a kwarg, but works when `y` is an arg. I’m new to julia and Turing but this behavior is unexpected to me. Probably a Turing bug?

``````using Random: Random
using StatsBase
using Turing
using LinearAlgebra

data = Dict();
let
global data
X = ones(Float64, 300, 5)
X[:, 3:5] = Random.randn((300, 3))
data["y"] = Random.randn(300);
data["X"] = X
data["jstimulus"] = Random.rand(1:10, 300);
data["jsubject"] = Random.rand(1:20, 300);
end;

@model function demo_arg(y, X, jstimulus, jsubject, ::Type{T}=Float64) where {T}
n = size(X, 1)
if y === missing
y = Vector{T}(undef, n)
end

n_stimuli = length(unique(jstimulus))
n_subjects = length(unique(jsubject))

μ_prior_dist = Normal(0, 5)
σ_prior_dist = truncated(Normal(0, 10), 0, Inf)

μ0_stimulus ~ μ_prior_dist
σ0_stimulus ~ σ_prior_dist

μ0_subject ~ μ_prior_dist
σ0_subject ~ σ_prior_dist

μ1 ~ μ_prior_dist
σ1 ~ σ_prior_dist

μ2 ~ μ_prior_dist
σ2 ~ σ_prior_dist

μ3 ~ μ_prior_dist
σ3 ~ σ_prior_dist

β0_stimulus ~ filldist(Normal(μ0_stimulus, σ0_stimulus), n_stimuli)
β0_subject ~ filldist(Normal(μ0_subject, σ0_subject), n_subjects)
β1 ~ filldist(Normal(μ1, σ1), n_stimuli)
β2 ~ filldist(Normal(μ2, σ2), n_stimuli)
β3 ~ filldist(Normal(μ3, σ3), n_stimuli)

A = hcat(
β0_stimulus[jstimulus],
β0_subject[jsubject],
β1[jstimulus],
β2[jstimulus],
β3[jstimulus],
)
yhat = vec(sum(A .* X, dims=2))

Σ = Diagonal(fill(1.0^2, n))
y ~ MvNormal(yhat, Σ)
return (yhat=yhat,)
end

@model function demo_kwarg(::Type{T}=Float64; y=missing, X, jstimulus, jsubject) where {T}
n = size(X, 1)
if y === missing
y = Vector{T}(undef, n)
end

n_stimuli = length(unique(jstimulus))
n_subjects = length(unique(jsubject))

μ_prior_dist = Normal(0, 5)
σ_prior_dist = truncated(Normal(0, 10), 0, Inf)

μ0_stimulus ~ μ_prior_dist
σ0_stimulus ~ σ_prior_dist

μ0_subject ~ μ_prior_dist
σ0_subject ~ σ_prior_dist

μ1 ~ μ_prior_dist
σ1 ~ σ_prior_dist

μ2 ~ μ_prior_dist
σ2 ~ σ_prior_dist

μ3 ~ μ_prior_dist
σ3 ~ σ_prior_dist

β0_stimulus ~ filldist(Normal(μ0_stimulus, σ0_stimulus), n_stimuli)
β0_subject ~ filldist(Normal(μ0_subject, σ0_subject), n_subjects)
β1 ~ filldist(Normal(μ1, σ1), n_stimuli)
β2 ~ filldist(Normal(μ2, σ2), n_stimuli)
β3 ~ filldist(Normal(μ3, σ3), n_stimuli)

A = hcat(
β0_stimulus[jstimulus],
β0_subject[jsubject],
β1[jstimulus],
β2[jstimulus],
β3[jstimulus],
)
yhat = vec(sum(A .* X, dims=2))

Σ = Diagonal(fill(1.0^2, n))
y ~ MvNormal(yhat, Σ)
return (y=y,)
end

seed_global = 42;
rng = Random.MersenneTwister(seed_global);

### this works as expected ###
model_demo_arg = demo_arg(missing, data["X"], data["jstimulus"], data["jsubject"])
chain_arg = sample(rng, model_demo_arg, HMC(0.01, 5), 500)

### this fails ###
model_demo_kwarg = demo_kwarg(; y=missing, X=data["X"], jstimulus=data["jstimulus"], jsubject=data["jsubject"])
chain_kwarg = sample(rng, model_demo_kwarg, HMC(0.01, 5), 500)
``````
``````julia> chain_kwarg = sample(rng, model_demo_kwarg, HMC(0.01, 5), 500)
ERROR: DomainError with Dual{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}}(NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN):
Normal: the condition σ >= zero(σ) is not satisfied.
Stacktrace:
[1] #371
@ ~/.julia/packages/Distributions/ji8PW/src/univariate/continuous/normal.jl:37 [inlined]
[2] check_args
@ ~/.julia/packages/Distributions/ji8PW/src/utils.jl:89 [inlined]
[3] #Normal#370
@ ~/.julia/packages/Distributions/ji8PW/src/univariate/continuous/normal.jl:37 [inlined]
[4] Normal
@ ~/.julia/packages/Distributions/ji8PW/src/univariate/continuous/normal.jl:36 [inlined]
[5] demo_kwarg(__model__::DynamicPPL.Model{…}, __varinfo__::DynamicPPL.ThreadSafeVarInfo{…}, __context__::DynamicPPL.SamplingContext{…}, arg#472::DynamicPPL.TypeWrap{…}; y::Missing, X::Matrix{…}, jstimulus::Vector{…}, jsubject::Vector{…})
@ Main ./REPL[26]:21
[6] demo_kwarg
@ ./REPL[26]:1 [inlined]
[7] _evaluate!!
@ ~/.julia/packages/DynamicPPL/E4kDs/src/model.jl:963 [inlined]
@ ~/.julia/packages/DynamicPPL/E4kDs/src/model.jl:952 [inlined]
[9] evaluate!!(model::DynamicPPL.Model{…}, varinfo::DynamicPPL.TypedVarInfo{…}, context::DynamicPPL.SamplingContext{…})
@ DynamicPPL ~/.julia/packages/DynamicPPL/E4kDs/src/model.jl:887
[10] logdensity
@ ~/.julia/packages/DynamicPPL/E4kDs/src/logdensityfunction.jl:94 [inlined]
[11] Fix1
@ ./operators.jl:1118 [inlined]
[16] ∂logπ∂θ
@ ~/.julia/packages/Turing/IyijE/src/mcmc/hmc.jl:159 [inlined]
[17] ∂H∂θ
[19] step
[20] sample_phasepoint
[22] transition
[23] step(rng::Random.MersenneTwister, model::DynamicPPL.Model{…}, spl::DynamicPPL.Sampler{…}, state::Turing.Inference.HMCState{…}; nadapts::Int64, kwargs::@Kwargs{})
@ Turing.Inference ~/.julia/packages/Turing/IyijE/src/mcmc/hmc.jl:240
[24] step(rng::Random.MersenneTwister, model::DynamicPPL.Model{…}, spl::DynamicPPL.Sampler{…}, state::Turing.Inference.HMCState{…})
@ Turing.Inference ~/.julia/packages/Turing/IyijE/src/mcmc/hmc.jl:226
[25] macro expansion
@ ~/.julia/packages/AbstractMCMC/YrmkI/src/sample.jl:176 [inlined]
[26] macro expansion
@ ~/.julia/packages/ProgressLogging/6KXlp/src/ProgressLogging.jl:328 [inlined]
[27] macro expansion
@ ~/.julia/packages/AbstractMCMC/YrmkI/src/logging.jl:9 [inlined]
[28] mcmcsample(rng::Random.MersenneTwister, model::DynamicPPL.Model{…}, sampler::DynamicPPL.Sampler{…}, N::Int64; progress::Bool, progressname::String, callback::Nothing, discard_initial::Int64, thinning::Int64, chain_type::Type, initial_state::Nothing, kwargs::@Kwargs{})
@ AbstractMCMC ~/.julia/packages/AbstractMCMC/YrmkI/src/sample.jl:120
[29] sample(rng::Random.MersenneTwister, model::DynamicPPL.Model{…}, sampler::DynamicPPL.Sampler{…}, N::Int64; chain_type::Type, resume_from::Nothing, initial_state::Nothing, kwargs::@Kwargs{})
@ DynamicPPL ~/.julia/packages/DynamicPPL/E4kDs/src/sampler.jl:93
[30] sample
@ ~/.julia/packages/DynamicPPL/E4kDs/src/sampler.jl:83 [inlined]
[31] #sample#4
@ ~/.julia/packages/Turing/IyijE/src/mcmc/Inference.jl:263 [inlined]
[32] sample(rng::Random.MersenneTwister, model::DynamicPPL.Model{…}, alg::HMC{…}, N::Int64)
@ Turing.Inference ~/.julia/packages/Turing/IyijE/src/mcmc/Inference.jl:256
[33] top-level scope
@ REPL[34]:1
Some type information was truncated. Use `show(err)` to see complete types.
``````

Anyway, I moved to the new conditioning syntax, which is working

``````@model function demo(;
X::Matrix{Float64},
jstimulus::Vector{Int},
jsubject::Vector{Int})

n_stimuli = length(unique(jstimulus))
n_subjects = length(unique(jsubject))

μ_prior_dist = Normal(0, 5)
σ_prior_dist = truncated(Normal(0, 10), 0, Inf)

μ0_stimulus ~ μ_prior_dist
σ0_stimulus ~ σ_prior_dist

μ0_subject ~ μ_prior_dist
σ0_subject ~ σ_prior_dist

μ1 ~ μ_prior_dist
σ1 ~ σ_prior_dist

μ2 ~ μ_prior_dist
σ2 ~ σ_prior_dist

μ3 ~ μ_prior_dist
σ3 ~ σ_prior_dist

β0_stimulus ~ filldist(Normal(μ0_stimulus, σ0_stimulus), n_stimuli)
β0_subject ~ filldist(Normal(μ0_subject, σ0_subject), n_subjects)
β1 ~ filldist(Normal(μ1, σ1), n_stimuli)
β2 ~ filldist(Normal(μ2, σ2), n_stimuli)
β3 ~ filldist(Normal(μ3, σ3), n_stimuli)

A = hcat(
β0_stimulus[jstimulus],
β0_subject[jsubject],
β1[jstimulus],
β2[jstimulus],
β3[jstimulus],
)
yhat := sum(A .* X, dims=2)

y ~ MvNormal(vec(yhat), LinearAlgebra.I)

return nothing
end

model_demo = demo(;
X=data["X"],
jstimulus=data["jstimulus"],
jsubject=data["jsubject"],
) | (; y=data["y"])

model_predictive = demo(;
X=data["X"],
jstimulus=data["jstimulus"],
jsubject=data["jsubject"],
)
``````