Hello everyone,
I’m experimenting with the awesome package RxInfer
but apparently I’m having some difficulties in understanding some basic principles.
I tried to adapt one of the first simple examples of RxInfer
to build a small model for Net promotor score (NPS). Basically each data point is a three dimensional vector where each dimension is the count of Detractors, Neutral and Promoters. Here is the very simple model. I can run it through until it’s actually time to call infer. Then I get a message that I need to define the inference rules. Any hints on how to do that?
using RxInfer
using Random
ns = rand(Binomial(300, 0.3), 100)
trueθ = [0.3, 0.5, 0.2] # detractors, neutral, promoters
dataset = float.([rand(Multinomial(ns[i], trueθ)) for i in eachindex(ns)])
#dataset = reduce(hcat, dataset)
@model function npsmodel(n)
y = datavar(Vector{Float64}, n)
# We endow θ parameter of our model with a conjugate prior
θ ~ Dirichlet([10, 10, 10])
# We assume that outcome of each trial
# is governed by the Multinomial distribution
for i in 1:n
y[i] ~ Multinomial(90, θ)
end
end
result = inference(
model=npsmodel(size(dataset, 2)),
data=(y=dataset,)
)
The stacktrace:
ERROR: `Multinomial` is not available as a node in the inference engine. Used in `y ~ Multinomial(...)` expression.
Use `@node` macro to add a custom factor node corresponding to `Multinomial`. See `@node` macro for additional documentation and examples.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] make_node(::Type, ::FactorNodeCreationOptions{Tuple{Tuple{Int64}, Tuple{Int64}, Tuple{Int64}}, Nothing, Nothing}, ::DataVariable{PointMass{Vector{Float64}}, Rocket.RecentSubjectInstance{Message{PointMass{Vector{Float64}}}, Subject{Message{PointMass{Vector{Float64}}}, AsapScheduler, AsapScheduler}}}, ::ConstVariable{PointMass{Int64}, SingleObservable{Message{PointMass{Int64}, Nothing}, AsapScheduler}}, ::RandomVariable)
@ ReactiveMP ~/.julia/packages/ReactiveMP/vWHNY/src/node.jl:980
[3] make_node(::FactorGraphModel, ::FactorNodeCreationOptions{Nothing, Nothing, Nothing}, ::Type, ::DataVariable{PointMass{Vector{Float64}}, Rocket.RecentSubjectInstance{Message{PointMass{Vector{Float64}}}, Subject{Message{PointMass{Vector{Float64}}}, AsapScheduler, AsapScheduler}}}, ::ConstVariable{PointMass{Int64}, SingleObservable{Message{PointMass{Int64}, Nothing}, AsapScheduler}}, ::RandomVariable)
@ RxInfer ~/.julia/packages/RxInfer/NiAqM/src/model.jl:340
[4] macro expansion
@ ./REPL[163]:12 [inlined]
[5] macro expansion
@ ~/.julia/packages/GraphPPL/n5QGe/src/model.jl:445 [inlined]
[6] var"##npsmodel#355"(model#352::FactorGraphModel, n::Int64)
@ Main ~/.julia/packages/RxInfer/NiAqM/src/graphppl.jl:33
[7] ModelGenerator
@ ~/.julia/packages/RxInfer/NiAqM/src/model.jl:248 [inlined]
[8] #create_model#98
@ ~/.julia/packages/RxInfer/NiAqM/src/model.jl:265 [inlined]
[9] inference(; model::RxInfer.ModelGenerator{var"###npsmodel#355", Tuple{Int64}, NamedTuple{(), Tuple{}}}, data::NamedTuple{(:y,), Tuple{Vector{Vector{Float64}}}}, initmarginals::Nothing, initmessages::Nothing, constraints::Nothing, meta::Nothing, options::Nothing, returnvars::Nothing, iterations::Nothing, free_energy::Bool, free_energy_diagnostics::Tuple{BetheFreeEnergyCheckNaNs, BetheFreeEnergyCheckInfs}, showprogress::Bool, callbacks::Nothing, addons::Nothing, postprocess::DefaultPostprocess, warn::Bool)
@ RxInfer ~/.julia/packages/RxInfer/NiAqM/src/inference.jl:499
[10] top-level scope
@ REPL[164]:1
Update: I realized that I need to make my own @node and corresponding rules. That seems like a lot of stuff that I have no idea how to do. Is there a quick way to achieve this?