Use of update!() in RXInfer to handle multiple iterations

I am trying to implement the van de Laar chance constraint model in RXInfer, and as a warm-up I have implemented the reference model from the same paper. Code is included at the end of this post that produces an output graph that closely resembles the published one. I used the mountain car example as a template. This post is a continuation of a Github issue for RXInfer.

While the code runs fine, I have questions about modifying it to achieve two goals. Any thoughts, @bvdmitri or others?

First, I would like to use multiple iterations in the inference process, the equivalent of e.g., iterations=10 in the inference function. But I want to stop the inference procedure after an iteration if a condition is met. This is what the chance constraint code (written for ForneyLab.jl) does.

I see two possible ways to control iterations: (1) use a loop containing update!(); or (2) use a callback. I discuss issues with both.

The Advanced Tutorial section of the RXInfer documentation appears to implement inference iterations by stepping through a loop, where in each iteration update!(y, dataset) is called. In that example, y is a vector of datavar returned from create_model(), and dataset is a vector of Float64.

But this is not the way my infer function or problem is structured. I set some values in Float64 vectors (e.g., m_u[1] = agent_action_t). I could return a datavar vector m_u from the model and then grab it in my infer function using a create_model statement, but then I couldn’t update m_u[1] manually as required.

What I want is to call update! with something like update!(data) where data is the dict that contains for example (:m_u => m_u). But that doesn’t work.

An alternative would be to use a callback in the inference function. Building on an example callback in the documentation, the following runs if I first create function cb(model, name, update) ... end:

result = inference(
    ...,
    callbacks = (
        on_marginal_update = (model, name, update) -> cb(model, name, update),
    )
)

But from the documentation, it appears that there are not any callback choices (e.g., before_data_update) that would allow for halting iterations.

Any ideas on how to use update! or callbacks in my situation, or how to otherwise halt iterations based on some condition?

Second, I would like to print out free energy values as is done in the Advanced Tutorial example. It uses something akin to:

model2, returnval = create_model(model)
bfe_observable = score(model2, Float64, BetheFreeEnergy())
bfe_subscription = subscribe!(bfe_observable, (fe) -> println("Current BFE value: ", fe));

This code is commented out in my infer function. If I remove the comments, nothing is printed. I’m just not sure where this code is suppose to go, or whether I have specified it correctly.

In trying to follow the example in the Advanced Tutorial, I assume that my final infer function after addressing the above two questions, and if I use update!, would look something like:

infer = (agent_action_t::Float64, agent_obs_t::Float64, t::Int) -> begin
    m_w[:] = [fx_m_wind(t+k) for k in 0:T-1]  # mean wind
    m_u[1] = agent_action_t  # register action with the generative model
    v_u[1] = tiny  # clamp control prior to performed action
    m_x[1] = agent_obs_t  # register observation with the generative model
    v_x[1] = tiny  # clamp goal prior to observation    
    
    data = Dict(
            :m_u => m_u, 
            :v_u => v_u, 
            :m_x => m_x, 
            :v_x => v_x,
            :m_w => m_w,
            :v_w => v_w,
            :m_x_t_last => m_x_t_last,
            :v_x_t_last => v_x_t_last
        )
    
    model, (u, uw, x, x_t_last) = create_model(ref_model(; T=T))
    bfe_observable = score(model, Float64, BetheFreeEnergy()) 
    bfe_subscription = subscribe!(bfe_observable, (fe) -> println("Current BFE value: ", fe));

    setmarginal!.(u, NormalMeanVariance(0.0, huge)) 
    setmarginal!.(uw, NormalMeanVariance(0.0, huge))
    setmarginal!.(x, NormalMeanVariance(0.0, huge))
    setmarginal!.(x_t_last, NormalMeanVariance(0.0, huge))

    u_values = keep(Marginal)
    uw_values = keep(Marginal)
    x_values = keep(Marginal)
    x_t_last_values = keep(Marginal)

    u_subscription = [subscribe!(getmarginal(u[i]), u_values) for i in 1:T]
    uw_subscription = [subscribe!(getmarginal(uw[i]), uw_values) for i in 1:T]
    x_subscription = [subscribe!(getmarginal(x[i]), x_values) for i in 1:T]
    x_t_last_subscription = subscribe!(getmarginal(x_t_last), x_t_last_values)

    for i in 1:n_iterations
        update!(data)
       # now perform some test to halt iterations
    end
    
    return result  # where result is a collection of marginals
end

Is that about correct? The above does not work, of course, as update! does not work. And I’m not sure it would print out free energy as desired. Once I have an answer for my two questions, then I will try to tackle the harder problem of how to specify the @node and @rules code for the chance constraint model.

The code that does work is below. It can be run by using include("./RXRef.jl") in the REPL, if the script is saved with that filename. In brief, the goal is to keep a drone above a target elevation, where the drone is subject to a vertical wind.

As an aside, if I change the code below to e.g., iterations=10 in the inference function, the printout of E_avg at the end does not change from when iterations=1. So, I’m not even sure that multiple iterations is doing anything.

module RXRef


import Plots
import Random

# use some private functionality from ReactiveMP, 
import RxInfer.ReactiveMP: getrecent, messageout

using Formatting
using Infiltrator
using Revise
using RxInfer

Random.seed!(51233) # Set random seed for reproducibility

Plots.scalefontsizes()
Plots.scalefontsizes(0.8)

# ==================================================================================================
# --------------------------------------------------------------------------------------------------
function createWorld(; v_wind)
    # functions for interacting with the simulated environment.
        
    x_0 = 0.0 # Initial position, drone elevation
    x_t0 = x_0
    x_t1 = x_0
    
    execute = (action_t::Float64, m_wind_t::Float64) -> begin
        # Execute the action
        # action_t = ascention velocity, dx = action_t * t, with t=1
        x_t1 = x_t0 + action_t + m_wind_t + sqrt(v_wind) * randn()  # Update elevation
        x_t0 = x_t1 # Prepare for next step
    end
    
    observe = () -> begin 
        return x_t0  # Observe the current state
    end
    
    return (execute, observe)
end


# --------------------------------------------------------------------------------------------------
@model function ref_model(; T)
    # during each training step, from t to t+1, this function is called only once. 
    
    # the mean and variance of observation at the last t
    m_x_t_last = datavar(Float64)
    v_x_t_last = datavar(Float64)
    
    # sample the observation at the last t
    x_t_last ~ GaussianMeanVariance(m_x_t_last, v_x_t_last) 
    x_k_last = x_t_last
    
    # control
    m_u = datavar(Float64, T)
    v_u = datavar(Float64, T)
        
    # obs
    m_x = datavar(Float64, T)
    v_x = datavar(Float64, T)
    
    # wind
    m_w = datavar(Float64, T)
    v_w = datavar(Float64, T)
        
    # random variables
    u = randomvar(T)  # control
    x = randomvar(T)  # height
    uw = randomvar(T)  # control + wind variance
    
    # loop over horizon
    for k = 1:T
        x[k] ~ GaussianMeanVariance(m_x[k], v_x[k])   # goal prior
        u[k] ~ GaussianMeanVariance(m_u[k], v_u[k])   # control prior
        uw[k] ~ GaussianMeanVariance(u[k], v_w[k])    # control + wind variance
        x[k] ~ x_k_last + uw[k] + m_w[k]
        x_k_last = x[k]
    end
    
    return (x,)
end


