[ANN]: RxInfer.jl 2.0 Julia package for automated Bayesian inference on a factor graph with reactive message passing

rxinfer_black

We are excited to announce the release of RxInfer 2.0, which is a Julia package for fast and scalable Bayesian inference in probabilistic models. RxInfer unites previously released packages ReactiveMP, GraphPPL, and Rocket in a single user-friendly ecosystem that aims to execute efficient Bayesian inference in real-time.GraphPPL is a package for user-friendly model and inference constraints specification. ReactiveMP exports a high-performant, reactive message passing-based Bayesian inference engine for the specified model. The reactive extensions package Rocket supports running Bayesian inference on streaming data sets in real-time.

In general, RxInfer executes inference faster than sampling-based Bayesian inference methods since it exploits local conjugate pairings in the graph representation of a probabilistic model. For instance, RxInfer significantly outperforms Julia’s state-of-the-art inference package Turing on a set of conjugate state-space models. Turing, on the other hand, may be able to execute inference for a less constrained set of models. However, the ReactiveMP inference engine has evolved significantly over the last year and is capable of running inference for many non-conjugate models as well. It is also possible to include non-linear dependencies between variables and use various built-in approximation methods, such as Statistical Linearization, Unscented transform and Conjugate-computation Variational Inference (CVI).

In short, RxInfer provides:

  • A convenient user-friendly language for specification of model and inference constraints:

    • The @model macro resembles closely the same macro in other popular PPL languages, like Turing .
    • The @constraints macro specifies constraints on the variational family of distributions that should be used during the optimization procedure.
    • The @meta macro specifies approximation methods for different parts of the probabilistic model graph
  • Inference for static large datasets with millions of observations on a standard (win/mac/linux) laptop.

  • Real-time inference for streaming data sets with potentially an unlimited number of observations

  • A unified framework for different Bayesian inference algorithms, including:

    • Sum-Product
    • Mean-Field & Structured Variational Inference
    • Expectation Maximization
    • Expectation Propagation
  • Hybrid Bayesian inference algorithms across different parts of the probabilistic model’s graph

    • It is possible to use, for example, the Sum-Product rule in one part of the model and VMP in another part of the model
  • RxInfer is easily extendable with custom novel factor nodes and message update equations

  • Supports automatic differentiation of the entire inference procedure for global parameters tuning

  • Bethe Free Energy (approximation to the Bayesian model evidence) evaluation

  • And much more!

Check the Getting started section to get started. The documentation has a lot of example and tutorials.

Links:

26 Likes

This seems like a very cool package, congratulations on the work!

I have a question regarding the model on the front-page:

@model function SSM(n, x0, A, B, Q, P) 
 
 	 # x is a sequence of hidden states 
 	 x = randomvar(n) 
 
 	 # y is a sequence of clamped observations 
 	 y = datavar(Vector{Float64}, n) 
 
 	 # `~` expression creates a probabilistic relationship
 	 # between random variables
 	 x_prior ~ MvNormal(μ = mean(x0), Σ = cov(x0)) 
 	 x_prev = x_prior 
 
 	 # Build the state-space model 
 	 for i in 1:n 
 		   x[i] ~ MvNormal(μ = A * x_prev, Σ = Q) 
 		   y[i] ~ MvNormal(μ = B * x[i], Σ = P) 
 		   x_prev = x[i] 
 	 end 
 end 

In particular, I’m wondering how to interpret covariance-properties of the variables defined by the line

x[i] ~ MvNormal(μ = A * x_prev, Σ = Q) 

Here, x_prev is a random variable with some distribution attached to it, so the mean of x[i] will be a random variable. How does this relate to the Q that is specified as the covariance of x[i]? Is covariance Q added to the covariance of A * x_prev in the sense that A x_{prev} + w where w \sim N(0, Q)? This would correspond to how I normally think of, e.g., a Kalman filter operating, but it visually conflicts a bit with stating that x_i \sim N(A x_{prev}, Q). Maybe this line has a different interpretation in this language?

1 Like

Hi @baggepinnen!

Thanks for the question. If w[i]~N(0, Q), then we can marginalize out its mean, which is zero, hence
x[i] = A*x_prev + w[i] can be rewritten as MvNormal(μ=A*x_prev, Σ=Q)
Alternatively, you can specify the model in KF-like way:

@model function SSM(n, x0, A, B, Q, P) 
 
    # x is a sequence of hidden states 
    w = randomvar(n) 
    x = randomvar(n) 

    # y is a sequence of clamped observations 
    y = datavar(Vector{Float64}, n) 

    # `~` expression creates a probabilistic relationship
    # between random variables
    x_prior ~ MvNormal(μ = mean(x0), Σ = cov(x0)) 
    x_prev = x_prior 

    # Build the state-space model 
    for i in 1:n
          w[i] ~ MvNormal(μ=zeros(length(mean(x0))), Σ=Q)
          x[i] ~ A * x_prev + w[i]
          y[i] ~ MvNormal(μ = B * x[i], Σ = P) 
          x_prev = x[i] 
    end 
