Sampling predictive with MvNormal in Turing

tldr; none of these seem to work, and give different errors:

Vector{Float64}(undef, n) ~ MvNormal(yhat, Diagonal(fill(1.0, n)))
Vector{Real}(undef, n) ~ MvNormal(yhat, Diagonal(fill(1.0, n)))
Vector{Missing}(missing, n) ~ MvNormal(yhat, Diagonal(fill(1.0, n)))

I was building this relatively simple univariate partial pooling model and was finding that it was very slow relative to numpyro, so I tried replacing all the for loops and switching to a different backend. That helped, but I’m now running into an issue sampling a predictive distribution from the model.

I’m relatively new to Turing and julia so I might be missing something basic.

Here’s a version of the model where everything works, using a for loop and Normal:

using Random: Random
using StatsBase
using Turing
using LinearAlgebra
using ArviZ: ArviZ

versioninfo()
using Pkg
Pkg.status()

seed_global = 42;
Random.seed!(seed_global);

data = Dict();
let
    global data
    X = ones(Float64, 3_000, 5);
    X[:, 3:5] = Random.randn((3_000, 3));
    data["X"] = X;
    data["jstimulus"] = Random.rand(1:8, 3_000);
    data["jsubject"] = Random.rand(1:50, 3_000);
    data["y"] = Random.randn(3_000);
end;

@model function demo(;
    y::Union{Vector{Float64},Vector{Real},Vector{Missing}},
    X::Matrix{Float64},
    jstimulus::Vector{Int},
    jsubject::Vector{Int})

    n_stimuli = length(unique(jstimulus))
    n_subjects = length(unique(jsubject))

    μ_prior_dist = Normal(0, 5)
    σ_prior_dist = truncated(Normal(0, 10), 0, Inf)

    μ0_stimulus ~ μ_prior_dist
    σ0_stimulus ~ σ_prior_dist

    μ0_subject ~ μ_prior_dist
    σ0_subject ~ σ_prior_dist

    μ1 ~ μ_prior_dist
    σ1 ~ σ_prior_dist

    μ2 ~ μ_prior_dist
    σ2 ~ σ_prior_dist

    μ3 ~ μ_prior_dist
    σ3 ~ σ_prior_dist

    β0_stimulus ~ filldist(Normal(μ0_stimulus, σ0_stimulus), n_stimuli)
    β0_subject ~ filldist(Normal(μ0_subject, σ0_subject), n_subjects)
    β1 ~ filldist(Normal(μ1, σ1), n_stimuli)
    β2 ~ filldist(Normal(μ2, σ2), n_stimuli)
    β3 ~ filldist(Normal(μ3, σ3), n_stimuli)

    A = hcat(
        β0_stimulus[jstimulus],
        β0_subject[jsubject],
        β1[jstimulus],
        β2[jstimulus],
        β3[jstimulus],
    )
    yhat = sum(A .* X, dims=2)

    for i in eachindex(y)
        y[i] ~ Normal(yhat[i], 1.0)
    end

    return (yhat=yhat,)
end


model_demo = demo(;
    y=data["y"],
    X=data["X"],
    jstimulus=data["jstimulus"],
    jsubject=data["jsubject"],
);

rng = Random.MersenneTwister(seed_global)
posterior_chain = sample(rng, model_demo, Turing.NUTS(), 100; n_adapts=200, discard_adapt=true)

prior_chain = Turing.sample(rng, model_demo, Prior(), 200)

model_predictive = demo(;
    y=similar(data["y"], Missing),
    X=data["X"],
    jstimulus=data["jstimulus"],
    jsubject=data["jsubject"],
);

prior_predictive = Turing.predict(rng, model_predictive, prior_chain)
posterior_predictive = Turing.predict(rng, model_predictive, posterior_chain)

However, if I replace the for loop with

    Σ = Diagonal(fill(1.0, length(yhat)))
    y ~ MvNormal(vec(yhat), Σ)

I get this error:

julia> prior_predictive = Turing.predict(rng, model_predictive, prior_chain)
ERROR: MethodError: no method matching loglikelihood(::DiagNormal, ::Vector{Missing})

