Simple linear regression with ReactiveMP - RuleMethodError?

Hi @svilupp!

Thanks for trying out ReactiveMP.jl. We somehow missed your question on discourse; sorry for the long waiting reply.

Please note that assigning a Gamma prior to the precision parameter of Normal likelihood results in analytically intractable inference. In the context of message-passing, it means that you can’t execute belief propagation in your graph, which results in the first error: prefix m_ stands for the message, while q_ stands for marginal.

To circumvent this issue, we need to resort to, for example, variational inference or VMP in the context of factor graphs.

Now, there are a few ways to do that, but let me first point out that the only place where we need to factorize our model is just around NormalMeanPrecision(a * time_index[i] + b, sigma) factor. In particular, we need to use a mean-field factorization between the mean and precision, i.e.
a * time_index[i] + b and sigma.

First and the easiest way to do that is to use the default_factorisation option when specifying the model:

@model [ default_factorisation = MeanField() ] function linreg(n)
    a ~ NormalMeanVariance(0.0, 10.0)
    b ~ NormalMeanVariance(0.0, 10.0)
    sigma  ~ GammaShapeRate(1.0, 1.0)
    
    time_index = datavar(Float64, n)
    y = datavar(Float64, n)

    for i in 1:n
        y[i]   ~ NormalMeanPrecision(a * time_index[i] + b, sigma)
    end
end

In this way, ReactiveMP will use mean-field wherever is possible, which in our case is just between a * time_index[i] + b and sigma.

Then the inference will follow smoothly:

results = inference(
    model = Model(linreg, length(time_index)),
    data  = (y = y, time_index = time_index),
    initmessages = (b = vague(NormalMeanVariance),),
    initmarginals = (sigma = vague(GammaShapeRate),),
    returnvars   = (a = KeepLast(), b = KeepLast(),sigma=KeepLast()),
    iterations   = 20,
    showprogress = true,
)

About initmarginals and initmessages. Non-rigorously, you need to provide marginals when you resort to VMP. Likewise, you need to provide messages when the graph contains loops (technically there is a little more to that). Here, we need to initialize marginal for sigma and a message for either a or b.

Alternatively, you can create an auxiliary vector aux that will represent the means of your likelihood, i.e.:

@model function linreg(n)
    a ~ NormalMeanVariance(0.0, 10.0)
    b ~ NormalMeanVariance(0.0, 10.0)
    sigma  ~ GammaShapeRate(1.0, 1.0)
    
    aux = randomvar(n)
    time_index = datavar(Float64, n)
    y = datavar(Float64, n)

    for i in 1:n
        aux[i] ~ a * time_index[i] + b
        y[i]   ~ NormalMeanPrecision(aux[i], sigma)
    end
end

In this case, you would need to provide specific constraints on your posterior factorization:

constraints = @constraints begin 
    q(aux, sigma) = q(aux)q(sigma)
end

and feed them inside your inference function

results = inference(
    model = Model(linreg, length(time_index)),
    data  = (y = y, time_index = time_index),
    constraints = constraints,
    initmessages = (b = vague(NormalMeanVariance),),
    initmarginals = (sigma = vague(GammaShapeRate),),
    returnvars   = (a = KeepLast(), b = KeepLast(),sigma=KeepLast()),
    iterations   = 20,
    showprogress = true,
)

This is a good extension of the Linear regression demo; please feel free to send PR with a new demo!

2 Likes