end
1 Like

Good question @baggepinnen ! I see that @albertpod has already replied to you, I can just add that you are right and basically:

x[i] ~ N(A * x[i - 1], Q)

is a shorthand for

tmp ~ A * x[i - 1]
x[i] ~ N(tmp, Q)

which (thanks to the Gaussian distribution statistical properties) can be rewritten exactly like

tmp ~ A * x[i - 1]
w ~ N(0, Q)
x[i] ~ tmp + w

Referring to you original question the covariance matrix for variable x[i]. If we assume that the covariance matrix for x[i - 1] is \Sigma_{i - 1} than \Sigma_{i} = A \Sigma_{i - 1} A^T + Q (following equations from Sascha Korl, A Factor Graph Approach to Signal Modelling, System Identification and Filtering)

Thanks for your comments guys!

What’s confusing me is probably that

x[i] ~ MvNormal(μ = A * x_prev, Σ = Q) 

looks like x[i] is sampled from a normal distribution with covariance Q, but in fact, the covariance of x[i] is \Sigma_i from the Lyapunov iteration \Sigma_i = A \Sigma_{i-1} A^T + Q.

In either case, I now understand how your modeling language works. Keep up the nice work! :slight_smile:

3 Likes

Looks great. Can you share more about where it should excel and where sampling based methods should be preferred? Are there a guidelines for that?

Hi @RoyiAvital!

Great question! Indeed there are no written guidelines for that at the moment.

Generally, if you can exploit conjugacy (or conditional-conjugacy) in your model or if there are pre-defined nodes that serve your needs, for example, Gamma Mixture node or GCV node, then you will excel by using RxInfer.jl in terms of accuracy and speed!

Most of these nodes representing distributions can be easily combined, e.g., Gamma Mixture can serve as a prior for the precision parameter of Normal distribution.

However, we recognize that sometimes you may want to put Beta prior on top of the mean of the Normal distribution. In this case, we advise using the CVI algorithm that uses sampling because there is no analytical solution. Note the CVI approximation will be employed locally; the rest of the inference will try to use analytical updates (if possible). Advanced users, can implement the local more efficient update for this scenario.

In a situation where your model consists only of “non-standard” dependencies, which can be solved only through sampling, you can, in principle, run inference by solely using CVI or Unscented approximations. Though, CVI will not be fast.

We will add more examples highlighting the advantage of the RxInfer.jl approach.

If you have more concrete questions, we will be happy to answer them.

1 Like

So if I get you write there are 3 cases:

  1. Model based on Conjugacy
    In this case RxInfer.jl will be (Almost?) as efficient as using the explicit formulas for the conjugate priors. But if one wants, probably one can do it manually.
  2. Model without Any Conjugacy
    In this case sampling based methods are the way to go.
  3. Mixed Model
    In this case RxInfer.jl also should have advantage as it can use the local cases where conjugacy can be exploited.

In case (3) the question is how efficient are the sampling mechanism in RxInfer.jl. If they are as efficient as competitions it will be uniformly the best choice in this case. If not, it will depend on the specific case and balance between the cases.

It’s pretty hard to provide a general recipe on what to do even within these 3 cases.

  1. Model based on Congugacy:

So, even if your model exhibits conjugacy, the inference might be problematic. So manual derivation of update rules will not be as easy as it may seem. Under the hood, RxInfer.jl casts your probabilistic model into the computational graph, Forney-style Factor Graph (FFG). FFG has a one-to-one correspondence with your probabilistic model. If there are loops in your graph, you would typically need a schedule for updates, even in conjugate cases.
The unique feature of RxInfer is that it is based on a reactive programming framework: there is no fixed inference algorithm. The model is broken into a network of factors (nodes), and each node independently processes incoming data streams.
It lets you iterate fast through your model and inference and frees you from manual derivations of the inference algorithm.

  1. Model without Any Conjugacy:

Generally, yes; however, if you can get around with Unscented transform or Linearization and you crave speed, then we advise using these techniques. Despite the naming, CVI approximation can handle many non-conjugate scenarios, but this method requires more examples. We plan to add more advanced and faster approximations to deal with different intractabilities.

  1. Mixed Model

Let me put it straight, RxInfer.jl does not have a sole sampling-based inference mechanism like Turing.jl, e.g. HMC, SMC, etc. We will add this feature later, probably we will just integrate Turing.jl into our package. However, RxInfer.jl has sampling elements in the CVI method that is used for the Delta node.

4 Likes

This looks great! I’m trying to run a multiple regression analysis based on the linear regression example but can’t figure out how to use matrix multiplication. This works:

# generate data
xdata1 = randn(10)
xdata2 = randn(10)
ydata = 2 .* xdata1 .+ 3 .* xdata2 .+ randn(10)

