[ANN]: RxInfer.jl 2.0 Julia package for automated Bayesian inference on a factor graph with reactive message passing

Hey @stanlazic !

Thank you for the interest in RxInfer. For your example to work you need to slightly modify your model and input data in the following way:

# define model
@model function linear_regression(n)
    # I changed x to be of the vector type
    x = datavar(Vector{Float64}, n)
    y = datavar(Float64, n)

    # priors
    β0 ~ NormalMeanVariance(1.0, 100.0)
    # I changed beta to be of the multivariate type
    β ~ MvNormalMeanCovariance([ 0.0, 0.0 ], [ 1.0 0.0; 0.0 1.0 ])
    
    for i in 1:n       
        y[i] ~ NormalMeanVariance(dot(x[i], β) + β0, 1.0)
    end
end

and the inference:

xdata =  collect.(zip(xdata1, xdata2)) # collect into pairs

results = inference(
    model = linear_regression(length(xdata1)), 
    data  = (y = ydata, x = xdata), 
    initmessages = (β0 = NormalMeanVariance(0.0, 1.0), )
)

and the results will be available:

mean_cov(results.posteriors[:β])
([1.1638631582255687, 2.1389045170632293], [0.24982928559123693 0.07513981828561604; 0.07513981828561604 0.21513388551908566])

The reason for your error is that dot is not defined between two vectors of RandomVariables, but rather between two multivariate RandomVariables.

You may also be interested in the following example: ReactiveMP: How to run linear model with multiple predictors and an intercept - #6 by albertpod. This example uses old version of the RxInfer, which we called ReactiveMP. But the example itself should work in any case if you remove [ default_factorisation = MeanField() ] from the @model specification and put it in the inference() function as follows:

n = 250
m = 100

@model function multivariate_linear_regression(n,m)
    a ~ MvNormalMeanCovariance(zeros(m), diagm(ones(m)))
    b ~ NormalMeanVariance(0.0,1.0)
    W ~ InverseWishart(n+2, diageye(n))
    c ~ ones(n)*b
    x = datavar(Matrix{Float64})
    y = datavar(Vector{Float64})
    z ~ x*a+c
    y ~ MvNormalMeanCovariance(z, W)

end

results = inference(
    model = multivariate_linear_regression(n,m),
    data  = (y = randn(n), x = randn(n,m)),
    initmarginals = (W = InverseWishart(n+2, diageye(n)), ),
    returnvars   = (a = KeepLast(), b = KeepLast(), W = KeepLast()),
    free_energy = true,
    iterations   = 10,
    constraints = MeanField()
)
2 Likes