Hello!
I am trying to modify this example from RxInfer.jl’s documentation and get an error saying I didn’t initialize marginals when two unobserved states influence multiple observable results in the domain logic in a Bayesian network.
Specifically, consider the following MWE that was derived from the example linked above.
using RxInfer, Random
# Create Score node
struct Score end
@node Score Stochastic [out, in]
# Adding update rule for the Score node
@rule Score(:in, Marginalisation) (q_out::PointMass,) = begin
return Bernoulli(mean(q_out))
end
@model function model_a(r)
local s
# Priors
s[1] ~ Bernoulli(0.5)
s[2] ~ Bernoulli(0.5)
# Domain logic and results
r[1] ~ Score(s[1])
r[2] ~ Score(s[1] || s[2])
end
test_results = [0, 1]
inference_result = infer(
model=model_a(),
data=(r=test_results,)
)
results = map(params, inference_result.posteriors[:s])
This results in no error whatsoever and yields the result s = [0,1]
, as desired. However, consider the following model, in which I add slightly more complex logic to the Bayesian network:
# Same code as above outside of the model adjustment below:
@model function model_b(r)
local s
# Priors
s[1] ~ Bernoulli(0.5)
s[2] ~ Bernoulli(0.5)
# Domain logic and results
r[1] ~ Score(s[1] && s[2]) # There is an additional dependence from s[1] & s[2] to r[1] now.
r[2] ~ Score(s[1] || s[2])
end
# Run inference with slightly more complicated model
inference_result = infer(
model=model_b(),
data=(r=test_results,)
)
When I try to infer the states s
from this model, I get the following error indicating that my marginals have not been defined and hence s
cannot be updated by RxInfer.jl.
1-element ExceptionStack:
Variables [ s ] have not been updated after an update event.
Therefore, make sure to initialize all required marginals and messages. See `initialization` keyword argument for the inference function.
See the official documentation for detailed information regarding the initialization.
Stacktrace:
[1] error(s::String)
@ Base .\error.jl:35
[2] check_and_reset_updated!(updates::Dict{Symbol, RxInfer.MarginalHasBeenUpdated})
@ RxInfer C:\Users\SA30308\.julia\packages\RxInfer\wbFg1\src\inference\inference.jl:79
[3] batch_inference(; model::GraphPPL.ModelGenerator{typeof(model_b), @Kwargs{}, GraphPPL.PluginsCollection{Tuple{}}, RxInfer.ReactiveMPGraphPPLBackend}, data::@NamedTuple{r::Vector{Int64}}, initialization::Nothing, constraints::Nothing, meta::Nothing, options::Nothing, returnvars::Nothing, predictvars::Nothing, iterations::Nothing, free_energy::Bool, free_energy_diagnostics::Tuple{RxInfer.ObjectiveDiagnosticCheckNaNs, RxInfer.ObjectiveDiagnosticCheckInfs}, showprogress::Bool, callbacks::Nothing, addons::Nothing, postprocess::DefaultPostprocess, warn::Bool, catch_exception::Bool)
@ RxInfer C:\Users\SA30308\.julia\packages\RxInfer\wbFg1\src\inference\batch.jl:306
[4] batch_inference
@ RxInfer C:\Users\SA30308\.julia\packages\RxInfer\wbFg1\src\inference\batch.jl:94 [inlined]
[5] #infer#242
@ RxInfer C:\Users\SA30308\.julia\packages\RxInfer\wbFg1\src\inference\inference.jl:306 [inlined]
[6] top-level scope
@ REPL[11]:1
If I set my marginals for s
correctly for model_a
, shouldn’t model_b
work as well? What am I missing here? And what are the rules for modifying domain logic in RxInfer.jl that will avoid running into this sort of issue in the future?
Thank you very much for your time and help.