So I tried using
y=Vector{Float64}(undef, length(data["y"]))
rather than y=similar(data["y"], Missing), as is suggested in the documentation.

That allows the model to run, but it doesn’t seem to treat the Float64 undef vector as parameters, and returns unusable chains:

julia> prior_predictive = Turing.predict(rng, model_predictive, prior_chain)
Chains MCMC chain (200×0×1 Array{Float64, 3}):

Iterations        = 1:1:200
Number of chains  = 1
Samples per chain = 200
parameters        = 
internals         = 

Summary Statistics
  parameters   mean   std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec 
      Symbol    Any   Any   Float64    Float64    Float64   Float64       Missing 


Quantiles
  parameters   2.5%   25.0%   50.0%   75.0%   97.5% 
      Symbol    Any     Any     Any     Any     Any 

And I also tried y=Vector{Real}(undef, length(data["y"])), which gives this error:

prior_predictive = Turing.predict(Random.MersenneTwister(seed_global), model_predictive, prior_chain)
ERROR: UndefRefError: access to undefined reference
Stacktrace:
  [1] getindex
    @ ./essentials.jl:13 [inlined]
  [2] _broadcast_getindex
    @ ./broadcast.jl:675 [inlined]
  [3] _getindex
    @ ./broadcast.jl:705 [inlined]
  [4] _broadcast_getindex
    @ ./broadcast.jl:681 [inlined]
  [5] getindex
    @ ./broadcast.jl:636 [inlined]
  [6] copy
    @ ./broadcast.jl:942 [inlined]
  [7] materialize
    @ ./broadcast.jl:903 [inlined]
  [8] sqmahal(d::DiagNormal, x::Vector{Real})
    @ Distributions ~/.julia/packages/Distributions/ji8PW/src/multivariate/mvnormal.jl:267
  [9] _logpdf(d::DiagNormal, x::Vector{Real})
    @ Distributions ~/.julia/packages/Distributions/ji8PW/src/multivariate/mvnormal.jl:143
 [10] logpdf
    @ ~/.julia/packages/Distributions/ji8PW/src/common.jl:263 [inlined]
 [11] loglikelihood
    @ ~/.julia/packages/Distributions/ji8PW/src/common.jl:448 [inlined]
 [12] observe
    @ ~/.julia/packages/DynamicPPL/E4kDs/src/context_implementations.jl:266 [inlined]
 [13] observe
    @ ~/.julia/packages/DynamicPPL/E4kDs/src/context_implementations.jl:263 [inlined]
 [14] tilde_observe
    @ ~/.julia/packages/DynamicPPL/E4kDs/src/context_implementations.jl:158 [inlined]
 [15] tilde_observe
    @ ~/.julia/packages/DynamicPPL/E4kDs/src/context_implementations.jl:156 [inlined]
 [16] tilde_observe
    @ ~/.julia/packages/DynamicPPL/E4kDs/src/context_implementations.jl:151 [inlined]
 [17] tilde_observe!!(context::DynamicPPL.SamplingContext{…}, right::DiagNormal, left::Vector{…}, vi::DynamicPPL.ThreadSafeVarInfo{…})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/E4kDs/src/context_implementations.jl:207
 [18] tilde_observe!!(context::DynamicPPL.SamplingContext{…}, right::DiagNormal, left::Vector{…}, vname::AbstractPPL.VarName{…}, vi::DynamicPPL.ThreadSafeVarInfo{…})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/E4kDs/src/context_implementations.jl:194
 [19] demo(__model__::DynamicPPL.Model{…}, __varinfo__::DynamicPPL.ThreadSafeVarInfo{…}, __context__::DynamicPPL.SamplingContext{…}; y::Vector{…}, X::Matrix{…}, jstimulus::Vector{…}, jsubject::Vector{…})
    @ Main ./REPL[19]:42
 [20] demo
    @ ./REPL[19]:1 [inlined]
 [21] _evaluate!!
    @ ~/.julia/packages/DynamicPPL/E4kDs/src/model.jl:963 [inlined]
 [22] evaluate_threadsafe!!
    @ ~/.julia/packages/DynamicPPL/E4kDs/src/model.jl:952 [inlined]
 [23] evaluate!!
    @ ~/.julia/packages/DynamicPPL/E4kDs/src/model.jl:887 [inlined]
 [24] evaluate!!(model::DynamicPPL.Model{…}, rng::Random.TaskLocalRNG, varinfo::DynamicPPL.UntypedVarInfo{…}, sampler::DynamicPPL.SampleFromPrior, context::DynamicPPL.DefaultContext)
    @ DynamicPPL ~/.julia/packages/DynamicPPL/E4kDs/src/model.jl:900
 [25] (::DynamicPPL.Model{…})(::Random.TaskLocalRNG, ::Vararg{…})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/E4kDs/src/model.jl:860
 [26] VarInfo
    @ ~/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:129 [inlined]
 [27] VarInfo
    @ ~/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:128 [inlined]
 [28] VarInfo
    @ ~/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:132 [inlined]
 [29] transitions_from_chain(rng::Random.MersenneTwister, model::DynamicPPL.Model{…}, chain::Chains{…}; sampler::DynamicPPL.SampleFromPrior)
    @ Turing.Inference ~/.julia/packages/Turing/IyijE/src/mcmc/Inference.jl:718
 [30] predict(rng::Random.MersenneTwister, model::DynamicPPL.Model{…}, chain::Chains{…}; include_all::Bool)
    @ Turing.Inference ~/.julia/packages/Turing/IyijE/src/mcmc/Inference.jl:631
 [31] predict(rng::Random.MersenneTwister, model::DynamicPPL.Model{…}, chain::Chains{…})
    @ Turing.Inference ~/.julia/packages/Turing/IyijE/src/mcmc/Inference.jl:624
 [32] top-level scope
    @ REPL[24]:1
