RxInfer Domain Logic

Hello!

I am trying to modify this example from RxInfer.jl’s documentation and get an error saying I didn’t initialize marginals when two unobserved states influence multiple observable results in the domain logic in a Bayesian network.

Specifically, consider the following MWE that was derived from the example linked above.

using RxInfer, Random

# Create Score node
struct Score end
@node Score Stochastic [out, in]

# Adding update rule for the Score node
@rule Score(:in, Marginalisation) (q_out::PointMass,) = begin
    return Bernoulli(mean(q_out))
end
 
@model function model_a(r)

    local s
    # Priors
    s[1] ~ Bernoulli(0.5)
    s[2] ~ Bernoulli(0.5)

    # Domain logic and results
    r[1] ~ Score(s[1]) 
    r[2] ~ Score(s[1] || s[2]) 
end

test_results = [0, 1]
inference_result = infer(
    model=model_a(),
    data=(r=test_results,)
)
results = map(params, inference_result.posteriors[:s])

This results in no error whatsoever and yields the result s = [0,1], as desired. However, consider the following model, in which I add slightly more complex logic to the Bayesian network:

# Same code as above outside of the model adjustment below: 
@model function model_b(r)

    local s
    # Priors
    s[1] ~ Bernoulli(0.5)
    s[2] ~ Bernoulli(0.5)

    # Domain logic and results
    r[1] ~ Score(s[1] && s[2]) # There is an additional dependence from s[1] & s[2] to r[1] now. 
    r[2] ~ Score(s[1] || s[2]) 
end

# Run inference with slightly more complicated model
 inference_result = infer(
           model=model_b(),
           data=(r=test_results,)
       )

When I try to infer the states s from this model, I get the following error indicating that my marginals have not been defined and hence s cannot be updated by RxInfer.jl.

1-element ExceptionStack:
Variables [ s ] have not been updated after an update event.
Therefore, make sure to initialize all required marginals and messages. See `initialization` keyword argument for the inference function.
See the official documentation for detailed information regarding the initialization.

Stacktrace:
 [1] error(s::String)
   @ Base .\error.jl:35
 [2] check_and_reset_updated!(updates::Dict{Symbol, RxInfer.MarginalHasBeenUpdated})
   @ RxInfer C:\Users\SA30308\.julia\packages\RxInfer\wbFg1\src\inference\inference.jl:79
 [3] batch_inference(; model::GraphPPL.ModelGenerator{typeof(model_b), @Kwargs{}, GraphPPL.PluginsCollection{Tuple{}}, RxInfer.ReactiveMPGraphPPLBackend}, data::@NamedTuple{r::Vector{Int64}}, initialization::Nothing, constraints::Nothing, meta::Nothing, options::Nothing, returnvars::Nothing, predictvars::Nothing, iterations::Nothing, free_energy::Bool, free_energy_diagnostics::Tuple{RxInfer.ObjectiveDiagnosticCheckNaNs, RxInfer.ObjectiveDiagnosticCheckInfs}, showprogress::Bool, callbacks::Nothing, addons::Nothing, postprocess::DefaultPostprocess, warn::Bool, catch_exception::Bool)
   @ RxInfer C:\Users\SA30308\.julia\packages\RxInfer\wbFg1\src\inference\batch.jl:306
 [4] batch_inference
   @ RxInfer C:\Users\SA30308\.julia\packages\RxInfer\wbFg1\src\inference\batch.jl:94 [inlined]
 [5] #infer#242
   @ RxInfer C:\Users\SA30308\.julia\packages\RxInfer\wbFg1\src\inference\inference.jl:306 [inlined]
 [6] top-level scope
   @ REPL[11]:1

If I set my marginals for s correctly for model_a, shouldn’t model_b work as well? What am I missing here? And what are the rules for modifying domain logic in RxInfer.jl that will avoid running into this sort of issue in the future?

Thank you very much for your time and help.

Hi @spolk !

The error message “Variables [ … ] have not been updated after an update event” does indicate that your model B has a loop as described in the documentation.

For your case, I drew both factor graphs (will add model A in a subsequent post):

You can see that model B has a loop from s_1 to itself (from the s_1 equality node, the OR node, and then over the s_2 equality node, to the AND node, back to the s_1 equality node), as well as from s_2 to itself.

To resolve this circular dependency / loop, you need to tell RxInfer how to break the loop to update the variables by initializing the messages like so

init = @initialization begin           
       μ(s) = Bernoulli(0.5)                                                                 
end

[...]

inference_result = infer(
  model = model_b(),
  data = (r=test_results,),
  initialization = init,
  [...]
)

I hope that helps :slight_smile:

3 Likes

Factor graph for model A for comparison:

(As a new user, I am only allowed to add one media per post)

3 Likes

Thank you!