# Variational inference for linear regression

This code presents a linear regression with generated data. Using NUTS, the parameter estimates are close to the true parameter values. However, using ADVI, the estimates are far from the true parameter values.

Because ADVI is much faster than NUTS, I would like to use it if possible, but I would like to get correct parameter estimates. How can I improve the estimates?

``````import Random
using DataFrames

using StatsBase: StatsBase, mad, median, percentile, sample, shuffle
using Distributions
using Turing
using Chain

Random.seed!(1)
##

df, parameters = let
β = (
β_0 = 7,
β_1 = 0.05,
β_2 = 0.10,
β_3 = 0.15,
β_4 = 0.20,
)
σ = 2
parameters = (β..., σ=σ)
parameters_df = DataFrame(
parameter=collect(keys(parameters)),
value=collect(values(parameters)),
)

N = 100_000
X = DataFrame(
x_0=fill(1, N),
x_1=rand([1,2], N),
x_2=rand(1:1:10, N),
x_3=rand([0,1], N),
x_4=rand(0:1:10, N),
)
df = transform(X)
μ = Matrix(df) * collect(β)
ϵ = rand(Normal(0, σ), N)
y = μ .+ ϵ

(df = transform(
X,
[] => (() -> μ) => :μ,
[] => (() -> ϵ) => :ϵ,
[] => (() -> y) => :y,
),
parameters = parameters_df,    )
end

##
@model function linear_outcome_model(X, y)
β_0 ~ Normal(mean(y), 2std(y))
β_1 ~ Normal(0, .5)
β_2 ~ Normal(0, .5)
β_3 ~ Normal(0, .5)
β_4 ~ Normal(0, .5)

μ = (
β_0 .* X.x_0
.+ β_1 .* X.x_1
.+ β_2 .* X.x_2
.+ β_3 .* X.x_3
.+ β_4 .* X.x_4

)
σ ~ truncated(Normal(0, 2std(y)), 0, Inf)
y ~ MvNormal(μ, σ)
end

fields = [    :x_0,    :x_1,    :x_2,    :x_3,    :x_4,]

model = @chain df begin
linear_outcome_model(_[:, fields], _.y)
end

samples_nuts =  sample(model, NUTS(0.65), 3000)

samples_vi = let
samplesarray_vi = rand(estimate_vi, 1000)
_, sym2range = bijector(model, Val(true))
parameters = keys(sym2range)

rearranged = DataFrame(Dict(
param => vec(samplesarray_vi[sym2range[param]..., :])
for param in parameters
))
insertcols!(rearranged, 1, :index => 1:1:nrow(rearranged))
DataFrames.stack(rearranged, Not(:index), variable_name=:parameter)
end

summary_nuts = @chain samples_nuts begin
DataFrame(summarize(_, mean, std))
rename(_, :parameters => :parameter)
transform(:parameter => ByRow(Symbol) => :parameter)
innerjoin(_, parameters, on=:parameter)
end

summary_vi = @chain samples_vi begin
groupby(:parameter)
combine(:value => mean => :mean, :value => std => :std)
transform(:parameter => ByRow(Symbol) => :parameter)
innerjoin(parameters, on=:parameter)
end
``````

Results show NUTS is close to the true `value` but ADVI is not:

``````julia> summary_nuts
6×4 DataFrame
Row │ parameter  mean       std         value
│ Symbol     Float64    Float64     Real
─────┼─────────────────────────────────────────
1 │ β_0        6.97143    0.0257945       7
2 │ β_1        0.0409737  0.0124979    0.05
3 │ β_2        0.101222   0.00222162    0.1
4 │ β_3        0.170142   0.0125686    0.15
5 │ β_4        0.203275   0.00194464    0.2
6 │ σ          1.99375    0.00440001      2

julia> summary_vi
6×4 DataFrame
Row │ parameter  mean      std        value
│ Symbol     Float64   Float64    Real
─────┼───────────────────────────────────────
1 │ β_0        2.37664   0.312357       7
2 │ β_1        1.26278   0.145315    0.05
3 │ β_2        0.333547  0.0577251    0.1
4 │ β_3        1.07337   0.239592    0.15
5 │ β_4        0.377414  0.0625201    0.2
6 │ σ          2.33721   0.227075       2

``````