# --------------------------------------------------------------------------------------------------
function createAgent(; T, fx_m_wind, v_wind, m_goal, v_goal, lambda)
    
    # control prior
    m_u = Vector{Float64}([0.0 for k=1:T ])
    v_u = Vector{Float64}([lambda^(-1) for k=1:T])
    
    # goal
    m_x = Vector{Float64}([m_goal for k=1:T])
    v_x = Vector{Float64}([v_goal for k=1:T])  
    
    # wind
    m_w = Vector{Float64}([0.0 for k=1:T])
    v_w = Vector{Float64}([v_wind for k=1:T])
    
    # initial position and variance
    m_x_t_last = 0.0
    v_x_t_last = convert(Float64, tiny)
    
    # Set current inference results
    result = nothing
            
    infer = (agent_action_t::Float64, agent_obs_t::Float64, t::Int) -> begin
        
        m_w[:] = [fx_m_wind(t+k) for k in 0:T-1]  # mean wind
        m_u[1] = agent_action_t  # register action with the generative model
        v_u[1] = tiny  # clamp control prior to performed action
        m_x[1] = agent_obs_t  # register observation with the generative model
        v_x[1] = tiny  # clamp goal prior to observation
        
        data = Dict(
            :m_u => m_u, 
            :v_u => v_u, 
            :m_x => m_x, 
            :v_x => v_x,
            :m_w => m_w,
            :v_w => v_w,
            :m_x_t_last => m_x_t_last,
            :v_x_t_last => v_x_t_last
        )
                
        model  = ref_model(; T=T) 
        #model2, returnval = create_model(model)
        #bfe_observable = score(model2, Float64, BetheFreeEnergy())
        #bfe_subscription = subscribe!(bfe_observable, (fe) -> println("Current BFE value: ", fe));
        
        result = inference(
            model = model, 
            data = data,
            initmarginals = (
                u = NormalMeanVariance(0.0, huge),
                uw = NormalMeanVariance(0.0, huge),
                x = NormalMeanVariance(0.0, huge),
                x_t_last = NormalMeanVariance(0.0, huge),
            ),
            iterations = 1,
            free_energy = true,
            #showprogress = true,
            returnvars = (
                u = KeepLast(), 
                uw = KeepLast(), 
                x_t_last = KeepLast(),
                x = KeepLast()
            ),
        ) 
        #unsubscribe!([bfe_subscription, ])
    end
    
    
    # The `act` function returns the inferred best possible action
    act = () -> begin
        if result !== nothing
            return mode(result.posteriors[:u][2])  # index 2 for next step, when iterations=1
        else
            return 0.0 # Without inference result we return some 'random' action
        end
    end
    
    # The `future` function returns the inferred future states
    future = () -> begin 
        if result !== nothing 
            return getindex.(mode.(result.posteriors[:x]), 1)  # for iterations =1
        else
            return zeros(T)
        end
    end

    slide = () -> begin
        (x, ) = result.returnval
        slide_msg_idx = 3 # This index is model dependent
        (m_x_t_last, v_x_t_last) = mean_var(getrecent(messageout(x[2], slide_msg_idx)))
        
        # these are not actually necessary for this simple problem, as the vectors do not change
        m_u = circshift(m_u, -1)
        m_u[end] = 0.0
        
        v_u = circshift(v_u, -1)
        v_u[end] = lambda^(-1)

        m_x = circshift(m_x, -1)
        m_x[end] = m_goal
        
        v_x = circshift(v_x, -1)
        v_x[end] = v_goal
    end
    
    return (infer, act, slide, future)    
end


# --------------------------------------------------------------------------------------------------
function plotTrajectory(obs, act, per_n, fx_m_wind, epsilon)

    (L,N) = size(obs)
    
    # mean of wind
    p1 = Plots.scatter(0:L, fx_m_wind.(0:L), color="black", xlim= (0,L), grid=true, ylabel="E[w]", legend=false)
    
    # Trajectory per run
    p2 = Plots.plot(legend=false)
    Plots.hline!(p2, [1.0], color="red", ls=:dash)
    for n=1:per_n:N
        Plots.plot!(1:L, obs[:,n], color="black", alpha=0.1)
    end
    Plots.plot!(xlim=(1,L), ylim=(-5.5, 5.5), grid=true, ylabel="Elevation (x)")
    
    # Control signal per run
    p3 = Plots.plot(legend=false)
    for n=1:per_n:N
        Plots.plot!(1:L, act[:,n], color="black", alpha=0.1)
    end
    Plots.plot!(xlim= (1,L), ylim=(-1.5, 2.5), grid=true, ylabel= "Control Signal (a)")
    
    # Violation ratio (of runs) over time
    p4= Plots.plot(legend=false)
    Plots.hline!(p4, [epsilon], color="red", ls=:dash)
    r = vec(mean(obs .< 1.0, dims=2))
    Plots.scatter!(1:L, r, color="black")
    Plots.plot!(xlim=(1,L), ylim=(0, 0.05), grid=true, xlabel="Time (t)", ylabel="Target Violation Ratio")
    
    Plots.plot(p1,p2,p3,p4, layout=(4,1), size=(800,1000), left_margin=5Plots.mm, )
    Plots.savefig("./sim_reference.png")
    #@infiltrate; @assert false
end


# --------------------------------------------------------------------------------------------------
function main()
    #=
    no hidden states, agent directly observes its elevation and wind velocity mean/variance
    x = elevation (e.g., meters)
    T = time for look-ahead horizon (e.g., seconds)
    action = selected control value, vertical velocity (e.g., m/s)
    t = current moment
    k = time steps in look ahead horizon, k= t:t+T-1
    m_w = mean of vertical wind velocity (e.g., m/s)  
    v_w = variance of vertical wind velocity (e.g, m^2)
    u = control variable = action = vertical velocity (e.g, m/s)
    L = simulation time (e.g., seconds)
    
    m_ = mean of _
    v_ = variance of _
    =#
    
    # Simulation parameters
    L = 20
    T = 9
    v_wind = 0.2
    m_goal = 2.0
    v_goal = 0.18478
    lambda = 0.01  # control prior precision
    epsilon = 0.01  # unsafe mass
    
    fx_m_wind(t::Int64) = 5<=t<10 ? -1.0 : 0.0 # Wind mean as function of time t
    N = 100  # number of trials
    
    # Step through experimental protocol
    actions =  Matrix{Float64}(undef, L, N) 
    observations = Matrix{Float64}(undef, L, N) 
        
    for n in 1:N
        # Let there be a world
        (execute_ai, observe_ai) = createWorld(
            v_wind = v_wind,
        ) 

        # Let there be an agent
        (infer_ai, act_ai, slide_ai, future_ai) = createAgent(; 
            T  = T, 
            fx_m_wind = fx_m_wind,
            v_wind = v_wind,
            m_goal = m_goal, 
            v_goal = v_goal, 
            lambda = lambda,
        ) 
            
        for t=1:L
            actions[t,n] = act_ai()  # invoke an action from the agent
            futures = future_ai()  # fetch the predicted future states
            execute_ai(actions[t,n], fx_m_wind(t))  # the action influences hidden external states
            observations[t,n] = observe_ai()  # observe the current environmental outcome
            infer_ai(actions[t,n], observations[t,n], t)  # infer beliefs from current model state
            slide_ai()  # prepare for next iteration
        end
    end
        
    E_avg = round(mean(sum(actions.^2, dims=1)), digits=2) # Average quadratic cost of control (over all runs)
    @show E_avg
    
    per_n = ceil(Int, N/100) # Plot one in every per_n trajectories
    plotTrajectory(observations, actions, per_n, fx_m_wind, epsilon)
    
end


end  # module ----------------------------------------------------------

RXRef.main()

Hwy @John_Boik , thanks for trying out RxInfer. We appreciate all feedback from our users.
Your message is quite big, so please don’t hesitate to ask something second time if I forgot some important question. So:

  1. But from the documentation, it appears that there are not any callback choices (e.g., before_data_update ) that would allow for halting iterations.

That is a very valid concern and we should implement that. In fact I can even work on this feature this week, should be relatively easy to implement. Would you mind opening an issue in the RxInfer repository such that we can track a progress on that? In the issue you can also explain in more details, for example, what arguments would you expect for such callback.

  1. If I remove the comments, nothing is printed. I’m just not sure where this code is suppose to go, or whether I have specified it correctly.

The code with Free Energy should not print anything unless actual iterations are performed, so without valid calls to update! it works as expected. RxInfer exports “observables” and you can subscribe on changes in those observables, but without actual inference there are no changes to react on.

  1. Is that about correct? The above does not work, of course, as update! does not work.

Without diving into too much details I would say the code looks correct (well the result variable is not defined, but I suppose it just does not go there at this point :slight_smile: ).

EDIT: After looking at your code second time I noticed that the way you use and subscribe to u_values = keep(Marginal) is not correct. You should either create u_values = [ keep(Marginal) for i in 1:T ] or easier would be:

u_values = keep(Vector{Marginal})
u_subscription = subscribe!(getmarginals(u), u_values) # notice `s` at the `getmarginalS`

