This code presents a linear regression with generated data. Using NUTS, the parameter estimates are close to the true parameter values. However, using ADVI, the estimates are far from the true parameter values.
Because ADVI is much faster than NUTS, I would like to use it if possible, but I would like to get correct parameter estimates. How can I improve the estimates?
import Random
using DataFrames
using StatsBase: StatsBase, mad, median, percentile, sample, shuffle
using Distributions
using Turing
using Chain
Random.seed!(1)
##
df, parameters = let
β = (
β_0 = 7,
β_1 = 0.05,
β_2 = 0.10,
β_3 = 0.15,
β_4 = 0.20,
)
σ = 2
parameters = (β..., σ=σ)
parameters_df = DataFrame(
parameter=collect(keys(parameters)),
value=collect(values(parameters)),
)
N = 100_000
X = DataFrame(
x_0=fill(1, N),
x_1=rand([1,2], N),
x_2=rand(1:1:10, N),
x_3=rand([0,1], N),
x_4=rand(0:1:10, N),
)
df = transform(X)
μ = Matrix(df) * collect(β)
ϵ = rand(Normal(0, σ), N)
y = μ .+ ϵ
(df = transform(
X,
[] => (() -> μ) => :μ,
[] => (() -> ϵ) => :ϵ,
[] => (() -> y) => :y,
),
parameters = parameters_df, )
end
##
@model function linear_outcome_model(X, y)
β_0 ~ Normal(mean(y), 2std(y))
β_1 ~ Normal(0, .5)
β_2 ~ Normal(0, .5)
β_3 ~ Normal(0, .5)
β_4 ~ Normal(0, .5)
μ = (
β_0 .* X.x_0
.+ β_1 .* X.x_1
.+ β_2 .* X.x_2
.+ β_3 .* X.x_3
.+ β_4 .* X.x_4
)
σ ~ truncated(Normal(0, 2std(y)), 0, Inf)
y ~ MvNormal(μ, σ)
end
fields = [ :x_0, :x_1, :x_2, :x_3, :x_4,]
model = @chain df begin
linear_outcome_model(_[:, fields], _.y)
end
samples_nuts = sample(model, NUTS(0.65), 3000)
samples_vi = let
estimate_vi = vi(model, ADVI())
samplesarray_vi = rand(estimate_vi, 1000)
_, sym2range = bijector(model, Val(true))
parameters = keys(sym2range)
rearranged = DataFrame(Dict(
param => vec(samplesarray_vi[sym2range[param]..., :])
for param in parameters
))
insertcols!(rearranged, 1, :index => 1:1:nrow(rearranged))
DataFrames.stack(rearranged, Not(:index), variable_name=:parameter)
end
summary_nuts = @chain samples_nuts begin
DataFrame(summarize(_, mean, std))
rename(_, :parameters => :parameter)
transform(:parameter => ByRow(Symbol) => :parameter)
innerjoin(_, parameters, on=:parameter)
end
summary_vi = @chain samples_vi begin
groupby(:parameter)
combine(:value => mean => :mean, :value => std => :std)
transform(:parameter => ByRow(Symbol) => :parameter)
innerjoin(parameters, on=:parameter)
end
Results show NUTS is close to the true value
but ADVI is not:
julia> summary_nuts
6×4 DataFrame
Row │ parameter mean std value
│ Symbol Float64 Float64 Real
─────┼─────────────────────────────────────────
1 │ β_0 6.97143 0.0257945 7
2 │ β_1 0.0409737 0.0124979 0.05
3 │ β_2 0.101222 0.00222162 0.1
4 │ β_3 0.170142 0.0125686 0.15
5 │ β_4 0.203275 0.00194464 0.2
6 │ σ 1.99375 0.00440001 2
julia> summary_vi
6×4 DataFrame
Row │ parameter mean std value
│ Symbol Float64 Float64 Real
─────┼───────────────────────────────────────
1 │ β_0 2.37664 0.312357 7
2 │ β_1 1.26278 0.145315 0.05
3 │ β_2 0.333547 0.0577251 0.1
4 │ β_3 1.07337 0.239592 0.15
5 │ β_4 0.377414 0.0625201 0.2
6 │ σ 2.33721 0.227075 2