ReactiveMP hierarchical linear regression

Hello,

I am playing around with ReactiveMP and wanted to try a hierarchical linear regression model.
I tried adapting the linear regression example in the docs, but I cannot figure out how to properly index the group level parameters (denoted c for class in the code). Appreciate any help on this, and would be happy to contribute a working example to the docs if wanted.

Thanks!

using ReactiveMP, GraphPPL, Rocket, Random, Plots, StableRNGs

rng = StableRNG(1234)
N = 250
c = Int32.(vcat(ones(div(N, 2)), 2 .* ones(div(N, 2))))
C = length(unique(c))
true_α = vcat(1.4, 1.7)
true_β = vcat(-0.2, -0.3)
true_σ = 1.5

xx = vcat(collect(1:N/2), collect(1:N/2))
x = xx .+ randn(rng, N)
y = true_α[c] .+ true_β[c] .* xx .+ true_σ.*randn(rng, N);
scatter(x, y, title = "Data", group=c)

@model [ default_factorisation = MeanField() ] function hlinreg(N, C, T=Float64)

### Priors
μα ~ NormalMeanVariance(0.0, 10.0)
τα ~ GammaShapeRate(1.0, 1.0)

μβ ~ NormalMeanVariance(0.0, 10.0)
τβ ~ GammaShapeRate(1.0, 1.0)

τ ~ GammaShapeRate(1.0, 1.0)

### Class level parameters
α = randomvar(C)
β = randomvar(C)
for c in 1:C
α[c] ~ NormalMeanPrecision(μα, τα)
β[c] ~ NormalMeanPrecision(μβ, τβ)
end

### Observations
x = datavar(T, N)
y = datavar(T, N)
c = datavar(Int32, N)
μ = randomvar(N)
for i in 1:N
cᵢ = c[i] # class index for  this observation
μ[i] ~ α[cᵢ] + β[cᵢ] * x[i] #  # fails with 'invalid index: DataVariable(c_1)'
y[i] ~ NormalMeanPrecision(μ[i], τ)
end

return α, β, τ, x, y
end

results = inference(
model=Model(hlinreg, N, C),
data=(x=x, y=y, c=c),
initmessages=(
μα=NormalMeanVariance(0.0, 100.0),
μβ=NormalMeanVariance(0.0, 100.0),
),
initmarginals=(
τ=vague(GammaShapeRate),
),
returnvars=(α=KeepLast(), β=KeepLast(), τ=KeepLast()),
iterations=20
);

Hey! Thanks for the interest in the ReactiveMP.jl package. You need a slight modification to your model, as far as I can tell there is no need to make the c as the datavar. However, I do think your use case is totally valid. We will consider it as a bug and will try to improve this functionality in the future. For now, you can do a slight modification to your mode specification:

# I simply moved c vector to the arguments
@model [ default_factorisation = MeanField() ] function hlinreg(N, C, c, T=Float64)

### Priors
μα ~ NormalMeanVariance(0.0, 10.0)
τα ~ GammaShapeRate(1.0, 1.0)

μβ ~ NormalMeanVariance(0.0, 10.0)
τβ ~ GammaShapeRate(1.0, 1.0)

τ ~ GammaShapeRate(1.0, 1.0)

### Class level parameters
α = randomvar(C)
β = randomvar(C)
for c in 1:C
α[c] ~ NormalMeanPrecision(μα, τα)
β[c] ~ NormalMeanPrecision(μβ, τβ)
end

### Observations
x = datavar(T, N)
y = datavar(T, N)
μ = randomvar(N)
for i in 1:N
cᵢ = c[i] # class index for  this observation
μ[i] ~ α[cᵢ] + β[cᵢ] * x[i] #  # fails with 'invalid index: DataVariable(c_1)'
y[i] ~ NormalMeanPrecision(μ[i], τ)
end

return α, β, τ, x, y
end


And run it as:

results = inference(
model=Model(hlinreg, N, C, c),
data=(x=x, y=y,),
initmessages=(
α = NormalMeanVariance(0.0, 100.0),
),
initmarginals=(
τ = vague(GammaShapeRate),
μα = vague(NormalMeanPrecision),
μβ = vague(NormalMeanPrecision),
τα = vague(GammaShapeRate),
τβ = vague(GammaShapeRate),
),
returnvars=(α=KeepLast(), β=KeepLast(), τ=KeepLast()),
iterations=20
);


On my end it gives me:

mean_var.(results.posteriors[:α])
2-element Vector{Tuple{Float64, Float64}}:
(1.3188253339155878, 0.017369964039966924)
(1.6090116679773925, 0.017369964061618996)

mean_var.(results.posteriors[:β])
2-element Vector{Tuple{Float64, Float64}}:
(-0.20076853532497577, 3.3708045276105617e-6)
(-0.29805951728076235, 3.3766653505304176e-6)


Which is close enough to the real values.

1 Like

Hi @bvdmitri, thanks for responding!

Your version with c as an argument works great. Is it correct to think “c is not directly observed so it should be an argument”?

If that is correct, how would that work in a filtering setting? My end goal with this is to try and write a filtering version of this model, but then c has to be a datavar, correct? However, if I understand your comment correctly this might not be supported currently?

Is it correct to think “c is not directly observed so it should be an argument”?

Yes, that is true. Initially I though there must be a workaround, but there is another restriction for datavar - the observed value should not change the structure of the model. Unfortunately this is exactly the case for your c variable as it attempts to change the underlying factor graph of the probabilistic model on the fly. This is not supported.

If that is correct, how would that work in a filtering setting? My end goal with this is to try and write a filtering version of this model, but then c has to be a datavar, correct?

Hm, that is an interesting point. As I mentioned it is not allowed to change the graph structure on the fly, so it means that you have to create a new model every time for each of your filtering steps. This also seems valid to me as you indeed have different models if you change c. It should not be a big problem as creating a model is fast in the ReactiveMP package.

I must admit it is an interesting use case though and we will think if it possible to support it in the future. You might also take some extra inspiration from this question about multivariate linear regression. It does model \alpha as a multivariate gaussian, thus does not require the c parameter.

P.S. Currently our team is working to release a next-gen version of ReactiveMP, which we call RxInfer. RxInfer will include the whole ecosystem ReactiveMP, GraphPPL and Rocket and is aimed to provide a user-friendly inference routines. You might be interested in that since it includes the new inference procedure for the filtering: here is an example. If everything goes according to the plan we release RxInfer at the end of this month.

Thank you, the explanation on c changing the graph makes a lot of sense. After thinking about what I wrote some more I realized the “not directly observed” argument did not make since x is also kind of like an index variable, but like you mentioned, it would not change the graph structure like c would so it can be a datavar.

Thanks for the link and the head up on RxInfer! Looking forward to the release.