Do the same for remaining vectors or random variables and remove saving marginals for x_t_last because it simply points to the x. Just save x instead.

Then you can simply write the following in your for loop and the inference should start running:

for i in 1:n_iterations
    update!(model[:m_u], data[:m_u])
    update!(model[:v_u], data[:v_u])
    ... # write this for all data entries
end

And then you can perform some checks in between to halt the iterations.

  1. So, I’m not even sure that multiple iterations is doing anything.

From the model definition you provided I can see that you don’t use the @constraints macro nor constraints = ... option in the inference fucntion. It means that you perform CBFE optimization subject to no constraints (equivalent to Belief Propagation or Sum-Product), which also means that you get the exact solution on the first iteration. So such behaviour is expected, but be aware that exact inference is not possible in many models. If exact inference is not possible you will need to use @constraints and more iterations to converge to an approximate answer.

For a very detailed description (but also quite mathematically involved) about inference constraints I can refer you to this paper (RxInfer essentially implements this paper).

For example the Mountain Car example, which you were referring to, also does not use iterations because it executes exact inference. Perhaps in the van de Laar change constraint paper they used some extra inference constraints such that iterations matter?

EDIT: I double checked this example and indeed I can see that Thijs uses extra factorisation constraints

q = PosteriorFactorization(u, x, ids=[:U, :X]) # Factorization of the variational posterior

in RxInfer you can achieve the same by doing

@constraints function my_constraints()
    q(u, x) = q(u)q(x)
end

and calling it as

inference(constraints = my_constraints(), ...)

or passing it directly to the create_model function:

create_model(my_model(), constraints = my_constraints())

RxInfer tries to find some of “obvious” constraints automatically. So I’m not sure if this factorisation constraint will change anything for you particular model (most probably not, + operator ignores factorisation constraints). But I can see that Thijs also has more inference constraints, such as PointMass constraint or ChanceConstraint that act on a variable. You can do something similar by doing:

@constraints function my_constraints(some_argument)
    q(u, x) = q(u)q(x)
    q(x) :: MyConstraint(some_argument)
end

And implement your own MyConstraint by following the documentation here.

There is a second use case, however, when iterations do matter. If you have loops in your model the inference procedure essentially becomes the Loopy Belief Propagation where it also gives an approximate solution, but without convergence guaranties. But as far as I can see you model does not have loops in it.

Thanks for you questions again. Please don’t hesitate to ask twice If I forgot something in your message.

Hi @John_Boik , I hope Dmitry’s answers are helpful. Regarding the questions posted on github, this is what I can offer:

  1. The slide_msg_idx is legacy code from ForneyLab, in which it was possible to populate a dictionary with the messages passed around during inference. In ReactiveMP, messages are passed asynchronously and tracking them is harder (although it is possible, see this issue). In the mountain car example, the correct slide_msg_idx was found manually. As far as I know, there is no easy procedure for finding the index of a message at a specific node.
  2. Adding a TruncatedGaussian node is a procedure that requires adding a node specification and, indeed, a rule specification to the source code. This requires some understanding of the internals of ReactiveMP and RxInfer, and it might therefore be harder to do for external contributors (although the plan is to add a manual to the ReactiveMP documentation, see here). I would recommend waiting for us to implement it (it is on the roadmap with an ETA within 6 weeks). If you don’t want to, I can post some instructions on how to add a node and try to guide you through it.
  3. Those two steps (adding a node and a correct constraint specification) are necessary but possibly not sufficient conditions. Thijs will know whether more is needed, but he’s been quite busy this week.
1 Like

Thanks @wmkouw and @bvdmitri.
With your suggestions @bvdmitri, the reference model using the update! function now runs without error. Now I’ve started on the chance constraint version, just to see how far I can get with a little assistance here. I have implemented a function for the truncated Gaussian (which works) and also a node and rule for the ChanceConstraint. But I must have misspecified the node or rule in some way, or I’m not calling ChanceConstraint correctly. I get the error:

ERROR: LoadError: MethodError: no method matching make_node(::Type{Main.RXChanceUpdate.ChanceConstraint}, ::ReactiveMP.FactorNodeCreationOptions{NTuple{7, Tuple{Int64}}, Nothing, Nothing}, ::ReactiveMP.RandomVariable, ::ReactiveMP.DataVariable{ReactiveMP.PointMass{Float64}, Rocket.RecentSubjectInstance{ReactiveMP.Message{ReactiveMP.PointMass{Float64}}, Rocket.Subject{ReactiveMP.Message{ReactiveMP.PointMass{Float64}}, Rocket.AsapScheduler, Rocket.AsapScheduler}}}, ::ReactiveMP.DataVariable{ReactiveMP.PointMass{Float64}, Rocket.RecentSubjectInstance{ReactiveMP.Message{ReactiveMP.PointMass{Float64}}, Rocket.Subject{ReactiveMP.Message{ReactiveMP.PointMass{Float64}}, Rocket.AsapScheduler, Rocket.AsapScheduler}}}, ::Tuple{ReactiveMP.ConstVariable{ReactiveMP.PointMass{Float64}, Rocket.SingleObservable{ReactiveMP.Message{ReactiveMP.PointMass{Float64}, Nothing}, Rocket.AsapScheduler}}, ReactiveMP.ConstVariable{ReactiveMP.PointMass{Float64}, Rocket.SingleObservable{ReactiveMP.Message{ReactiveMP.PointMass{Float64}, Nothing}, Rocket.AsapScheduler}}}, ::ReactiveMP.ConstVariable{ReactiveMP.PointMass{Float64}, Rocket.SingleObservable{ReactiveMP.Message{ReactiveMP.PointMass{Float64}, Nothing}, Rocket.AsapScheduler}}, ::ReactiveMP.ConstVariable{ReactiveMP.PointMass{Float64}, Rocket.SingleObservable{ReactiveMP.Message{ReactiveMP.PointMass{Float64}, Nothing}, Rocket.AsapScheduler}})
Closest candidates are:
  make_node(::Union{Type{Type{<:Main.RXChanceUpdate.ChanceConstraint}}, Type{<:Main.RXChanceUpdate.ChanceConstraint}}, ::ReactiveMP.FactorNodeCreationOptions, ::ReactiveMP.AbstractVariable, ::ReactiveMP.AbstractVariable, ::ReactiveMP.AbstractVariable, ::ReactiveMP.AbstractVariable, ::ReactiveMP.AbstractVariable, ::ReactiveMP.AbstractVariable) at ~/.julia/packages/ReactiveMP/nrbxT/src/node.jl:1201
  make_node(::Union{Type{Type{<:Main.RXChanceUpdate.ChanceConstraint}}, Type{<:Main.RXChanceUpdate.ChanceConstraint}}, ::ReactiveMP.FactorNodeCreationOptions, ::ReactiveMP.AbstractVariable...) at ~/.julia/packages/ReactiveMP/nrbxT/src/node.jl:1211
  make_node(::FactorGraphModel, ::ReactiveMP.FactorNodeCreationOptions, ::Any, ::Any...) at ~/.julia/packages/RxInfer/NiAqM/src/model.jl:339
  ...

I don’t understand what’s wrong from the error, other than I’m calling something wrong. The problem is probably something simple. My code is almost identical to that of the reference model, except now I include the truncated Gaussian function, a node, a rule, and I call ChanceConstraint in the model. Right now, I’m only trying to get the code to run enough to bring me into the first line of the rule, then I will fix up the body of the rule later. So, below I only show the code for the node, rule, and how I call ChanceConstraint in the @model function. Any ideas why the error occurs? I created the node and rule by following the Advanced Tutorial documentation, as best I could.

The following is in the global space of my package:

struct ChanceConstraint end  # struct to keep Julia happy
@node ChanceConstraint Stochastic [out, mu, V, G, epsilon, atol]

@rule ChanceConstraint(:out, Marginalisation) (mu::Any, V::Any, G::Any, epsilon::Any, atol::Any) = begin 
    @infiltrate; @assert false  # just trying to get this far without error
end

The following is in my model function:

# loop over horizon
for k = 1:T
    #x[k] ~ GaussianMeanVariance(m_x[k], v_x[k])   # goal prior, for reference model
    x[k] ~ ChanceConstraint(m_x[k], v_x[k], (1.0, Inf), epsilon, atol)  # fails here
    # .....
