tldr; none of these seem to work, and give different errors:
Vector{Float64}(undef, n) ~ MvNormal(yhat, Diagonal(fill(1.0, n)))
Vector{Real}(undef, n) ~ MvNormal(yhat, Diagonal(fill(1.0, n)))
Vector{Missing}(missing, n) ~ MvNormal(yhat, Diagonal(fill(1.0, n)))
I was building this relatively simple univariate partial pooling model and was finding that it was very slow relative to numpyro, so I tried replacing all the for loops and switching to a different backend. That helped, but I’m now running into an issue sampling a predictive distribution from the model.
I’m relatively new to Turing and julia so I might be missing something basic.
Here’s a version of the model where everything works, using a for
loop and Normal
:
using Random: Random
using StatsBase
using Turing
using LinearAlgebra
using ArviZ: ArviZ
versioninfo()
using Pkg
Pkg.status()
seed_global = 42;
Random.seed!(seed_global);
data = Dict();
let
global data
X = ones(Float64, 3_000, 5);
X[:, 3:5] = Random.randn((3_000, 3));
data["X"] = X;
data["jstimulus"] = Random.rand(1:8, 3_000);
data["jsubject"] = Random.rand(1:50, 3_000);
data["y"] = Random.randn(3_000);
end;
@model function demo(;
y::Union{Vector{Float64},Vector{Real},Vector{Missing}},
X::Matrix{Float64},
jstimulus::Vector{Int},
jsubject::Vector{Int})
n_stimuli = length(unique(jstimulus))
n_subjects = length(unique(jsubject))
μ_prior_dist = Normal(0, 5)
σ_prior_dist = truncated(Normal(0, 10), 0, Inf)
μ0_stimulus ~ μ_prior_dist
σ0_stimulus ~ σ_prior_dist
μ0_subject ~ μ_prior_dist
σ0_subject ~ σ_prior_dist
μ1 ~ μ_prior_dist
σ1 ~ σ_prior_dist
μ2 ~ μ_prior_dist
σ2 ~ σ_prior_dist
μ3 ~ μ_prior_dist
σ3 ~ σ_prior_dist
β0_stimulus ~ filldist(Normal(μ0_stimulus, σ0_stimulus), n_stimuli)
β0_subject ~ filldist(Normal(μ0_subject, σ0_subject), n_subjects)
β1 ~ filldist(Normal(μ1, σ1), n_stimuli)
β2 ~ filldist(Normal(μ2, σ2), n_stimuli)
β3 ~ filldist(Normal(μ3, σ3), n_stimuli)
A = hcat(
β0_stimulus[jstimulus],
β0_subject[jsubject],
β1[jstimulus],
β2[jstimulus],
β3[jstimulus],
)
yhat = sum(A .* X, dims=2)
for i in eachindex(y)
y[i] ~ Normal(yhat[i], 1.0)
end
return (yhat=yhat,)
end
model_demo = demo(;
y=data["y"],
X=data["X"],
jstimulus=data["jstimulus"],
jsubject=data["jsubject"],
);
rng = Random.MersenneTwister(seed_global)
posterior_chain = sample(rng, model_demo, Turing.NUTS(), 100; n_adapts=200, discard_adapt=true)
prior_chain = Turing.sample(rng, model_demo, Prior(), 200)
model_predictive = demo(;
y=similar(data["y"], Missing),
X=data["X"],
jstimulus=data["jstimulus"],
jsubject=data["jsubject"],
);
prior_predictive = Turing.predict(rng, model_predictive, prior_chain)
posterior_predictive = Turing.predict(rng, model_predictive, posterior_chain)
However, if I replace the for loop with
Σ = Diagonal(fill(1.0, length(yhat)))
y ~ MvNormal(vec(yhat), Σ)
I get this error:
julia> prior_predictive = Turing.predict(rng, model_predictive, prior_chain)
ERROR: MethodError: no method matching loglikelihood(::DiagNormal, ::Vector{Missing})
So I tried using
y=Vector{Float64}(undef, length(data["y"]))
rather than y=similar(data["y"], Missing)
, as is suggested in the documentation.
That allows the model to run, but it doesn’t seem to treat the Float64 undef vector as parameters, and returns unusable chains:
julia> prior_predictive = Turing.predict(rng, model_predictive, prior_chain)
Chains MCMC chain (200×0×1 Array{Float64, 3}):
Iterations = 1:1:200
Number of chains = 1
Samples per chain = 200
parameters =
internals =
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec
Symbol Any Any Float64 Float64 Float64 Float64 Missing
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Any Any Any Any Any
And I also tried y=Vector{Real}(undef, length(data["y"]))
, which gives this error:
prior_predictive = Turing.predict(Random.MersenneTwister(seed_global), model_predictive, prior_chain)
ERROR: UndefRefError: access to undefined reference
Stacktrace:
[1] getindex
@ ./essentials.jl:13 [inlined]
[2] _broadcast_getindex
@ ./broadcast.jl:675 [inlined]
[3] _getindex
@ ./broadcast.jl:705 [inlined]
[4] _broadcast_getindex
@ ./broadcast.jl:681 [inlined]
[5] getindex
@ ./broadcast.jl:636 [inlined]
[6] copy
@ ./broadcast.jl:942 [inlined]
[7] materialize
@ ./broadcast.jl:903 [inlined]
[8] sqmahal(d::DiagNormal, x::Vector{Real})
@ Distributions ~/.julia/packages/Distributions/ji8PW/src/multivariate/mvnormal.jl:267
[9] _logpdf(d::DiagNormal, x::Vector{Real})
@ Distributions ~/.julia/packages/Distributions/ji8PW/src/multivariate/mvnormal.jl:143
[10] logpdf
@ ~/.julia/packages/Distributions/ji8PW/src/common.jl:263 [inlined]
[11] loglikelihood
@ ~/.julia/packages/Distributions/ji8PW/src/common.jl:448 [inlined]
[12] observe
@ ~/.julia/packages/DynamicPPL/E4kDs/src/context_implementations.jl:266 [inlined]
[13] observe
@ ~/.julia/packages/DynamicPPL/E4kDs/src/context_implementations.jl:263 [inlined]
[14] tilde_observe
@ ~/.julia/packages/DynamicPPL/E4kDs/src/context_implementations.jl:158 [inlined]
[15] tilde_observe
@ ~/.julia/packages/DynamicPPL/E4kDs/src/context_implementations.jl:156 [inlined]
[16] tilde_observe
@ ~/.julia/packages/DynamicPPL/E4kDs/src/context_implementations.jl:151 [inlined]
[17] tilde_observe!!(context::DynamicPPL.SamplingContext{…}, right::DiagNormal, left::Vector{…}, vi::DynamicPPL.ThreadSafeVarInfo{…})
@ DynamicPPL ~/.julia/packages/DynamicPPL/E4kDs/src/context_implementations.jl:207
[18] tilde_observe!!(context::DynamicPPL.SamplingContext{…}, right::DiagNormal, left::Vector{…}, vname::AbstractPPL.VarName{…}, vi::DynamicPPL.ThreadSafeVarInfo{…})
@ DynamicPPL ~/.julia/packages/DynamicPPL/E4kDs/src/context_implementations.jl:194
[19] demo(__model__::DynamicPPL.Model{…}, __varinfo__::DynamicPPL.ThreadSafeVarInfo{…}, __context__::DynamicPPL.SamplingContext{…}; y::Vector{…}, X::Matrix{…}, jstimulus::Vector{…}, jsubject::Vector{…})
@ Main ./REPL[19]:42
[20] demo
@ ./REPL[19]:1 [inlined]
[21] _evaluate!!
@ ~/.julia/packages/DynamicPPL/E4kDs/src/model.jl:963 [inlined]
[22] evaluate_threadsafe!!
@ ~/.julia/packages/DynamicPPL/E4kDs/src/model.jl:952 [inlined]
[23] evaluate!!
@ ~/.julia/packages/DynamicPPL/E4kDs/src/model.jl:887 [inlined]
[24] evaluate!!(model::DynamicPPL.Model{…}, rng::Random.TaskLocalRNG, varinfo::DynamicPPL.UntypedVarInfo{…}, sampler::DynamicPPL.SampleFromPrior, context::DynamicPPL.DefaultContext)
@ DynamicPPL ~/.julia/packages/DynamicPPL/E4kDs/src/model.jl:900
[25] (::DynamicPPL.Model{…})(::Random.TaskLocalRNG, ::Vararg{…})
@ DynamicPPL ~/.julia/packages/DynamicPPL/E4kDs/src/model.jl:860
[26] VarInfo
@ ~/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:129 [inlined]
[27] VarInfo
@ ~/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:128 [inlined]
[28] VarInfo
@ ~/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:132 [inlined]
[29] transitions_from_chain(rng::Random.MersenneTwister, model::DynamicPPL.Model{…}, chain::Chains{…}; sampler::DynamicPPL.SampleFromPrior)
@ Turing.Inference ~/.julia/packages/Turing/IyijE/src/mcmc/Inference.jl:718
[30] predict(rng::Random.MersenneTwister, model::DynamicPPL.Model{…}, chain::Chains{…}; include_all::Bool)
@ Turing.Inference ~/.julia/packages/Turing/IyijE/src/mcmc/Inference.jl:631
[31] predict(rng::Random.MersenneTwister, model::DynamicPPL.Model{…}, chain::Chains{…})
@ Turing.Inference ~/.julia/packages/Turing/IyijE/src/mcmc/Inference.jl:624
[32] top-level scope
@ REPL[24]:1
Some type information was truncated. Use `show(err)` to see complete types.
So again, apologies if I missed something basic in the documentation, but how should I go about sampling a predictive distribution using MvNormal
? (As a side note, I first tried to use DiagNormal
but I couldn’t even get it to work with the data, not sure if it’s an internal-only function?.) Thanks for the help!
julia> versioninfo()
Julia Version 1.10.4
Commit 48d4fd48430 (2024-06-04 10:41 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: macOS (arm64-apple-darwin22.4.0)
CPU: 12 × Apple M2 Pro
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, apple-m1)
Threads: 8 default, 0 interactive, 4 GC (on 8 virtual cores)
Environment:
JULIA_EDITOR = code
JULIA_NUM_THREADS = 8
julia> Pkg.status()
[0bf59076] AdvancedHMC v0.6.1
[cbdf2221] AlgebraOfGraphics v0.6.19
[c7e460c6] ArgParse v1.2.0
[131c737c] ArviZ v0.10.5
[4a6e88f0] ArviZPythonPlots v0.1.5
[336ed68f] CSV v0.10.14
⌃ [13f3f980] CairoMakie v0.11.11
[324d7699] CategoricalArrays v0.10.8
[a93c6f00] DataFrames v1.6.1
[1a297f60] FillArrays v1.11.0
[663a7486] FreeTypeAbstraction v0.10.3
[682c06a0] JSON v0.21.4
[98e50ef6] JuliaFormatter v1.0.56
⌅ [ee78f7c6] Makie v0.20.10
[7f7a1694] Optimization v3.25.1
[b1d3bc72] Pathfinder v0.8.7
[f27b6e38] Polynomials v4.0.9
[438e738f] PyCall v1.96.4
[37e2e3b7] ReverseDiff v1.15.3
[295af30f] Revise v3.5.14
[2913bbd2] StatsBase v0.34.3
[f3b207a7] StatsPlots v0.15.7
[fce5fe82] Turing v0.32.3
[e88e6eb3] Zygote v0.6.70