This works (runs at least, haven’t tested thoroughly)
@model function linear_regression(n,m)
a ~ MvNormalMeanCovariance(zeros(m), diagm(ones(m)))
b ~ NormalMeanVariance(0.0,1.0)
c ~ ones(n)*b
x = datavar(Matrix{Float64})
y = datavar(Vector{Float64})
z ~ x*a+c
y ~ MvNormalMeanCovariance(z , tiny .*diagm(ones(n)))
return a, y
end
results = inference(
model = Model(linear_regression, n,m),
data = (y = randn(n), x = randn(n,m)),
returnvars = (a = KeepLast(),),
iterations = 20
);
Guess I just need to get my head around the datavar/randomvar as vectors/matrices vs vectors/matrices of datavars/randomvars, and nrush up on very basic linear algebra!