end

@John_Boik I think the problem here is that you pass a tuple (1.0, Inf) as one input to the G. It is not really supported at the moment and the error message is unhelpful. What you can do is split G into G1 and G2 and pass 1.0 and Inf separately.

EDIT: or you can try to create a constvar reference to G like

G = constvar((1.0, Inf))
# and pass it like
x[k] ~ ChanceConstraints(m_x[k], v_x[k], G, epsilon, atol) # our @model macro is confused with the tulle

We have plans to redesign the node and rules such that this case will be supported as well, but this is work in progress and is not available for regular users.

@bvdmitri Unfortunately, splitting the input tuple was not a fix. I’ve simplified the problem below, now using only two parameters. I also include fields in the ChanceConstraint struct, as without them I get the error:

MethodError: no method matching Main.RXChanceUpdate.ChanceConstraint(::Float64, ::Float64)
struct ChanceConstraint 
    u 
    V 
    function ChanceConstraint(u, V)
        new(u, V)
    end
end  

@node ChanceConstraint Stochastic [out, u, V]

@rule ChanceConstraint(:out, Marginalisation) (u::Any, V::Any) = begin 
    @infiltrate; @assert false  # just trying to get this far without error
end

I call ChanceConstraint in my model function like so, where x is a randomvar array:

x[k] ~ ChanceConstraint(1., 1.)

This gives the error:

MethodError: no method matching ~(::ReactiveMP.RandomVariable, ::Main.RXChanceUpdate.ChanceConstraint)

I can call just ChanceConstraint(1., 1.) without error. Seems that something is still misspecified. Any ideas?

The original ForneyLab code included interfaces::Array{Interface,1} in the structure, and associated out to the interfaces. But I assume this is no longer necessary. When I look at how ReactiveMP creates the struct, node, and rule for NormalMeanVariance, it looks similar to what I have done above, as far as I can tell.

@John_Boik

I think you’re doing everything correct and, actually, I cannot reproduce your errors. I don’t have an access to the model you’re trying to run, but I just created a simple model where I used ChanceConstraint node and everything worked fine (with the exception of missing rule)

x[k] ~ ChanceConstraint(1., 1.)

This kind of error cannot happen in the @model macro as all ~ operators are rewritten at parse-time.

MethodError: no method matching ~(...)

Perhaps you tried to use ~ expression outside of the @model macro?

struct ChanceConstraint 
    u 
    V 
    function ChanceConstraint(u, V)
        new(u, V)
    end
end 

This kind of struct specification is redundant. You could just write

struct ChanceConstraint 
    u 
    V 
end

and julia will generate the default constructor automatically. I’m mentoning this because it is very easy to override the default constructor such that Julia cannot longer create a particular object and then you get the weird errors like:

MethodError: no method matching Main.RXChanceUpdate.ChanceConstraint(::Float64, ::Float64)

but I don’t think it has anything to do with RxInfer. RxInfer is not trying to create an actual object with fields. Number of edges or their names do not have any connection to the fields of the object. This is why you can simply write

struct ChanceConstraint end # does not matter that the struct does not have `u` and `V` fields

@node ChanceConstraint Stochastic [out, u, V]

The error message you get makes me think that you are indeed trying to execute the x[k] ~ ChanceConstraint(1., 1.) expression outside of the @model macro, which is not supposed to work, but the full model specification would be definitely helpful to understand the core of the issue.

Splitting input tuple was also a fix for me and I didn’t get any errors with the following expression:

@node ChanceConstraint Stochastic [out, mu, V, G1, G2, epsilon, atol]

# and then in the model macro worked perfectly fine
@model mymodel(n)
    ...
    for k in 1:n
        x[k] ~ ChanceConstraint(m_x[k], v_x[k], 1.0, Inf, epsilon, atol)
    end
    ...
end

As I mentioned you should also be able to wrap the tuple into the constvar like that:

@node ChanceConstraint Stochastic [out, mu, V, G, epsilon, atol]

# and then in the model macro
@model function mymodel(n)
    
    G = constvar((1.0, Inf))
    
    for k in 1:n
        ...
        x[k] ~ ChanceConstraint(m_x[k], v_x[k], G, epsilon, atol)
        ...
    end
end

Perhaps you can share the model specification script with us such that we can see what could be wrong, but as far as I can tell your example should just work.

Thanks @bvdmitri. You were right. I use Infiltrator for debugging and had been halting the code inside the model function and then calling ChanceConstraint. So in that sense, I was outside the model.

I’ve made progress and might be close to finished, but a last few, important issues remain. The code below runs without error. In main() I am simply attempting to calculate the control law (for zero wind speed). If I can get that far, I can finish the code. The action, a, at the end of main() should be a vector that starts at about 3 and falls to 0 (for elevation observations that start at -1 and end at 5). The futures, f, should all be about 2. But currently, my a vector starts at 2.7 and ends at -2.7. All values in f hover near 5.0.

The issue, of course, is the chance constraint node. In the original van de Laar code, it seems they define a rule for calculating the out arm of the chance constraint node, but not for any messages going in the reverse direction. And in the paper they speak of a “variable” node for x_j (x is the elevation of the drone), but I thought factor graphs only have functions as nodes. So, I’m not exactly sure how they construct their model.

My model is constructed so that the output of the action+wind+x[k] is the input to the chance constraint node. The output of the chance constraint node is x[k+1], which is also wired to an observation node. The rule for the :out arm of the chance constraint node and the associated truncated Gaussian function seem to work fine. What I don’t understand is how to construct the rules for messages traveling in the reverse direction. Ignoring wind for simplicity, a sketch of my model is below, where u is the control action and x is the drone elevation. I use a dummy variable to hold x[k]+u because the input and output of the chance constraint node both can’t be x[k+1]. For the observed x, I use a tiny variance.

                u[k]
                 |
                 |
    x[k-1] ---> x[k] + u  ---> ChanceConstraint ---> = --->  x[k]
                                                     |
                                                     |
                                                  obs_x[k] 

I did not see any rules for the reverse messages in the original code, and I’m not sure what they should look like. I currently use dummy rules (rules 2 and 3 in the code) that just return the mean and variance of whatever was passed in. Obviously, this needs to be fixed. Because the rules are not correct, to get reasonable output I had to specify a mean and variance of (2., 10.) for observations of x[k], k>1, in the createAgent function. With correct rules, I should be able to use a vague setting of (0.0, huge).

Any suggestions as to what the reverse message rules should look like? Code is below.

module RXChance


import Plots
import Random
import Distributions
import StatsFuns

using Formatting
using Infiltrator
using Revise
using RxInfer

Random.seed!(51233) # Set random seed for reproducibility

Plots.scalefontsizes()
Plots.scalefontsizes(0.8)


# ==================================================================================================

# --------------------------------------------------------------------------------------------------
function truncatedGaussianMoments(m::Float64, V::Float64, a::Float64, b::Float64)
    
    V = clamp(V, tiny, huge)
    StdG = Distributions.Normal(m, sqrt(V))
    TrG = Distributions.Truncated(StdG, a, b)
    
    Z = Distributions.cdf(StdG, b) - Distributions.cdf(StdG, a)  # safe mass for standard Gaussian
    
    if Z < tiny
        # Invalid region; return undefined mean and variance of truncated distribution
        Z    = 0.0
        m_tr = 0.0
        V_tr = 0.0
    else
        m_tr = Distributions.mean(TrG)
        V_tr = Distributions.var(TrG)
    end
    
    #@infiltrate; @assert false
    return (Z, m_tr, V_tr) 
end


# ================ global ==========================================================================

# struct to keep Julia happy
struct ChanceConstraint end  

@node ChanceConstraint Stochastic [out, x, lo, hi, epsilon, atol]