Some type information was truncated. Use `show(err)` to see complete types.

So again, apologies if I missed something basic in the documentation, but how should I go about sampling a predictive distribution using MvNormal? (As a side note, I first tried to use DiagNormal but I couldn’t even get it to work with the data, not sure if it’s an internal-only function?.) Thanks for the help!

julia> versioninfo()
Julia Version 1.10.4
Commit 48d4fd48430 (2024-06-04 10:41 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: macOS (arm64-apple-darwin22.4.0)
  CPU: 12 × Apple M2 Pro
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, apple-m1)
Threads: 8 default, 0 interactive, 4 GC (on 8 virtual cores)
Environment:
  JULIA_EDITOR = code
  JULIA_NUM_THREADS = 8

julia> Pkg.status()
  [0bf59076] AdvancedHMC v0.6.1
  [cbdf2221] AlgebraOfGraphics v0.6.19
  [c7e460c6] ArgParse v1.2.0
  [131c737c] ArviZ v0.10.5
  [4a6e88f0] ArviZPythonPlots v0.1.5
  [336ed68f] CSV v0.10.14
⌃ [13f3f980] CairoMakie v0.11.11
  [324d7699] CategoricalArrays v0.10.8
  [a93c6f00] DataFrames v1.6.1
  [1a297f60] FillArrays v1.11.0
  [663a7486] FreeTypeAbstraction v0.10.3
  [682c06a0] JSON v0.21.4
  [98e50ef6] JuliaFormatter v1.0.56
⌅ [ee78f7c6] Makie v0.20.10
  [7f7a1694] Optimization v3.25.1
  [b1d3bc72] Pathfinder v0.8.7
  [f27b6e38] Polynomials v4.0.9
  [438e738f] PyCall v1.96.4
  [37e2e3b7] ReverseDiff v1.15.3
  [295af30f] Revise v3.5.14
  [2913bbd2] StatsBase v0.34.3
  [f3b207a7] StatsPlots v0.15.7
  [fce5fe82] Turing v0.32.3
  [e88e6eb3] Zygote v0.6.70

It appears that the MvNormal method fails when y is a kwarg, but works when y is an arg. I’m new to julia and Turing but this behavior is unexpected to me. Probably a Turing bug?

using Random: Random
using StatsBase
using Turing
using LinearAlgebra

data = Dict();
let
    global data
    X = ones(Float64, 300, 5)
    X[:, 3:5] = Random.randn((300, 3))
    data["y"] = Random.randn(300);
    data["X"] = X
    data["jstimulus"] = Random.rand(1:10, 300);
    data["jsubject"] = Random.rand(1:20, 300);
