After watching Statistical Rethinking lecture about linear regression, I tried to follow the steps with other data using Turing.
Also I looked up this reference. GitHub - StatisticalRethinkingJulia/SR2TuringPluto.jl: StatisticalRethinking notebook project using Turing and Pluto notebooks (derived from Max Lapan's Jupyter project)
here is the code I used
@model function lin_reg(x, y)
α ~ Normal(mean(y), 10)
β ~ LogNormal(0, 1)
σ ~ Uniform(0, 10)
μ = @. α + β * (x - mean(x))
y ~ MvNormal(μ, σ)
end
model = lin_reg(Xs, ys)
chain = sample(model, NUTS(), 1000)
After above NUTS sampler code, I get this warning messages and the results are bad.
Something is wrong with standard deviation. (if I change sd ~ Uniform(0, 50), then sd gives about 50)
The current proposal will be rejected due to numerical error(s).
isfinite.((θ, r, ℓπ, ℓκ)) (true, false, false, false)
here are traceplots
For your information, data looks like this in scatter plot
Thanks!
I may not reply asap since time is 1am now.
Thanks for your question. It looks like your model cannot estimate σ
because σ
is greater than 10 (or 50). A couple suggestions you could try:
- standardize you variables by subtracting the mean and dividing by the standard deviation. This is by far the best method to use, because it makes setting priors easier and it helps the sampler run faster.
- Divide your
ys
by 1000 if you don’t want to standardize.
- Set a larger upper bound on your uniform prior (e.g.,
Uniform(0, 1000)
) or use a flat prior for σ
.
Examples
using Turing
using Distributions
using Plots, StatsPlots
using Random, Statistics
# Generate data that resembles scatterplot
Random.seed!(43143)
N = 500
α = 0.2
β = 0.1
xs = rand(Uniform(0, 60), N)
ys = rand.(Poisson.(exp.(α .+ β .* xs)))
scatter(xs, ys)
# model
@model function lin_reg(x, y)
α ~ Normal(mean(y), 10)
β ~ LogNormal(0, 1)
σ ~ Uniform(0, 50)
μ = @. α + β * (x - mean(x))
y ~ MvNormal(μ, σ)
end
# model without changes to show same behaviour
model = lin_reg(xs, ys)
chain = sample(model, NUTS(), 1000)
plot(chain) # shows same behaviour as σ is hitting bounds
# standardize
standardize(x) = (x .- mean(x)) ./ std(x)
xs_standardized = standardize(xs)
ys_standardized = standardize(ys)
model_standardized = lin_reg(xs_standardized, ys_standardized)
chain_standardized = sample(model_standardized, NUTS(), 1000)
plot(chain_standardized)
# divide by 1000
ys_1000 = copy(ys ./ 1000)
model_1000 = lin_reg(xs, ys_1000)
chain_1000 = sample(model_1000, NUTS(), 1000)
plot(chain_1000)
# Flat prior
@model function lin_reg_flat(x, y)
α ~ Normal(mean(y), 10)
β ~ LogNormal(0, 1)
σ ~ Flat()
μ = @. α + β * (x - mean(x))
y ~ MvNormal(μ, σ)
end
model_flat = lin_reg_flat(xs, ys)
chains_flat = sample(model_flat, NUTS(), 1000)
plot(chains_flat)
2 Likes
Thx for the answer! I’ll give it a try.