@rule ChanceConstraint(:out, Marginalisation) (m_x::NormalMeanVariance, q_lo::PointMass, 
        q_hi::PointMass, q_epsilon::PointMass, q_atol::PointMass) = begin 
    epsilon = mean(q_epsilon)
    atol = mean(q_atol)
    lo = mean(q_lo)
    hi = mean(q_hi)
    (min_G, max_G) = (lo, hi)
    
    (m_bw, V_bw) = mean_var(m_x)
    (xi_bw, W_bw) = (m_bw, 1. /V_bw)  # check division by  zero
    (m_tilde, V_tilde) = (m_bw, V_bw)
    
    # Phi_G is called the "safe mass" 
    (Phi_G, m_G, V_G) = truncatedGaussianMoments(m_bw, V_bw, min_G, max_G) # Compute statistics (and normalizing constant) of q in G
    if epsilon <= 1.0 - Phi_G # If constraint is active
        # Initialize statistics of uncorrected belief
        m_tilde = m_bw
        V_tilde = V_bw
        for i = 1:100 # Iterate at most this many times
            (Phi_lG, m_lG, V_lG) = truncatedGaussianMoments(m_tilde, V_tilde, -Inf, min_G) # Statistics for q in region left of G
            (Phi_rG, m_rG, V_rG) = truncatedGaussianMoments(m_tilde, V_tilde, max_G, Inf) # Statistics for q in region right of G

            # Compute moments of non-G region as a mixture of left and right truncations
            Phi_nG = Phi_lG + Phi_rG
            m_nG = Phi_lG / Phi_nG * m_lG + Phi_rG / Phi_nG * m_rG
            V_nG = Phi_lG / Phi_nG * (V_lG + m_lG^2) + Phi_rG/Phi_nG * (V_rG + m_rG^2) - m_nG^2

            # Compute moments of corrected belief as a mixture of G and non-G regions
            m_tilde = (1.0 - epsilon) * m_G + epsilon * m_nG
            V_tilde = (1.0 - epsilon) * (V_G + m_G^2) + epsilon * (V_nG + m_nG^2) - m_tilde^2
            # Re-compute statistics (and normalizing constant) of corrected belief
            (Phi_G, m_G, V_G) = truncatedGaussianMoments(m_tilde, V_tilde, min_G, max_G)
            if (1.0 - Phi_G) < (1.0 + atol)*epsilon
                ii = i
                break # Break the loop if the belief is sufficiently corrected
            end
        end
    end
    #@infiltrate; @assert false
    return NormalMeanPrecision(m_tilde, 1. / V_tilde)
end

@rule ChanceConstraint(:x, Marginalisation) (m_out::NormalMeanVariance, q_lo::PointMass, 
        q_hi::PointMass, q_epsilon::PointMass, q_atol::PointMass) = begin 
    #@infiltrate; @assert false
    return NormalMeanVariance(mean_var(m_out)...) 
end

@rule ChanceConstraint(:x, Marginalisation) (m_out::NormalWeightedMeanPrecision, q_lo::PointMass, 
        q_hi::PointMass, q_epsilon::PointMass, q_atol::PointMass) = begin 
    #@infiltrate; @assert false
    return NormalMeanVariance(mean_var(m_out)...) 
end


# --------------------------------------------------------------------------------------------------
function createWorld(; v_wind)
    # functions for interacting with the simulated environment.
        
    x_0 = 0.0 # Initial position, drone elevation
    x_t0 = x_0
    x_t1 = x_0
    
    execute = (action_t::Float64, m_wind_t::Float64) -> begin
        # Execute the action
        x_t1 = x_t0 + action_t + m_wind_t + sqrt(v_wind) * randn()  # Update elevation
        x_t0 = x_t1 # Prepare for next step
    end
    
    observe = () -> begin 
        return x_t0  # Observe the current state
    end
    
    return (execute, observe)
end


# --------------------------------------------------------------------------------------------------
@model function chance_constraint_model(; T, lo, hi, epsilon, atol)
    
    # control
    m_u = datavar(Float64, T)
    v_u = datavar(Float64, T)
    
    # obs
    m_x = datavar(Float64, T)
    v_x = datavar(Float64, T)    
    
    # wind
    m_w = datavar(Float64, T)
    v_w = datavar(Float64, T)
    
    x_k_last = NormalMeanVariance(0.0, huge)
    
    # random variables
    u = randomvar(T)  # control
    x = randomvar(T)  # height
    uw = randomvar(T)  # control + wind variance
    x_dummy = randomvar(T)  # dummy
    
    # loop over horizon
    for k = 1:T
        x[k] ~ GaussianMeanVariance(m_x[k], v_x[k])   # x[1] = agent_obs_t, else vague goal for k>1
        u[k] ~ NormalMeanVariance(m_u[k], v_u[k])   # control prior
        uw[k] ~ NormalMeanVariance(u[k], v_w[k])    # control + wind variance
        x_dummy[k] ~ x_k_last + uw[k] + m_w[k]
        x[k] ~ ChanceConstraint(x_dummy[k], lo, hi, epsilon, atol)  # where { q = MeanField() } 
        x_k_last = x[k]
    end
    
    #@infiltrate; @assert false
    return (u, uw, x, x_dummy)
end


# --------------------------------------------------------------------------------------------------
function createAgent(; T, fx_m_wind, v_wind, m_goal, v_goal, lambda, n_iters, lo, hi, epsilon, atol)
    
    # control prior
    m_u = Vector{Float64}([0.0 for k=1:T ])
    v_u = Vector{Float64}([lambda^(-1) for k=1:T])
    
    # goal
    m_x = Vector{Float64}([2. for k=1:T])  
    v_x = Vector{Float64}([10. for k=1:T])  # should be vague, but that doesnt work
    
    # wind
    m_w = Vector{Float64}([0.0 for k=1:T])
    v_w = Vector{Float64}([v_wind for k=1:T])
    
    # Set current inference results
    result = nothing
    infer = (agent_action_t::Float64, agent_obs_t::Float64, t::Int) -> begin
        
        m_w[:] = [fx_m_wind(t+k) for k in 0:T-1]  # mean wind
        m_u[1] = agent_action_t  # register action with the generative model
        v_u[1] = tiny  # clamp control prior to performed action
        m_x[1] = agent_obs_t  # register observation with the generative model
        v_x[1] = tiny  # clamp goal prior to observation
        
        data = Dict(
            :m_u => m_u, 
            :v_u => v_u, 
            :m_x => m_x, 
            :v_x => v_x,
            :m_w => m_w,
            :v_w => v_w,
        )
                
        model, (u, uw, x, x_dummy) = create_model(chance_constraint_model(; T=T, lo=lo, hi=hi, 
            epsilon=epsilon, atol=atol))
        
        #bfe_observable = score(model, Float64, BetheFreeEnergy())
        #bfe_subscription = subscribe!(bfe_observable, (fe) -> println("Current BFE value: ", fe));

        setmarginal!.(u, NormalMeanVariance(0.0, huge)) 
        setmarginal!.(uw, NormalMeanVariance(0.0, huge))
        setmarginal!.(x, NormalMeanVariance(0.0, huge))
        setmarginal!.(x_dummy, NormalMeanVariance(0.0, huge))
                
        u_values = keep(Vector{Marginal})
        u_subscription = subscribe!(getmarginals(u), u_values)
        
        uw_values = keep(Vector{Marginal})
        uw_subscription = subscribe!(getmarginals(uw), uw_values)
        
        x_values = keep(Vector{Marginal})
        x_subscription = subscribe!(getmarginals(x), x_values)
        
        (Phi_G, m_tr, V_tr) = (nothing, nothing, nothing)
        ii = nothing
        for i in 1:n_iters
            ii = i
            update!(model[:m_u], data[:m_u])
            update!(model[:v_u], data[:v_u])
            
            update!(model[:m_w], data[:m_w])
            update!(model[:v_w], data[:v_w])
            
            update!(model[:m_x], data[:m_x])
            update!(model[:v_x], data[:v_x])
            
            # Check convergence of x[2]
            (m_b, v_b) = mean_var(x_values[end][2])
                        
            (Phi_G, m_tr, V_tr) = truncatedGaussianMoments(m_b, v_b, lo, hi)
                        
            if (i > 10) && ((1.0 - Phi_G) < (1.0 + atol) * epsilon)  # same test as in chance constraint
                break # Break the loop if the belief is sufficiently corrected
            end
        end
        #@infiltrate; @assert false
        result = (
            u = u_values,
            uw = uw_values,
            x = x_values,
            )
        
        unsubscribe!([u_subscription, uw_subscription, x_subscription])
        
        #@infiltrate; @assert false
        return result
    end
    

    # The `future` function returns the inferred future states
    future = () -> begin 
        if result !== nothing 
            return mode.(result.x[end])
        else
            return zeros(T)
        end
    end


    # slide the time frame forward
    slide = () -> begin
                
        # these are not actually necessary for this simple problem, as the vectors do not change
        m_u = circshift(m_u, -1)
        m_u[end] = 0.0
        
        v_u = circshift(v_u, -1)
        v_u[end] = lambda^(-1)

        m_x = circshift(m_x, -1)
        m_x[end] = m_goal
        
        v_x = circshift(v_x, -1)
        v_x[end] = v_goal
    end
    
    
    # The `act` function returns the inferred best possible action
    act = () -> begin
        if result !== nothing
            return mode(result.u[end][2])  # return the next, (k=2), action
        else
            return 0.0 # Without inference result we return some 'random' action
        end
    end
    
    return (infer, act, slide, future)    
