Specify a separate MvNormal for each observation of a variable in Turing

This MWE is eq.6 from Kyung et al. which features group-wise scaling of the coefficients to facilitate shrinkage, analogous to the grouped lasso.

Pretty sure I just messed up the model specification somehow but I also filed an issue, in case that makes it easier for the devs.

It seems like Turing does not like how I have defined the MvNormal for each observation of y.

I have the following model:

# Import Turing and Distributions.
using Turing, Distributions

# Import MCMCChains, Plots, and StatPlots for visualizations and diagnostics.
using MCMCChains, Plots, StatsPlots

# Functionality for splitting and normalizing the data.
using MLDataUtils: shuffleobs, splitobs, rescale!

# Functionality for evaluating the model predictions.
using Distances

# Set a seed for reproducibility.
using Random

using LinearAlgebra
Random.seed!(0)

# Hide the progress prompt while sampling.
Turing.turnprogress(false);

@model function grouped_lasso(y,x1,x2,x0,σ²,λ²)
	# for pairwise model, only two groups for lasso: home (self) and remote (other)
	ngroups = 2

	# home and remote groups
	mk = size(x1)[2] + size(x0)[2]

	# number of observations
	nobs = size(x1)[1]

	# set variance prior (shrinkage of the group-wise linear coefficients)
	τ² = Vector{Real}(undef,ngroups)
	for i = 1:ngroups
		τ²[i] ~ Gamma((mk+1)/2,λ²/2)
	end

	# set the coefficient prior
	β = Vector{Vector}(undef,ngroups)
	for i = 1:ngroups
		β[i] ~ MvNormal(mk,σ²*τ²[i])
	end
	# set the target distribution
	ntarg = size(y)[2]
	for i = 1:nobs
		mu = Vector{Real}(undef,ntarg)
		for j = 1:ntarg
			mu[j] = dot(vec([x1[i,:]' x0[i,:]']),vec(β[1])) + dot(vec([x2[i,:]' x0[i,:]']),vec(β[2]))
		end
		y[i,:] ~ MvNormal(mu,σ²)
	end
end


yTrain = rand(MvNormal(fill(0,3),1),1000)'
x1Train = rand(MvNormal(fill(0,2),1),1000)'
x2Train = rand(MvNormal(fill(0,2),1),1000)'
x0Train = rand(MvNormal(fill(0,2),1),1000)'
σ² = 1
λ² = 2

model = grouped_lasso(yTrain,x1Train,x2Train,x0Train,σ²,λ²)
chain = sample(model, NUTS(0.65), 300);

error output:

