Hi @svilupp!
Thanks for trying out ReactiveMP.jl
. We somehow missed your question on discourse; sorry for the long waiting reply.
Please note that assigning a Gamma prior to the precision parameter of Normal likelihood results in analytically intractable inference. In the context of message-passing, it means that you can’t execute belief propagation in your graph, which results in the first error: prefix m_
stands for the message, while q_
stands for marginal.
To circumvent this issue, we need to resort to, for example, variational inference or VMP in the context of factor graphs.
Now, there are a few ways to do that, but let me first point out that the only place where we need to factorize our model is just around NormalMeanPrecision(a * time_index[i] + b, sigma)
factor. In particular, we need to use a mean-field factorization between the mean and precision, i.e.
a * time_index[i] + b
and sigma
.
First and the easiest way to do that is to use the default_factorisation
option when specifying the model:
@model [ default_factorisation = MeanField() ] function linreg(n)
a ~ NormalMeanVariance(0.0, 10.0)
b ~ NormalMeanVariance(0.0, 10.0)
sigma ~ GammaShapeRate(1.0, 1.0)
time_index = datavar(Float64, n)
y = datavar(Float64, n)
for i in 1:n
y[i] ~ NormalMeanPrecision(a * time_index[i] + b, sigma)
end
end
In this way, ReactiveMP
will use mean-field wherever is possible, which in our case is just between a * time_index[i] + b
and sigma
.
Then the inference will follow smoothly:
results = inference(
model = Model(linreg, length(time_index)),
data = (y = y, time_index = time_index),
initmessages = (b = vague(NormalMeanVariance),),
initmarginals = (sigma = vague(GammaShapeRate),),
returnvars = (a = KeepLast(), b = KeepLast(),sigma=KeepLast()),
iterations = 20,
showprogress = true,
)
About initmarginals
and initmessages
. Non-rigorously, you need to provide marginals when you resort to VMP. Likewise, you need to provide messages when the graph contains loops (technically there is a little more to that). Here, we need to initialize marginal for sigma
and a message for either a
or b
.
Alternatively, you can create an auxiliary vector aux
that will represent the means of your likelihood, i.e.:
@model function linreg(n)
a ~ NormalMeanVariance(0.0, 10.0)
b ~ NormalMeanVariance(0.0, 10.0)
sigma ~ GammaShapeRate(1.0, 1.0)
aux = randomvar(n)
time_index = datavar(Float64, n)
y = datavar(Float64, n)
for i in 1:n
aux[i] ~ a * time_index[i] + b
y[i] ~ NormalMeanPrecision(aux[i], sigma)
end
end
In this case, you would need to provide specific constraints on your posterior factorization:
constraints = @constraints begin
q(aux, sigma) = q(aux)q(sigma)
end
and feed them inside your inference function
results = inference(
model = Model(linreg, length(time_index)),
data = (y = y, time_index = time_index),
constraints = constraints,
initmessages = (b = vague(NormalMeanVariance),),
initmarginals = (sigma = vague(GammaShapeRate),),
returnvars = (a = KeepLast(), b = KeepLast(),sigma=KeepLast()),
iterations = 20,
showprogress = true,
)
This is a good extension of the Linear regression demo; please feel free to send PR with a new demo!