Problems fitting a simple Multinomial model in RxInfer

Hello everyone,

I’m experimenting with the awesome package RxInfer but apparently I’m having some difficulties in understanding some basic principles.

I tried to adapt one of the first simple examples of RxInfer to build a small model for Net promotor score (NPS). Basically each data point is a three dimensional vector where each dimension is the count of Detractors, Neutral and Promoters. Here is the very simple model. I can run it through until it’s actually time to call infer. Then I get a message that I need to define the inference rules. Any hints on how to do that?

using RxInfer
using Random

ns = rand(Binomial(300, 0.3), 100)
trueθ = [0.3, 0.5, 0.2] # detractors, neutral, promoters
dataset = float.([rand(Multinomial(ns[i], trueθ)) for i in eachindex(ns)])
#dataset = reduce(hcat, dataset)

@model function npsmodel(n)
    y = datavar(Vector{Float64}, n)
    # We endow θ parameter of our model with a conjugate prior
    θ ~ Dirichlet([10, 10, 10])
    # We assume that outcome of each trial
    # is governed by the Multinomial distribution
    for i in 1:n
        y[i] ~ Multinomial(90, θ)
    end
end

result = inference(
    model=npsmodel(size(dataset, 2)),
    data=(y=dataset,)
)

The stacktrace:

ERROR: `Multinomial` is not available as a node in the inference engine. Used in `y ~ Multinomial(...)` expression.
Use `@node` macro to add a custom factor node corresponding to `Multinomial`. See `@node` macro for additional documentation and examples.

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] make_node(::Type, ::FactorNodeCreationOptions{Tuple{Tuple{Int64}, Tuple{Int64}, Tuple{Int64}}, Nothing, Nothing}, ::DataVariable{PointMass{Vector{Float64}}, Rocket.RecentSubjectInstance{Message{PointMass{Vector{Float64}}}, Subject{Message{PointMass{Vector{Float64}}}, AsapScheduler, AsapScheduler}}}, ::ConstVariable{PointMass{Int64}, SingleObservable{Message{PointMass{Int64}, Nothing}, AsapScheduler}}, ::RandomVariable)
    @ ReactiveMP ~/.julia/packages/ReactiveMP/vWHNY/src/node.jl:980
  [3] make_node(::FactorGraphModel, ::FactorNodeCreationOptions{Nothing, Nothing, Nothing}, ::Type, ::DataVariable{PointMass{Vector{Float64}}, Rocket.RecentSubjectInstance{Message{PointMass{Vector{Float64}}}, Subject{Message{PointMass{Vector{Float64}}}, AsapScheduler, AsapScheduler}}}, ::ConstVariable{PointMass{Int64}, SingleObservable{Message{PointMass{Int64}, Nothing}, AsapScheduler}}, ::RandomVariable)
    @ RxInfer ~/.julia/packages/RxInfer/NiAqM/src/model.jl:340
  [4] macro expansion
    @ ./REPL[163]:12 [inlined]
  [5] macro expansion
    @ ~/.julia/packages/GraphPPL/n5QGe/src/model.jl:445 [inlined]
  [6] var"##npsmodel#355"(model#352::FactorGraphModel, n::Int64)
    @ Main ~/.julia/packages/RxInfer/NiAqM/src/graphppl.jl:33
  [7] ModelGenerator
    @ ~/.julia/packages/RxInfer/NiAqM/src/model.jl:248 [inlined]
  [8] #create_model#98
    @ ~/.julia/packages/RxInfer/NiAqM/src/model.jl:265 [inlined]
  [9] inference(; model::RxInfer.ModelGenerator{var"###npsmodel#355", Tuple{Int64}, NamedTuple{(), Tuple{}}}, data::NamedTuple{(:y,), Tuple{Vector{Vector{Float64}}}}, initmarginals::Nothing, initmessages::Nothing, constraints::Nothing, meta::Nothing, options::Nothing, returnvars::Nothing, iterations::Nothing, free_energy::Bool, free_energy_diagnostics::Tuple{BetheFreeEnergyCheckNaNs, BetheFreeEnergyCheckInfs}, showprogress::Bool, callbacks::Nothing, addons::Nothing, postprocess::DefaultPostprocess, warn::Bool)
    @ RxInfer ~/.julia/packages/RxInfer/NiAqM/src/inference.jl:499
 [10] top-level scope
    @ REPL[164]:1

Update: I realized that I need to make my own @node and corresponding rules. That seems like a lot of stuff that I have no idea how to do. Is there a quick way to achieve this?

I don’t know how closely he watches this, so for RxInfer questions it might be good to tag @bvdmitri (unless he prefers otherwise :slight_smile: )

2 Likes

Hi @DoktorMike!

Thank you for trying out RxInfer.jl. Unfortunately, the Multinomial node is currently not available in ReactiveMP.jl, which is the inference engine used in RxInfer.jl.

I can imagine that it’s frustrating that you are not able to run inference in such a simple model. However, there are analytical rules available for this model, so it is possible to implement it in principle. You could open an issue on ReactiveMP.jl requesting this feature.

Alternatively, you could implement the node and corresponding rules yourself. However, I cannot provide you with a link to the documentation as we are currently working on an example of how to derive rules.
At the moment, the tutorials can be found in this thesis (see Appendix A), but I understand that they may be difficult to follow. I will open an issue on RxInfer.jl to include instructions on how to implement a node in the documentation.

Technically, what needs to be done is to define the node:

@node Multinomial Stochastic [out, n, k]

and then specify the update rule for k interface (θ in your model):

@rule Multinomial(:k, Marginalisation) (q_out::PointMass, q_n::PointMass, ) = begin 
    return Dirichlet(probvec(q_out) .+ one(eltype(probvec(q_out))))
end

Given all that, your snippet can look as follows:

using RxInfer
using Random

ns = rand(Binomial(300, 0.3), 100)
trueθ = [0.3, 0.5, 0.2] # detractors, neutral, promoters
dataset = float.([rand(Multinomial(ns[i], trueθ)) for i in eachindex(ns)])

@node Multinomial Stochastic [out, n, k]

@rule Multinomial(:k, Marginalisation) (q_out::PointMass, q_n::PointMass, ) = begin 
    return Dirichlet(probvec(q_out) .+ one(eltype(probvec(q_out))))
end

@model function npsmodel(n)
    y = datavar(Vector{Float64}, n)
    # We endow θ parameter of our model with a conjugate prior
    θ ~ Dirichlet([10, 10, 10])
    # We assume that outcome of each trial
    # is governed by the Multinomial distribution
    for i in 1:n
        y[i] ~ Multinomial(90, θ) 
    end
end

result = inference(
    model=npsmodel(size(dataset, 1)),
    data=(y=dataset,)
)

@show mean(result.posteriors[:θ])

mean(result.posteriors[:θ]) = [0.29851203608373783, 0.5084326126391137, 0.1930553512771486]

P.S. There might be a way of hacking this model with RxInfer approximations, but I doubt that you want to do it for such a simple case.

5 Likes

Great, thank you for the pointers. I will read up on the methodology and see if I can contribute to the rules. :pray:t2::smiling_face:

You are welcome. @bvdmitri pointed out at this small example on nodes and rules that will be extended in the near future.

UPD:
The rule I’ve provided appears to be correct.

1 Like

That perfectly solves it! Thank you @albertpod :pray:t2: