Hey @stanlazic !
Thank you for the interest in RxInfer
. For your example to work you need to slightly modify your model and input data in the following way:
# define model
@model function linear_regression(n)
# I changed x to be of the vector type
x = datavar(Vector{Float64}, n)
y = datavar(Float64, n)
# priors
β0 ~ NormalMeanVariance(1.0, 100.0)
# I changed beta to be of the multivariate type
β ~ MvNormalMeanCovariance([ 0.0, 0.0 ], [ 1.0 0.0; 0.0 1.0 ])
for i in 1:n
y[i] ~ NormalMeanVariance(dot(x[i], β) + β0, 1.0)
end
end
and the inference:
xdata = collect.(zip(xdata1, xdata2)) # collect into pairs
results = inference(
model = linear_regression(length(xdata1)),
data = (y = ydata, x = xdata),
initmessages = (β0 = NormalMeanVariance(0.0, 1.0), )
)
and the results will be available:
mean_cov(results.posteriors[:β])
([1.1638631582255687, 2.1389045170632293], [0.24982928559123693 0.07513981828561604; 0.07513981828561604 0.21513388551908566])
The reason for your error is that dot
is not defined between two vectors of RandomVariable
s, but rather between two multivariate RandomVariable
s.
You may also be interested in the following example: ReactiveMP: How to run linear model with multiple predictors and an intercept - #6 by albertpod. This example uses old version of the RxInfer
, which we called ReactiveMP
. But the example itself should work in any case if you remove [ default_factorisation = MeanField() ]
from the @model
specification and put it in the inference()
function as follows:
n = 250
m = 100
@model function multivariate_linear_regression(n,m)
a ~ MvNormalMeanCovariance(zeros(m), diagm(ones(m)))
b ~ NormalMeanVariance(0.0,1.0)
W ~ InverseWishart(n+2, diageye(n))
c ~ ones(n)*b
x = datavar(Matrix{Float64})
y = datavar(Vector{Float64})
z ~ x*a+c
y ~ MvNormalMeanCovariance(z, W)
end
results = inference(
model = multivariate_linear_regression(n,m),
data = (y = randn(n), x = randn(n,m)),
initmarginals = (W = InverseWishart(n+2, diageye(n)), ),
returnvars = (a = KeepLast(), b = KeepLast(), W = KeepLast()),
free_energy = true,
iterations = 10,
constraints = MeanField()
)