ERROR: StackOverflowError:
Stacktrace:
 [1] MvNormal(::Array{Real,1}, ::PDMats.ScalMat{Float64}) at /home/me/.julia/packages/Distributions/HjzA0/src/multivariate/mvnormal.jl:200
 [2] MvNormal(::Array{Real,1}, ::PDMats.ScalMat{Float64}) at /home/me/.julia/packages/Distributions/HjzA0/src/multivariate/mvnormal.jl:202 (repeats 65353 times)
 [3] MvNormal(::Array{Real,1}, ::Int64) at /home/me/.julia/packages/Distributions/HjzA0/src/multivariate/mvnormal.jl:220
 [4] #29 at /home/me/code/scratch/local/turing/test_lasso.jl:50 [inlined]
 [5] (::var"#29#30")(::Random._GLOBAL_RNG, ::DynamicPPL.Model{var"#29#30",(:y, :x1, :x2, :x0, :σ², :λ²),(),(),Tuple{Adjoint{Float64,Array{Float64,2}},Adjoint{Float64,Array{Float64,2}},Adjoint{Float64,Array{Float64,2}},Adjoint{Float64,Array{Float64,2}},Int64,Int64},Tuple{}}, ::DynamicPPL.VarInfo{DynamicPPL.Metadata{Dict{DynamicPPL.VarName,Int64},Array{Distribution,1},Array{DynamicPPL.VarName,1},Array{Real,1},Array{Set{DynamicPPL.Selector},1}},Float64}, ::DynamicPPL.SampleFromPrior, ::DynamicPPL.DefaultContext, ::Adjoint{Float64,Array{Float64,2}}, ::Adjoint{Float64,Array{Float64,2}}, ::Adjoint{Float64,Array{Float64,2}}, ::Adjoint{Float64,Array{Float64,2}}, ::Int64, ::Int64) at ./none:0
 [6] macro expansion at /home/me/.julia/packages/DynamicPPL/YDJiG/src/model.jl:0 [inlined]
 [7] _evaluate at /home/me/.julia/packages/DynamicPPL/YDJiG/src/model.jl:160 [inlined]
 [8] evaluate_threadunsafe at /home/me/.julia/packages/DynamicPPL/YDJiG/src/model.jl:130 [inlined]
 [9] Model at /home/me/.julia/packages/DynamicPPL/YDJiG/src/model.jl:92 [inlined]
 [10] Model at /home/me/.julia/packages/DynamicPPL/YDJiG/src/model.jl:98 [inlined]
 [11] VarInfo at /home/me/.julia/packages/DynamicPPL/YDJiG/src/varinfo.jl:110 [inlined]
 [12] VarInfo at /home/me/.julia/packages/DynamicPPL/YDJiG/src/varinfo.jl:109 [inlined]
 [13] DynamicPPL.Sampler(::NUTS{Turing.Core.ForwardDiffAD{40},(),AdvancedHMC.DiagEuclideanMetric}, ::DynamicPPL.Model{var"#29#30",(:y, :x1, :x2, :x0, :σ², :λ²),(),(),Tuple{Adjoint{Float64,Array{Float64,2}},Adjoint{Float64,Array{Float64,2}},Adjoint{Float64,Array{Float64,2}},Adjoint{Float64,Array{Float64,2}},Int64,Int64},Tuple{}}, ::DynamicPPL.Selector) at /home/me/.julia/packages/Turing/UsQlw/src/inference/hmc.jl:378
 [14] Sampler at /home/me/.julia/packages/Turing/UsQlw/src/inference/hmc.jl:376 [inlined]
 [15] sample(::Random._GLOBAL_RNG, ::DynamicPPL.Model{var"#29#30",(:y, :x1, :x2, :x0, :σ², :λ²),(),(),Tuple{Adjoint{Float64,Array{Float64,2}},Adjoint{Float64,Array{Float64,2}},Adjoint{Float64,Array{Float64,2}},Adjoint{Float64,Array{Float64,2}},Int64,Int64},Tuple{}}, ::NUTS{Turing.Core.ForwardDiffAD{40},(),AdvancedHMC.DiagEuclideanMetric}, ::Int64; kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/me/.julia/packages/Turing/UsQlw/src/inference/Inference.jl:164
 [16] sample at /home/me/.julia/packages/Turing/UsQlw/src/inference/Inference.jl:164 [inlined]
 [17] #sample#1 at /home/me/.julia/packages/Turing/UsQlw/src/inference/Inference.jl:154 [inlined]
 [18] sample(::DynamicPPL.Model{var"#29#30",(:y, :x1, :x2, :x0, :σ², :λ²),(),(),Tuple{Adjoint{Float64,Array{Float64,2}},Adjoint{Float64,Array{Float64,2}},Adjoint{Float64,Array{Float64,2}},Adjoint{Float64,Array{Float64,2}},Int64,Int64},Tuple{}}, ::NUTS{Turing.Core.ForwardDiffAD{40},(),AdvancedHMC.DiagEuclideanMetric}, ::Int64) at /home/me/.julia/packages/Turing/UsQlw/src/inference/Inference.jl:154
 [19] top-level scope at none:1

Small thing, but check with the literature if it is really a Gamma prior on σ² and not on 1/σ².

Edit: Ah, it’s about Lasso, then it makes sense: https://doi.org/10.1080/00031305.2017.1291448

1 Like

Here is the much cleaner and faster MWE with suggested fixes by @devmotion (thanks for the ridiculously fast response on github!):

@model function grouped_lasso(y, X, σ,λ²)
	# number of observations and features
	p, nobs = size(X)
    mk = div(p, 2)

	# set variance prior (shrinkage of the group-wise linear coefficients)
	τ² ~ filldist(Gamma((mk + 1) / 2, 2 / λ²), 2)

	# set the coefficient prior
	β ~ arraydist(MvNormal.(mk, σ .* sqrt.(τ²)))

	# set the target distribution
    for i in 1:nobs
        mu = view(X, :, i)' * vec(β)
		y[:, i] ~ MvNormal(fill(mu,size(y)[1]), Matrix(σ*I, size(y)[1], size(y)[1]))

	end
end


yTrain = rand(MvNormal(fill(0,3),1),1000)
x1Train = rand(MvNormal(fill(0,2),1),1000)
x2Train = rand(MvNormal(fill(0,2),1),1000)
x0Train = rand(MvNormal(fill(0,2),1),1000)
σ = 1.0 #
λ² = 2


model = grouped_lasso(yTrain,[x0Train; x1Train; x0Train; x2Train;],σ,λ²)
chain = sample(model, NUTS(0.65), 300);

describe(chain)[1]
Summary Statistics
  parameters      mean       std   naive_se      mcse        ess      rhat  
      Symbol   Float64   Float64    Float64   Missing    Float64   Float64  
                                                                            
      β[1,1]    0.0352    0.5431     0.0443   missing    30.3005    1.0391  
      β[1,2]   -0.0189    0.5435     0.0444   missing    29.6187    1.0392  
      β[2,1]   -0.0463    0.5427     0.0443   missing    48.3097    1.0507  
      β[2,2]    0.0519    0.5421     0.0443   missing    48.6922    1.0498  
      β[3,1]   -0.0103    0.0178     0.0015   missing   103.0911    0.9942  
      β[3,2]   -0.0070    0.0205     0.0017   missing   116.8437    0.9938  
      β[4,1]   -0.0006    0.0166     0.0014   missing    25.3661    1.0488  
      β[4,2]   -0.0053    0.0182     0.0015   missing   153.0968    1.0133  
       τ²[1]    0.9193    0.8859     0.0723   missing   103.1504    0.9948  
       τ²[2]    0.7998    0.7515     0.0614   missing    73.5011    0.9947  


describe(chain)[2]

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%  
      Symbol   Float64   Float64   Float64   Float64   Float64  
                                                                
      β[1,1]   -1.0620   -0.2630    0.0450    0.3429    1.2221  
      β[1,2]   -1.2039   -0.3139   -0.0416    0.2944    1.0857  
      β[2,1]   -1.3327   -0.2722   -0.0355    0.2126    0.9825  
      β[2,2]   -0.9773   -0.2046    0.0287    0.2683    1.3501  
      β[3,1]   -0.0398   -0.0234   -0.0107    0.0024    0.0243  
      β[3,2]   -0.0496   -0.0174   -0.0058    0.0061    0.0278  
      β[4,1]   -0.0319   -0.0108   -0.0001    0.0114    0.0275  
      β[4,2]   -0.0431   -0.0168   -0.0038    0.0077    0.0247  
       τ²[1]    0.0476    0.2398    0.7270    1.2488    3.0862  
       τ²[2]    0.0617    0.2423    0.5721    1.1749    2.8963  
2 Likes