Need some help on sampler error with simple linear regression

After watching Statistical Rethinking lecture about linear regression, I tried to follow the steps with other data using Turing.
Also I looked up this reference. https://github.com/StatisticalRethinkingJulia/SR2TuringPluto.jl

here is the code I used

@model function lin_reg(x, y)

	α ~ Normal(mean(y), 10)
	β ~ LogNormal(0, 1)
	σ ~ Uniform(0, 10)

	μ = @. α + β * (x - mean(x))
	y ~ MvNormal(μ, σ)
end

model = lin_reg(Xs, ys)

chain = sample(model, NUTS(), 1000)

After above NUTS sampler code, I get this warning messages and the results are bad.
Something is wrong with standard deviation. (if I change sd ~ Uniform(0, 50), then sd gives about 50)

The current proposal will be rejected due to numerical error(s).
isfinite.((θ, r, ℓπ, ℓκ)) (true, false, false, false)

here are traceplots

For your information, data looks like this in scatter plot
sss(1)

Thanks! :+1:
I may not reply asap since time is 1am now. :pray:

Thanks for your question. It looks like your model cannot estimate σ because σ is greater than 10 (or 50). A couple suggestions you could try:

  1. standardize you variables by subtracting the mean and dividing by the standard deviation. This is by far the best method to use, because it makes setting priors easier and it helps the sampler run faster.
  2. Divide your ys by 1000 if you don’t want to standardize.
  3. Set a larger upper bound on your uniform prior (e.g., Uniform(0, 1000)) or use a flat prior for σ.

Examples

using Turing
using Distributions
using Plots, StatsPlots
using Random, Statistics

# Generate data that resembles scatterplot
Random.seed!(43143)
N = 500
α = 0.2
β = 0.1
xs = rand(Uniform(0, 60), N)
ys = rand.(Poisson.(exp.(α .+ β .* xs)))

scatter(xs, ys)

# model
@model function lin_reg(x, y)

	α ~ Normal(mean(y), 10)
	β ~ LogNormal(0, 1)
	σ ~ Uniform(0, 50)

	μ = @. α + β * (x - mean(x))
	y ~ MvNormal(μ, σ)
end

# model without changes to show same behaviour
model = lin_reg(xs, ys)
chain = sample(model, NUTS(), 1000)
plot(chain) # shows same behaviour as σ is hitting bounds

# standardize
standardize(x) = (x .- mean(x)) ./ std(x)
xs_standardized = standardize(xs)
ys_standardized = standardize(ys)

model_standardized = lin_reg(xs_standardized, ys_standardized)
chain_standardized = sample(model_standardized, NUTS(), 1000)
plot(chain_standardized)


# divide by 1000
ys_1000 = copy(ys ./ 1000)
model_1000 = lin_reg(xs, ys_1000)
chain_1000 = sample(model_1000, NUTS(), 1000)
plot(chain_1000)


# Flat prior
@model function lin_reg_flat(x, y)

	α ~ Normal(mean(y), 10)
	β ~ LogNormal(0, 1)
	σ ~ Flat()

	μ = @. α + β * (x - mean(x))
	y ~ MvNormal(μ, σ)
end
model_flat = lin_reg_flat(xs, ys)
chains_flat = sample(model_flat, NUTS(), 1000)
plot(chains_flat)

2 Likes

Thx for the answer! I’ll give it a try.