end;

@model function demo_arg(y, X, jstimulus, jsubject, ::Type{T}=Float64) where {T}
    n = size(X, 1)
    if y === missing
        y = Vector{T}(undef, n)
    end

    n_stimuli = length(unique(jstimulus))
    n_subjects = length(unique(jsubject))

    μ_prior_dist = Normal(0, 5)
    σ_prior_dist = truncated(Normal(0, 10), 0, Inf)

    μ0_stimulus ~ μ_prior_dist
    σ0_stimulus ~ σ_prior_dist

    μ0_subject ~ μ_prior_dist
    σ0_subject ~ σ_prior_dist

    μ1 ~ μ_prior_dist
    σ1 ~ σ_prior_dist

    μ2 ~ μ_prior_dist
    σ2 ~ σ_prior_dist

    μ3 ~ μ_prior_dist
    σ3 ~ σ_prior_dist

    β0_stimulus ~ filldist(Normal(μ0_stimulus, σ0_stimulus), n_stimuli)
    β0_subject ~ filldist(Normal(μ0_subject, σ0_subject), n_subjects)
    β1 ~ filldist(Normal(μ1, σ1), n_stimuli)
    β2 ~ filldist(Normal(μ2, σ2), n_stimuli)
    β3 ~ filldist(Normal(μ3, σ3), n_stimuli)

    A = hcat(
        β0_stimulus[jstimulus],
        β0_subject[jsubject],
        β1[jstimulus],
        β2[jstimulus],
        β3[jstimulus],
    )
    yhat = vec(sum(A .* X, dims=2))

    Σ = Diagonal(fill(1.0^2, n))
    y ~ MvNormal(yhat, Σ)
    return (yhat=yhat,)
end

@model function demo_kwarg(::Type{T}=Float64; y=missing, X, jstimulus, jsubject) where {T}
    n = size(X, 1)
    if y === missing
        y = Vector{T}(undef, n)
    end

    n_stimuli = length(unique(jstimulus))
    n_subjects = length(unique(jsubject))

    μ_prior_dist = Normal(0, 5)
    σ_prior_dist = truncated(Normal(0, 10), 0, Inf)

    μ0_stimulus ~ μ_prior_dist
    σ0_stimulus ~ σ_prior_dist

    μ0_subject ~ μ_prior_dist
    σ0_subject ~ σ_prior_dist

    μ1 ~ μ_prior_dist
    σ1 ~ σ_prior_dist

    μ2 ~ μ_prior_dist
    σ2 ~ σ_prior_dist

    μ3 ~ μ_prior_dist
    σ3 ~ σ_prior_dist

    β0_stimulus ~ filldist(Normal(μ0_stimulus, σ0_stimulus), n_stimuli)
    β0_subject ~ filldist(Normal(μ0_subject, σ0_subject), n_subjects)
    β1 ~ filldist(Normal(μ1, σ1), n_stimuli)
    β2 ~ filldist(Normal(μ2, σ2), n_stimuli)
    β3 ~ filldist(Normal(μ3, σ3), n_stimuli)

    A = hcat(
        β0_stimulus[jstimulus],
        β0_subject[jsubject],
        β1[jstimulus],
        β2[jstimulus],
        β3[jstimulus],
    )
    yhat = vec(sum(A .* X, dims=2))

    Σ = Diagonal(fill(1.0^2, n))
    y ~ MvNormal(yhat, Σ)
    return (y=y,)
end

seed_global = 42;
rng = Random.MersenneTwister(seed_global);

### this works as expected ###
model_demo_arg = demo_arg(missing, data["X"], data["jstimulus"], data["jsubject"])
chain_arg = sample(rng, model_demo_arg, HMC(0.01, 5), 500)