end


# --------------------------------------------------------------------------------------------------
function main()
    # Simulation parameters
    L = 20
    T = 10
    v_wind = 0.2
    m_goal = 2.0
    v_goal = 0.18478
    lambda = 0.01  # control prior precision
    lo = 1.0
    hi = Inf
    epsilon = 0.01  # unsafe mass
    atol = 0.01 # Convergence tolerance for chance constraints
    n_iters = 100  # number of inference iterations
    N = 100  # number of trials
    
        
    # plot control law --------------
    fx_m_wind_law(t::Int64) = 0.0 # Wind mean as function of time t
    (execute_ai, observe_ai) = createWorld(
        v_wind = 0.0,
    ) 
    # Let there be an agent
    (infer_ai, act_ai, slide_ai, future_ai) = createAgent(; 
        T  = T, 
        fx_m_wind = fx_m_wind_law,
        v_wind = 0.0,
        m_goal = m_goal, 
        v_goal = v_goal, 
        lambda = lambda,
        n_iters = n_iters,
        lo = lo,
        hi = hi,
        epsilon =  epsilon,
        atol = atol,
    ) 
    x_hat = -1.0:0.1:5.0 # trial elevations for rendering actions
    K = length(x_hat)
    a = Vector{Float64}(undef, K)  # actions
    f = Vector{Float64}(undef, K)  # next futures at k=2
    for k = 1:K
        infer_ai(0.0, x_hat[k], 1) # Perform inference under wind mean of zero
        a[k] = act_ai() # Extract corresponding action
        f[k] = future_ai()[2] 
    end
    
    # actions a vs. x_hat should look like control law graph, and all f should be > 1.0
    @infiltrate; @assert false
end

end  # module ----------------------------------------------------------

RXChance.main()

Hi @John_Boik, thanks for giving the chance constraints a spin. It would indeed be very nice to have this functionality in RxInfer as well. Although I’m not familiar with the RxInfer implementation details, I can offer some insights on the high-level idea.

A chance-constraint is meant to constraint a marginal distribution to abide by certain properties. In this case, a (posterior) probability distribution should not “overflow” a given region by more than a certain probability mass. This constraint then affects adjacent beliefs and ultimately the controls to (hopefully) account for the imposed constraint.

In order to enforce this constraint on a marginal distribution, an auxiliary chance-constraint node is included in the graphical model. This node then sends messages that enforce the marginal to abide by the preset conditions. In other words, the (chance) constraint on the (posterior) marginal, is converted to a prior constraint on the generative model that sends an adaptive message.

In your model, the chance-constraint is viewed as part of the transition model, which is not its intended use. If I would amend your model the FFG would look as follows, where I’ve indicated the messages to and from the chance -constraint node:

            [ ]      [ChanceConstraint]
             | u[k]      |
    x[k-1]   |      (2)v | ^(1)       x[k]
       ---> [+] ------> [=] ---> [=] --->
                                  |
                                  | y[k]        

In this view, there is only one message (2) that is sent by the chance-constraint node. This message is specifically designed to impose the required properties on q(x[k]), and depends on the incoming message (1).

Hope this clears up some confusion, let me know if you have further questions.

Thanks @ThijsvdLaar. That does clarify your intention. I have actually tried that arrangement, but was unable to code it. Let x[k] = x[k-1] + u[k] from your graph. Then if I understand correctly, the code in the model for the chance constraint node would need to look something like:

x[k] ~ ChanceConstraint(x[k], ...)

as x[k] is both the input and output to the node. But using the same variable for input and output is not allowed in RXInfer.

It would not help to use

xb[k] ~ ChanceConstraint(x[k], ...)

for some dummy variable xb because then you would need a statement like

x[k] ~ xb[k]

which is also not allowed.

Or, perhaps I am thinking about its implementation wrong.

@ThijsvdLaar, @bvdmitri Thanks for your help on this. It has been a good learning experience. I look forward to seeing how you code the problem in the end. I’m stuck, so will wait.

