Hello,
I am playing around with ReactiveMP and wanted to try a hierarchical linear regression model.
I tried adapting the linear regression example in the docs, but I cannot figure out how to properly index the group level parameters (denoted c
for class in the code). Appreciate any help on this, and would be happy to contribute a working example to the docs if wanted.
Thanks!
using ReactiveMP, GraphPPL, Rocket, Random, Plots, StableRNGs
rng = StableRNG(1234)
N = 250
c = Int32.(vcat(ones(div(N, 2)), 2 .* ones(div(N, 2))))
C = length(unique(c))
true_α = vcat(1.4, 1.7)
true_β = vcat(-0.2, -0.3)
true_σ = 1.5
xx = vcat(collect(1:N/2), collect(1:N/2))
x = xx .+ randn(rng, N)
y = true_α[c] .+ true_β[c] .* xx .+ true_σ.*randn(rng, N);
scatter(x, y, title = "Data", group=c)
@model [ default_factorisation = MeanField() ] function hlinreg(N, C, T=Float64)
### Priors
μα ~ NormalMeanVariance(0.0, 10.0)
τα ~ GammaShapeRate(1.0, 1.0)
μβ ~ NormalMeanVariance(0.0, 10.0)
τβ ~ GammaShapeRate(1.0, 1.0)
τ ~ GammaShapeRate(1.0, 1.0)
### Class level parameters
α = randomvar(C)
β = randomvar(C)
for c in 1:C
α[c] ~ NormalMeanPrecision(μα, τα)
β[c] ~ NormalMeanPrecision(μβ, τβ)
end
### Observations
x = datavar(T, N)
y = datavar(T, N)
c = datavar(Int32, N)
μ = randomvar(N)
for i in 1:N
cᵢ = c[i] # class index for this observation
μ[i] ~ α[cᵢ] + β[cᵢ] * x[i] # # fails with 'invalid index: DataVariable(c_1)'
y[i] ~ NormalMeanPrecision(μ[i], τ)
end
return α, β, τ, x, y
end
results = inference(
model=Model(hlinreg, N, C),
data=(x=x, y=y, c=c),
initmessages=(
μα=NormalMeanVariance(0.0, 100.0),
μβ=NormalMeanVariance(0.0, 100.0),
),
initmarginals=(
τ=vague(GammaShapeRate),
),
returnvars=(α=KeepLast(), β=KeepLast(), τ=KeepLast()),
iterations=20
);