# define model
@model function linear_regression(n)
    x = datavar(Float64, (n, 2))
    y = datavar(Float64, n)

    # priors
    β0 ~ NormalMeanVariance(0.0, 100.0)
    β = randomvar(2)
    β .~ NormalMeanVariance(0.0, 1.0)
    
    for i in 1:n       
        y[i] ~ NormalMeanVariance(β[1] * x[i, 1] + β[2] * x[i, 2] + β0, 1.0)
        # I was expecting something like this to work:
        #  y[i] ~ NormalMeanVariance(dot(x[i, :], β) + β0, 1.0)
    end
end

# run inference
results = inference(
    model = linear_regression(length(xdata1)), 
    data  = (y = ydata, x = hcat(xdata1, xdata2)), 
    initmessages = (β = NormalMeanVariance(0.0, 1.0),
                    β0 = NormalMeanVariance(0.0, 1.0))
)

When using matrix multiplication I get MethodError: no method matching make_node(::typeof(dot). And similar error messages when I try transpose(), or '.

Thanks for the contribution, this looks great! I have two questions:

  1. Is there a public non-macro interface? If not, can we get one? :slight_smile:
  2. Would you consider renaming the package to ReactiveInfer? The “Rx” is a bit cryptic, and “Rx” is a common abbreviation for “prescription” (as in a medical prescription).
2 Likes

Hey @stanlazic !

Thank you for the interest in RxInfer. For your example to work you need to slightly modify your model and input data in the following way:

# define model
@model function linear_regression(n)
    # I changed x to be of the vector type
    x = datavar(Vector{Float64}, n)
    y = datavar(Float64, n)

    # priors
    β0 ~ NormalMeanVariance(1.0, 100.0)
    # I changed beta to be of the multivariate type
    β ~ MvNormalMeanCovariance([ 0.0, 0.0 ], [ 1.0 0.0; 0.0 1.0 ])
    
    for i in 1:n       
        y[i] ~ NormalMeanVariance(dot(x[i], β) + β0, 1.0)
    end
end

and the inference:

xdata =  collect.(zip(xdata1, xdata2)) # collect into pairs

results = inference(
    model = linear_regression(length(xdata1)), 
    data  = (y = ydata, x = xdata), 
    initmessages = (β0 = NormalMeanVariance(0.0, 1.0), )
)

and the results will be available:

mean_cov(results.posteriors[:β])
([1.1638631582255687, 2.1389045170632293], [0.24982928559123693 0.07513981828561604; 0.07513981828561604 0.21513388551908566])

The reason for your error is that dot is not defined between two vectors of RandomVariables, but rather between two multivariate RandomVariables.

You may also be interested in the following example: ReactiveMP: How to run linear model with multiple predictors and an intercept - #6 by albertpod. This example uses old version of the RxInfer, which we called ReactiveMP. But the example itself should work in any case if you remove [ default_factorisation = MeanField() ] from the @model specification and put it in the inference() function as follows:

n = 250
m = 100

@model function multivariate_linear_regression(n,m)
    a ~ MvNormalMeanCovariance(zeros(m), diagm(ones(m)))
    b ~ NormalMeanVariance(0.0,1.0)
    W ~ InverseWishart(n+2, diageye(n))
    c ~ ones(n)*b
    x = datavar(Matrix{Float64})
    y = datavar(Vector{Float64})
    z ~ x*a+c
    y ~ MvNormalMeanCovariance(z, W)

end

results = inference(
    model = multivariate_linear_regression(n,m),
    data  = (y = randn(n), x = randn(n,m)),
    initmarginals = (W = InverseWishart(n+2, diageye(n)), ),
    returnvars   = (a = KeepLast(), b = KeepLast(), W = KeepLast()),
    free_energy = true,
    iterations   = 10,
    constraints = MeanField()
)
2 Likes

Hey @CameronBieganek ! Thanks for your interest!

  1. No, there is none. The @model macro is rather old already (previously we used it with ReactiveMP.jl library directly). Recently we have prioritised plans on refactoring this part including adding a public non-macro interface for model specification (such that models can be created on the fly or loaded from a file). We cannot provide any definite time frame on when it could be ready though. So at some point you will get one :slight_smile:
  2. The reason for Rx was to associate with the Reactive eXtensions family of libraries here: ReactiveX - Languages. We considered that confusion when choosing a name for the library, but AFAIK Rx is a common abbreviation only in the US and is far less known outside of the US (I might be wrong though).
2 Likes

I got it working – many thanks for your help!

1 Like

In a model definition, how do I perform non-trivial calculations on a random variable for use as a parameter in another distribution. I’d like to use intermediate variables but it doesn’t seem possible. For example this works:

x ~ Normal(0, 1)
y ~ Normal(3 * x + 1, 1)

but this does not:

x ~ Normal(0, 1)
a = 3 * x + 1
y ~ Normal(a, 1)

I’d like to use the second form as, unlike in the simple case above, I’m not sure how to express my calculation in a closed form without introducing intermediary variables.

Thank you in advance!