To add to the last post, I can use a dummy variable in the model, as shown below. I believe this captures the idea in Thijs’s graph. But the same problem remains about how to specify the rules for reverse messages. For the (three) reverse message rules that are required, I just return the mean and variance of what was sent in. The where { q = MeanField() is necessary; without it, update! does nothing.

x[k] ~ GaussianMeanVariance(m_x[k], v_x[k])   # x[1] = agent_obs_t, else vague goal for k>1
u[k] ~ NormalMeanVariance(m_u[k], v_u[k])   # control prior
uw[k] ~ NormalMeanVariance(u[k], v_w[k])    # control + wind variance
x_dummy[k] ~ x_k_last + uw[k] + m_w[k]
x[k] ~ ChanceConstraint(x_dummy[k], lo, hi, epsilon, atol)   where { q = MeanField() }
x_k_last = x[k]

Probably a different set of reverse message rules is needed. Note that if I fix observations in CreateAgent like I do in the reference model, for example:

m_x = Vector{Float64}([m_goal for k=1:T])  
v_x = Vector{Float64}([100.0 for k=1:T])

then I obtain almost identical output to what the reference model produces using the same target variance (in this example, 100). This tells me that the chance constraint node is not having any impact. As an aside, multiple iterations in the infer function are necessary (with only one iteration, results are terrible). It does not matter how many iterations are allowed in the chance constraint node. One works the same as 100 (typically, 18 are used).

If it’s of any help, I copy my full code below.

module RXChance


import Plots
import Random
import Distributions
import StatsFuns

using Formatting
using Infiltrator
using Revise
using RxInfer

Random.seed!(51233) # Set random seed for reproducibility

Plots.scalefontsizes()
Plots.scalefontsizes(0.8)


# ==================================================================================================

# --------------------------------------------------------------------------------------------------
function truncatedGaussianMoments(m::Float64, V::Float64, a::Float64, b::Float64)
    
    V = clamp(V, tiny, huge)
    StdG = Distributions.Normal(m, sqrt(V))
    TrG = Distributions.Truncated(StdG, a, b)
    
    Z = Distributions.cdf(StdG, b) - Distributions.cdf(StdG, a)  # safe mass for standard Gaussian
    
    if Z < tiny
        # Invalid region; return undefined mean and variance of truncated distribution
        Z    = 0.0
        m_tr = 0.0
        V_tr = 0.0
    else
        m_tr = Distributions.mean(TrG)
        V_tr = Distributions.var(TrG)
    end
    
    #@infiltrate; @assert false
    return (Z, m_tr, V_tr) 
end


# ================ global ==========================================================================

# struct to keep Julia happy
struct ChanceConstraint 
end  

@node ChanceConstraint Stochastic [out, x, lo, hi, epsilon, atol]

@rule ChanceConstraint(:out, Marginalisation) (q_x::NormalMeanVariance, q_lo::PointMass, 
        q_hi::PointMass, q_epsilon::PointMass, q_atol::PointMass) = begin 
    epsilon = mean(q_epsilon)
    atol = mean(q_atol)
    lo = mean(q_lo)
    hi = mean(q_hi)
    (min_G, max_G) = (lo, hi)
    
    (m_bw, V_bw) = mean_var(q_x)
    (xi_bw, W_bw) = (m_bw, 1. /V_bw)  # check division by  zero
    (m_tilde, V_tilde) = (m_bw, V_bw)
    
    # Phi_G is called the "safe mass" 
    (Phi_G, m_G, V_G) = truncatedGaussianMoments(m_bw, V_bw, min_G, max_G) # Compute statistics (and normalizing constant) of q in G
    if epsilon <= 1.0 - Phi_G # If constraint is active
        # Initialize statistics of uncorrected belief
        m_tilde = m_bw
        V_tilde = V_bw
        ii = nothing
        for i = 1:100 # Iterate at most this many times
            ii = i
            (Phi_lG, m_lG, V_lG) = truncatedGaussianMoments(m_tilde, V_tilde, -Inf, min_G) # Statistics for q in region left of G
            (Phi_rG, m_rG, V_rG) = truncatedGaussianMoments(m_tilde, V_tilde, max_G, Inf) # Statistics for q in region right of G

            # Compute moments of non-G region as a mixture of left and right truncations
            Phi_nG = Phi_lG + Phi_rG
            m_nG = Phi_lG / Phi_nG * m_lG + Phi_rG / Phi_nG * m_rG
            V_nG = Phi_lG / Phi_nG * (V_lG + m_lG^2) + Phi_rG/Phi_nG * (V_rG + m_rG^2) - m_nG^2

            # Compute moments of corrected belief as a mixture of G and non-G regions
            m_tilde = (1.0 - epsilon) * m_G + epsilon * m_nG
            V_tilde = (1.0 - epsilon) * (V_G + m_G^2) + epsilon * (V_nG + m_nG^2) - m_tilde^2
            # Re-compute statistics (and normalizing constant) of corrected belief
            (Phi_G, m_G, V_G) = truncatedGaussianMoments(m_tilde, V_tilde, min_G, max_G)
            if (1.0 - Phi_G) < (1.0 + atol)*epsilon
                
                break # Break the loop if the belief is sufficiently corrected
            end
        end
        if ii == 100
            @infiltrate; @assert false
        end
        
        # Not sure how this message should be sent. Below is the original code.
        #= 
        # Convert moments of corrected belief to canonical form
        W_tilde = cholinv(V_tilde)
        xi_tilde = W_tilde * m_tilde

        # Compute canonical parameters of forward message
        xi_fw = xi_tilde - xi_bw
        W_fw  = W_tilde - W_bw
    end

    return NormalWeightedMeanPrecision(xi_fw, W_fw)
    =#
    
    end
    #@infiltrate; @assert false
    return NormalMeanVariance(m_tilde, V_tilde)
    
end

# the following are dummy rules, just to make it run without error
@rule ChanceConstraint(:x, Marginalisation) (q_out::NormalMeanVariance, q_lo::PointMass, 
        q_hi::PointMass, q_epsilon::PointMass, q_atol::PointMass) = begin 
    #@infiltrate; @assert false
    return NormalMeanVariance(mean_var(q_out)...) 
end

@rule ChanceConstraint(:x, Marginalisation) (q_out::NormalWeightedMeanPrecision, q_lo::PointMass, 
        q_hi::PointMass, q_epsilon::PointMass, q_atol::PointMass) = begin 
    #@infiltrate; @assert false
    return NormalMeanVariance(mean_var(q_out)...) 
end

@rule ChanceConstraint(:out, Marginalisation) (q_x::NormalWeightedMeanPrecision, q_lo::PointMass, 
        q_hi::PointMass, q_epsilon::PointMass, q_atol::PointMass) = begin 
    #@infiltrate; @assert false
    return NormalMeanVariance(mean_var(q_x)...) 
end


# --------------------------------------------------------------------------------------------------
function createWorld(; v_wind)
    # functions for interacting with the simulated environment.
        
    x_0 = 0.0 # Initial position, drone elevation
    x_t0 = x_0
    x_t1 = x_0
    
    execute = (action_t::Float64, m_wind_t::Float64) -> begin
        # Execute the action
        x_t1 = x_t0 + action_t + m_wind_t + sqrt(v_wind) * randn()  # Update elevation
        x_t0 = x_t1 # Prepare for next step
    end
    
    observe = () -> begin 
        return x_t0  # Observe the current state
    end
    
    return (execute, observe)
end


# --------------------------------------------------------------------------------------------------
@model function chance_constraint_model(; T, lo, hi, epsilon, atol)
        
    # the mean and variance of observation at the last t
    m_x_t_last = datavar(Float64)
    v_x_t_last = datavar(Float64)
    
    # sample the observation at the last t
    x_t_last ~ GaussianMeanVariance(m_x_t_last, v_x_t_last) 
    x_k_last = x_t_last
    
    # control
    m_u = datavar(Float64, T)
    v_u = datavar(Float64, T)
    
    # obs
    m_x = datavar(Float64, T)
    v_x = datavar(Float64, T) 
        
    # wind
    m_w = datavar(Float64, T)
    v_w = datavar(Float64, T)
    
    # random variables
    u = randomvar(T)  # control
    x = randomvar(T)  # height
    uw = randomvar(T)  # control + wind variance
    x_dummy = randomvar(T)
    
    # loop over horizon
    for k = 1:T
        x[k] ~ GaussianMeanVariance(m_x[k], v_x[k])   # x[1] = agent_obs_t, else vague goal for k>1
        u[k] ~ NormalMeanVariance(m_u[k], v_u[k])   # control prior
        uw[k] ~ NormalMeanVariance(u[k], v_w[k])    # control + wind variance
        x_dummy[k] ~ x_k_last + uw[k] + m_w[k]
        x[k] ~ ChanceConstraint(x_dummy[k], lo, hi, epsilon, atol)   where { q = MeanField() }
        x_k_last = x[k]
    end
    
    return (u, uw, x, x_t_last, x_dummy)
end


# --------------------------------------------------------------------------------------------------
function createAgent(; T, fx_m_wind, v_wind, m_goal, v_goal, lambda, n_iters, lo, hi, epsilon, atol)
    
    # control prior
    m_u = Vector{Float64}([0.0 for k=1:T ])
    v_u = Vector{Float64}([lambda^(-1) for k=1:T])
    
    # goal
    m_x = Vector{Float64}([0. for k=1:T])  
    v_x = Vector{Float64}([huge for k=1:T])
    
    # wind
    m_w = Vector{Float64}([0.0 for k=1:T])
    v_w = Vector{Float64}([v_wind for k=1:T])
    
    # initial position and variance
    m_x_t_last = 0.0
    v_x_t_last = convert(Float64, tiny)
    
    # Set current inference results
    result = nothing
    infer = (action_t::Float64, obs_t::Float64, t::Int) -> begin
        m_w[:] = [fx_m_wind(t+k) for k in 0:T-1]  # mean wind
        m_u[1] = action_t  # register action with the generative model
        v_u[1] = tiny  # clamp control prior to performed action
        m_x[1] = obs_t  # register observation with the generative model
        v_x[1] = tiny  # clamp goal prior to observation
                
        data = Dict(
            :m_u => m_u, 
            :v_u => v_u, 
            :m_x => m_x, 
            :v_x => v_x,
            :m_w => m_w,
            :v_w => v_w,
            :m_x_t_last => m_x_t_last,
            :v_x_t_last => v_x_t_last
        )
                
        model, (u, uw, x, x_t_last, x_dummy) = create_model(chance_constraint_model(; T=T, lo=lo,
            hi=hi, epsilon=epsilon, atol=atol))
        
        #bfe_observable = score(model, Float64, BetheFreeEnergy())
        #bfe_subscription = subscribe!(bfe_observable, (fe) -> println("Current BFE value: ", fe));

        setmarginal!.(u, NormalMeanVariance(0.0, huge)) 
        setmarginal!.(uw, NormalMeanVariance(0.0, huge))
        setmarginal!.(x, NormalMeanVariance(0.0, huge))
        setmarginal!.(x_t_last, NormalMeanVariance(0.0, huge))
        setmarginal!.(x_dummy, NormalMeanVariance(0.0, huge))
                
        u_values = keep(Vector{Marginal})
        u_subscription = subscribe!(getmarginals(u), u_values)
        
        uw_values = keep(Vector{Marginal})
        uw_subscription = subscribe!(getmarginals(uw), uw_values)
        
        x_values = keep(Vector{Marginal})
        x_subscription = subscribe!(getmarginals(x), x_values)
        
        (Phi_G, m_tr, V_tr) = (nothing, nothing, nothing)
        ii = nothing
        for i in 1:n_iters
            ii = i
            update!(model[:m_x], data[:m_x])
            update!(model[:v_x], data[:v_x])
            
            update!(model[:m_u], data[:m_u])
            update!(model[:v_u], data[:v_u])
            
            update!(model[:m_w], data[:m_w])
            update!(model[:v_w], data[:v_w])
            
            update!(model[:m_x_t_last], data[:m_x_t_last])
            update!(model[:v_x_t_last], data[:v_x_t_last])
            
            
            # Check convergence
            (m_b, v_b) = mean_var(x_values[end][end])
            (Phi_G, m_tr, V_tr) = truncatedGaussianMoments(m_b, v_b, lo, hi)
               
            if (i > 10) && ((1.0 - Phi_G) < (1.0 + atol) * epsilon)  # same test as in chance constraint
                break # Break the loop if the belief is sufficiently corrected
            end
        end
        result = (
            u = u_values,
            uw = uw_values,
            x = x_values,
            )
        
        unsubscribe!([u_subscription, uw_subscription, x_subscription])
        #@infiltrate; @assert false
        return result
    end
    

    # The `future` function returns the inferred future states
    future = () -> begin 
        if result !== nothing 
            return mode.(result.x[end])
        else
            return zeros(T)
        end
    end


    # slide the time frame forward
    slide = () -> begin
        # no slide is necessary
    end
    
    
    # The `act` function returns the inferred best possible action
    act = () -> begin
        if result !== nothing
            return mode(result.u[end][2])  # return the next, (k=2), action
        else
            return 0.0 # Without inference result we return some 'random' action
        end
    end
    
    return (infer, act, slide, future)    
end


# --------------------------------------------------------------------------------------------------
function main()
    # Simulation parameters
    L = 20
    T = 10
    v_wind = 0.2
    m_goal = 2.0
    v_goal = 0.18478
    lambda = 0.01  # control prior precision
    lo = 1.0
    hi = Inf
    epsilon = 0.01  # unsafe mass
    atol = 0.01 # Convergence tolerance for chance constraints
    n_iters = 100  # number of inference iterations
    N = 100  # number of trials
    
    if false    
        # to plot the control law --------------
        fx_m_wind_law(t::Int64) = 0.0 # Wind mean as function of time t
        (execute_ai, observe_ai) = createWorld(
            v_wind = tiny,
        ) 
        # Let there be an agent
        (infer_ai, act_ai, slide_ai, future_ai) = createAgent(; 
            T  = T, 
            fx_m_wind = fx_m_wind_law,
            v_wind = tiny,
            m_goal = m_goal, 
            v_goal = v_goal, 
            lambda = lambda,
            n_iters = n_iters,
            lo = lo,
            hi = hi,
            epsilon =  epsilon,
            atol = atol,
        ) 
        x_hat = -1.0:0.1:5.0 # trial elevations for rendering actions
        K = length(x_hat)
        a = Vector{Float64}(undef, K)  # actions
        f = Vector{Float64}(undef, K)  # next futures at k=2
        for k = 1:K
            infer_ai(0.0, x_hat[k], 1) # Perform inference under wind mean of zero
            a[k] = act_ai() # Extract corresponding action
            f[k] = future_ai()[2] 
        end
        
        # actions a vs. x_hat should look like control law graph, and all f should be > 1.0
        plotControlLaw(x_hat, a, f)
        #@infiltrate; @assert false
    end
    
    # Step through experimental protocol
    fx_m_wind(t::Int64) = 5<=t<10 || 15<=t<17  ? -1.0 : 0.0 # Wind mean as function of time t
    actions =  Matrix{Float64}(undef, L, N) 
    observations = Matrix{Float64}(undef, L, N) 
    
    for n in 1:N
        # Let there be a world
        (execute_ai, observe_ai) = createWorld(
            v_wind = v_wind,
        ) 

        # Let there be an agent
        (infer_ai, act_ai, slide_ai, future_ai) = createAgent(; 
            T  = T, 
            fx_m_wind = fx_m_wind,
            v_wind = v_wind,
            m_goal = m_goal, 
            v_goal = v_goal, 
            lambda = lambda,
            n_iters = n_iters,
            lo = lo,
            hi = hi,
            epsilon =  epsilon,
            atol = atol,
        ) 
        
        for t=1:L
            actions[t,n] = act_ai()  # invoke an action from the agent
            futures = future_ai()  # fetch the predicted future states
            execute_ai(actions[t,n], fx_m_wind(t))  # the action influences hidden external states
            observations[t,n] = observe_ai()  # observe the current environmental outcome
            infer_ai(actions[t,n], observations[t,n], t)  # infer beliefs from current model state
            slide_ai()  # prepare for next iteration
            #@infiltrate; @assert false
        end
    end
        
    E_avg = round(mean(sum(actions.^2, dims=1)), digits=2) # Average quadratic cost of control (over all runs)
    @show E_avg
    
    rate = sum(observations[2:end,:] .< 1.0) / length(observations[2:end,:])
    @show mean(rate)
    #@infiltrate; @assert false
    
    per_n = ceil(Int, N/100) # Plot one in every per_n trajectories
    plotTrajectory(observations, actions, per_n, fx_m_wind, epsilon)
    #@infiltrate; @assert false
end


# --------------------------------------------------------------------------------------------------
function plotTrajectory(obs, act, per_n, fx_m_wind, epsilon)

    (L,N) = size(obs)
    
    # mean of wind
    p1 = Plots.scatter(0:L, fx_m_wind.(0:L), color="black", xlim= (0,L), grid=true, ylabel="E[w]", legend=false)
    
    # Trajectory per run
    p2 = Plots.plot(legend=false)
    Plots.hline!(p2, [1.0], color="red", ls=:dash)
    for n=1:per_n:N
        Plots.plot!(1:L, obs[:,n], color="black", alpha=0.1)
    end
    Plots.plot!(xlim=(1,L), ylim=(-5.5, 5.5), grid=true, ylabel="Elevation (x)")
    
    # Control signal per run
    p3 = Plots.plot(legend=false)
    for n=1:per_n:N
        Plots.plot!(1:L, act[:,n], color="black", alpha=0.1)
    end
    Plots.plot!(xlim= (1,L), ylim=(-1.5, 2.5), grid=true, ylabel= "Control Signal (a)")
    
    # Violation ratio (of runs) over time
    p4= Plots.plot(legend=false)
    Plots.hline!(p4, [epsilon], color="red", ls=:dash)
    r = vec(mean(obs .< 1.0, dims=2))
    Plots.scatter!(1:L, r, color="black")
    Plots.plot!(xlim=(1,L), ylim=(0, 0.05), grid=true, xlabel="Time (t)", ylabel="Target Violation Ratio")
    
    Plots.plot(p1,p2,p3,p4, layout=(4,1), size=(800,1000), left_margin=5Plots.mm, )
    Plots.savefig("./sim_chance.png")
    #@infiltrate; @assert false
end


# --------------------------------------------------------------------------------------------------
function plotControlLaw(x_hat, a, f)
    Plots.plot(x_hat, a, color=:green, grid="on", xlabel= "Elevation (x)", ylabel= "Value", 
        xlim= extrema(x_hat), ylim= extrema(a), label="action")
    Plots.vline!([1.0], color=:red, ls=:dash, label=false)
    Plots.plot!(x_hat, f, color=:blue, label="obs")
    Plots.savefig("control_law_chance.png")
end



end  # module ----------------------------------------------------------

RXChance.main()

Thanks for the input @John_Boik. Based on your code I’ve implemented a demo for chance-constrained active inference in RxInfer. For simplicity I assumed a fully observed Markov decision process. You can check it out here.

Wonderful @ThijsvdLaar! I see that the big item I was missing was the proper form for:

x[k] ~ ChanceConstraint(lo, hi, epsilon, atol) where { # Simultaneous constraint on state
            pipeline = RequireMessage(out = NormalWeightedMeanPrecision(0, 0.01))} # Predefine inbound message to break circular dependency

and the associated line in the rule:

m_out::UnivariateNormalDistributionsFamily, # Require inbound message

Your help is greatly appreciated.