### this fails ###
model_demo_kwarg = demo_kwarg(; y=missing, X=data["X"], jstimulus=data["jstimulus"], jsubject=data["jsubject"])
chain_kwarg = sample(rng, model_demo_kwarg, HMC(0.01, 5), 500)
julia> chain_kwarg = sample(rng, model_demo_kwarg, HMC(0.01, 5), 500)
ERROR: DomainError with Dual{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}}(NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN):
Normal: the condition σ >= zero(σ) is not satisfied.
Stacktrace:
  [1] #371
    @ ~/.julia/packages/Distributions/ji8PW/src/univariate/continuous/normal.jl:37 [inlined]
  [2] check_args
    @ ~/.julia/packages/Distributions/ji8PW/src/utils.jl:89 [inlined]
  [3] #Normal#370
    @ ~/.julia/packages/Distributions/ji8PW/src/univariate/continuous/normal.jl:37 [inlined]
  [4] Normal
    @ ~/.julia/packages/Distributions/ji8PW/src/univariate/continuous/normal.jl:36 [inlined]
  [5] demo_kwarg(__model__::DynamicPPL.Model{…}, __varinfo__::DynamicPPL.ThreadSafeVarInfo{…}, __context__::DynamicPPL.SamplingContext{…}, arg#472::DynamicPPL.TypeWrap{…}; y::Missing, X::Matrix{…}, jstimulus::Vector{…}, jsubject::Vector{…})
    @ Main ./REPL[26]:21
  [6] demo_kwarg
    @ ./REPL[26]:1 [inlined]
  [7] _evaluate!!
    @ ~/.julia/packages/DynamicPPL/E4kDs/src/model.jl:963 [inlined]
  [8] evaluate_threadsafe!!
    @ ~/.julia/packages/DynamicPPL/E4kDs/src/model.jl:952 [inlined]
  [9] evaluate!!(model::DynamicPPL.Model{…}, varinfo::DynamicPPL.TypedVarInfo{…}, context::DynamicPPL.SamplingContext{…})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/E4kDs/src/model.jl:887
 [10] logdensity
    @ ~/.julia/packages/DynamicPPL/E4kDs/src/logdensityfunction.jl:94 [inlined]
 [11] Fix1
    @ ./operators.jl:1118 [inlined]
 [12] chunk_mode_gradient!(result::DiffResults.MutableDiffResult{…}, f::Base.Fix1{…}, x::Vector{…}, cfg::ForwardDiff.GradientConfig{…})
    @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:123
 [13] gradient!
    @ ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:39 [inlined]
 [14] gradient!
    @ ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:35 [inlined]
 [15] logdensity_and_gradient
    @ ~/.julia/packages/LogDensityProblemsAD/rBlLq/ext/LogDensityProblemsADForwardDiffExt.jl:118 [inlined]
 [16] ∂logπ∂θ
    @ ~/.julia/packages/Turing/IyijE/src/mcmc/hmc.jl:159 [inlined]
 [17] ∂H∂θ
    @ ~/.julia/packages/AdvancedHMC/AlvV4/src/hamiltonian.jl:38 [inlined]
 [18] step(lf::AdvancedHMC.Leapfrog{…}, h::AdvancedHMC.Hamiltonian{…}, z::AdvancedHMC.PhasePoint{…}, n_steps::Int64; fwd::Bool, full_trajectory::Val{…})
    @ AdvancedHMC ~/.julia/packages/AdvancedHMC/AlvV4/src/integrator.jl:229
 [19] step
    @ ~/.julia/packages/AdvancedHMC/AlvV4/src/integrator.jl:199 [inlined]
 [20] sample_phasepoint
    @ ~/.julia/packages/AdvancedHMC/AlvV4/src/trajectory.jl:323 [inlined]
 [21] transition(rng::Random.MersenneTwister, τ::AdvancedHMC.Trajectory{…}, h::AdvancedHMC.Hamiltonian{…}, z::AdvancedHMC.PhasePoint{…})
    @ AdvancedHMC ~/.julia/packages/AdvancedHMC/AlvV4/src/trajectory.jl:262
 [22] transition
    @ ~/.julia/packages/AdvancedHMC/AlvV4/src/sampler.jl:59 [inlined]
 [23] step(rng::Random.MersenneTwister, model::DynamicPPL.Model{…}, spl::DynamicPPL.Sampler{…}, state::Turing.Inference.HMCState{…}; nadapts::Int64, kwargs::@Kwargs{})
    @ Turing.Inference ~/.julia/packages/Turing/IyijE/src/mcmc/hmc.jl:240
 [24] step(rng::Random.MersenneTwister, model::DynamicPPL.Model{…}, spl::DynamicPPL.Sampler{…}, state::Turing.Inference.HMCState{…})
    @ Turing.Inference ~/.julia/packages/Turing/IyijE/src/mcmc/hmc.jl:226
 [25] macro expansion
    @ ~/.julia/packages/AbstractMCMC/YrmkI/src/sample.jl:176 [inlined]
 [26] macro expansion
    @ ~/.julia/packages/ProgressLogging/6KXlp/src/ProgressLogging.jl:328 [inlined]
 [27] macro expansion
    @ ~/.julia/packages/AbstractMCMC/YrmkI/src/logging.jl:9 [inlined]
 [28] mcmcsample(rng::Random.MersenneTwister, model::DynamicPPL.Model{…}, sampler::DynamicPPL.Sampler{…}, N::Int64; progress::Bool, progressname::String, callback::Nothing, discard_initial::Int64, thinning::Int64, chain_type::Type, initial_state::Nothing, kwargs::@Kwargs{})
    @ AbstractMCMC ~/.julia/packages/AbstractMCMC/YrmkI/src/sample.jl:120
 [29] sample(rng::Random.MersenneTwister, model::DynamicPPL.Model{…}, sampler::DynamicPPL.Sampler{…}, N::Int64; chain_type::Type, resume_from::Nothing, initial_state::Nothing, kwargs::@Kwargs{})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/E4kDs/src/sampler.jl:93
 [30] sample
    @ ~/.julia/packages/DynamicPPL/E4kDs/src/sampler.jl:83 [inlined]
 [31] #sample#4
    @ ~/.julia/packages/Turing/IyijE/src/mcmc/Inference.jl:263 [inlined]
 [32] sample(rng::Random.MersenneTwister, model::DynamicPPL.Model{…}, alg::HMC{…}, N::Int64)
    @ Turing.Inference ~/.julia/packages/Turing/IyijE/src/mcmc/Inference.jl:256
 [33] top-level scope
    @ REPL[34]:1
Some type information was truncated. Use `show(err)` to see complete types.

Anyway, I moved to the new conditioning syntax, which is working

@model function demo(;
    X::Matrix{Float64},
    jstimulus::Vector{Int},
    jsubject::Vector{Int})

    n_stimuli = length(unique(jstimulus))
    n_subjects = length(unique(jsubject))

    μ_prior_dist = Normal(0, 5)
    σ_prior_dist = truncated(Normal(0, 10), 0, Inf)

    μ0_stimulus ~ μ_prior_dist
    σ0_stimulus ~ σ_prior_dist

    μ0_subject ~ μ_prior_dist
    σ0_subject ~ σ_prior_dist

    μ1 ~ μ_prior_dist
    σ1 ~ σ_prior_dist

    μ2 ~ μ_prior_dist
    σ2 ~ σ_prior_dist

    μ3 ~ μ_prior_dist
    σ3 ~ σ_prior_dist

    β0_stimulus ~ filldist(Normal(μ0_stimulus, σ0_stimulus), n_stimuli)
    β0_subject ~ filldist(Normal(μ0_subject, σ0_subject), n_subjects)
    β1 ~ filldist(Normal(μ1, σ1), n_stimuli)
    β2 ~ filldist(Normal(μ2, σ2), n_stimuli)
    β3 ~ filldist(Normal(μ3, σ3), n_stimuli)

    A = hcat(
        β0_stimulus[jstimulus],
        β0_subject[jsubject],
        β1[jstimulus],
        β2[jstimulus],
        β3[jstimulus],
    )
    yhat := sum(A .* X, dims=2)

    y ~ MvNormal(vec(yhat), LinearAlgebra.I)

    return nothing
end

model_demo = demo(;
    X=data["X"],
    jstimulus=data["jstimulus"],
    jsubject=data["jsubject"],
) | (; y=data["y"])


model_predictive = demo(;
    X=data["X"],
    jstimulus=data["jstimulus"],
    jsubject=data["jsubject"],
)