I am trying to implement the van de Laar chance constraint model in RXInfer, and as a warm-up I have implemented the reference model from the same paper. Code is included at the end of this post that produces an output graph that closely resembles the published one. I used the mountain car example as a template. This post is a continuation of a Github issue for RXInfer.

While the code runs fine, I have questions about modifying it to achieve two goals. Any thoughts, @bvdmitri or others?

First, I would like to use multiple iterations in the inference process, the equivalent of e.g., `iterations=10`

in the `inference`

function. But I want to stop the inference procedure after an iteration if a condition is met. This is what the chance constraint code (written for ForneyLab.jl) does.

I see two possible ways to control iterations: (1) use a loop containing `update!()`

; or (2) use a callback. I discuss issues with both.

The Advanced Tutorial section of the RXInfer documentation appears to implement inference iterations by stepping through a loop, where in each iteration `update!(y, dataset)`

is called. In that example, `y`

is a vector of datavar returned from `create_model()`

, and `dataset`

is a vector of Float64.

But this is not the way my `infer`

function or problem is structured. I set some values in Float64 vectors (e.g., `m_u[1] = agent_action_t`

). I could return a datavar vector `m_u`

from the model and then grab it in my `infer`

function using a `create_model`

statement, but then I couldnâ€™t update m_u[1] manually as required.

What I want is to call `update!`

with something like `update!(data)`

where data is the dict that contains for example `(:m_u => m_u)`

. But that doesnâ€™t work.

An alternative would be to use a callback in the `inference`

function. Building on an example callback in the documentation, the following runs if I first create `function cb(model, name, update) ... end`

:

```
result = inference(
...,
callbacks = (
on_marginal_update = (model, name, update) -> cb(model, name, update),
)
)
```

But from the documentation, it appears that there are not any callback choices (e.g., `before_data_update`

) that would allow for halting iterations.

Any ideas on how to use `update!`

or callbacks in my situation, or how to otherwise halt iterations based on some condition?

Second, I would like to print out free energy values as is done in the Advanced Tutorial example. It uses something akin to:

```
model2, returnval = create_model(model)
bfe_observable = score(model2, Float64, BetheFreeEnergy())
bfe_subscription = subscribe!(bfe_observable, (fe) -> println("Current BFE value: ", fe));
```

This code is commented out in my `infer`

function. If I remove the comments, nothing is printed. Iâ€™m just not sure where this code is suppose to go, or whether I have specified it correctly.

In trying to follow the example in the Advanced Tutorial, I assume that my final `infer`

function after addressing the above two questions, and if I use `update!`

, would look something like:

```
infer = (agent_action_t::Float64, agent_obs_t::Float64, t::Int) -> begin
m_w[:] = [fx_m_wind(t+k) for k in 0:T-1] # mean wind
m_u[1] = agent_action_t # register action with the generative model
v_u[1] = tiny # clamp control prior to performed action
m_x[1] = agent_obs_t # register observation with the generative model
v_x[1] = tiny # clamp goal prior to observation
data = Dict(
:m_u => m_u,
:v_u => v_u,
:m_x => m_x,
:v_x => v_x,
:m_w => m_w,
:v_w => v_w,
:m_x_t_last => m_x_t_last,
:v_x_t_last => v_x_t_last
)
model, (u, uw, x, x_t_last) = create_model(ref_model(; T=T))
bfe_observable = score(model, Float64, BetheFreeEnergy())
bfe_subscription = subscribe!(bfe_observable, (fe) -> println("Current BFE value: ", fe));
setmarginal!.(u, NormalMeanVariance(0.0, huge))
setmarginal!.(uw, NormalMeanVariance(0.0, huge))
setmarginal!.(x, NormalMeanVariance(0.0, huge))
setmarginal!.(x_t_last, NormalMeanVariance(0.0, huge))
u_values = keep(Marginal)
uw_values = keep(Marginal)
x_values = keep(Marginal)
x_t_last_values = keep(Marginal)
u_subscription = [subscribe!(getmarginal(u[i]), u_values) for i in 1:T]
uw_subscription = [subscribe!(getmarginal(uw[i]), uw_values) for i in 1:T]
x_subscription = [subscribe!(getmarginal(x[i]), x_values) for i in 1:T]
x_t_last_subscription = subscribe!(getmarginal(x_t_last), x_t_last_values)
for i in 1:n_iterations
update!(data)
# now perform some test to halt iterations
end
return result # where result is a collection of marginals
end
```

Is that about correct? The above does not work, of course, as `update!`

does not work. And Iâ€™m not sure it would print out free energy as desired. Once I have an answer for my two questions, then I will try to tackle the harder problem of how to specify the `@node`

and `@rules`

code for the chance constraint model.

The code that does work is below. It can be run by using `include("./RXRef.jl")`

in the REPL, if the script is saved with that filename. In brief, the goal is to keep a drone above a target elevation, where the drone is subject to a vertical wind.

As an aside, if I change the code below to e.g., iterations=10 in the `inference`

function, the printout of `E_avg`

at the end does not change from when iterations=1. So, Iâ€™m not even sure that multiple iterations is doing anything.

```
module RXRef
import Plots
import Random
# use some private functionality from ReactiveMP,
import RxInfer.ReactiveMP: getrecent, messageout
using Formatting
using Infiltrator
using Revise
using RxInfer
Random.seed!(51233) # Set random seed for reproducibility
Plots.scalefontsizes()
Plots.scalefontsizes(0.8)
# ==================================================================================================
# --------------------------------------------------------------------------------------------------
function createWorld(; v_wind)
# functions for interacting with the simulated environment.
x_0 = 0.0 # Initial position, drone elevation
x_t0 = x_0
x_t1 = x_0
execute = (action_t::Float64, m_wind_t::Float64) -> begin
# Execute the action
# action_t = ascention velocity, dx = action_t * t, with t=1
x_t1 = x_t0 + action_t + m_wind_t + sqrt(v_wind) * randn() # Update elevation
x_t0 = x_t1 # Prepare for next step
end
observe = () -> begin
return x_t0 # Observe the current state
end
return (execute, observe)
end
# --------------------------------------------------------------------------------------------------
@model function ref_model(; T)
# during each training step, from t to t+1, this function is called only once.
# the mean and variance of observation at the last t
m_x_t_last = datavar(Float64)
v_x_t_last = datavar(Float64)
# sample the observation at the last t
x_t_last ~ GaussianMeanVariance(m_x_t_last, v_x_t_last)
x_k_last = x_t_last
# control
m_u = datavar(Float64, T)
v_u = datavar(Float64, T)
# obs
m_x = datavar(Float64, T)
v_x = datavar(Float64, T)
# wind
m_w = datavar(Float64, T)
v_w = datavar(Float64, T)
# random variables
u = randomvar(T) # control
x = randomvar(T) # height
uw = randomvar(T) # control + wind variance
# loop over horizon
for k = 1:T
x[k] ~ GaussianMeanVariance(m_x[k], v_x[k]) # goal prior
u[k] ~ GaussianMeanVariance(m_u[k], v_u[k]) # control prior
uw[k] ~ GaussianMeanVariance(u[k], v_w[k]) # control + wind variance
x[k] ~ x_k_last + uw[k] + m_w[k]
x_k_last = x[k]
end
return (x,)
end
# --------------------------------------------------------------------------------------------------
function createAgent(; T, fx_m_wind, v_wind, m_goal, v_goal, lambda)
# control prior
m_u = Vector{Float64}([0.0 for k=1:T ])
v_u = Vector{Float64}([lambda^(-1) for k=1:T])
# goal
m_x = Vector{Float64}([m_goal for k=1:T])
v_x = Vector{Float64}([v_goal for k=1:T])
# wind
m_w = Vector{Float64}([0.0 for k=1:T])
v_w = Vector{Float64}([v_wind for k=1:T])
# initial position and variance
m_x_t_last = 0.0
v_x_t_last = convert(Float64, tiny)
# Set current inference results
result = nothing
infer = (agent_action_t::Float64, agent_obs_t::Float64, t::Int) -> begin
m_w[:] = [fx_m_wind(t+k) for k in 0:T-1] # mean wind
m_u[1] = agent_action_t # register action with the generative model
v_u[1] = tiny # clamp control prior to performed action
m_x[1] = agent_obs_t # register observation with the generative model
v_x[1] = tiny # clamp goal prior to observation
data = Dict(
:m_u => m_u,
:v_u => v_u,
:m_x => m_x,
:v_x => v_x,
:m_w => m_w,
:v_w => v_w,
:m_x_t_last => m_x_t_last,
:v_x_t_last => v_x_t_last
)
model = ref_model(; T=T)
#model2, returnval = create_model(model)
#bfe_observable = score(model2, Float64, BetheFreeEnergy())
#bfe_subscription = subscribe!(bfe_observable, (fe) -> println("Current BFE value: ", fe));
result = inference(
model = model,
data = data,
initmarginals = (
u = NormalMeanVariance(0.0, huge),
uw = NormalMeanVariance(0.0, huge),
x = NormalMeanVariance(0.0, huge),
x_t_last = NormalMeanVariance(0.0, huge),
),
iterations = 1,
free_energy = true,
#showprogress = true,
returnvars = (
u = KeepLast(),
uw = KeepLast(),
x_t_last = KeepLast(),
x = KeepLast()
),
)
#unsubscribe!([bfe_subscription, ])
end
# The `act` function returns the inferred best possible action
act = () -> begin
if result !== nothing
return mode(result.posteriors[:u][2]) # index 2 for next step, when iterations=1
else
return 0.0 # Without inference result we return some 'random' action
end
end
# The `future` function returns the inferred future states
future = () -> begin
if result !== nothing
return getindex.(mode.(result.posteriors[:x]), 1) # for iterations =1
else
return zeros(T)
end
end
slide = () -> begin
(x, ) = result.returnval
slide_msg_idx = 3 # This index is model dependent
(m_x_t_last, v_x_t_last) = mean_var(getrecent(messageout(x[2], slide_msg_idx)))
# these are not actually necessary for this simple problem, as the vectors do not change
m_u = circshift(m_u, -1)
m_u[end] = 0.0
v_u = circshift(v_u, -1)
v_u[end] = lambda^(-1)
m_x = circshift(m_x, -1)
m_x[end] = m_goal
v_x = circshift(v_x, -1)
v_x[end] = v_goal
end
return (infer, act, slide, future)
end
# --------------------------------------------------------------------------------------------------
function plotTrajectory(obs, act, per_n, fx_m_wind, epsilon)
(L,N) = size(obs)
# mean of wind
p1 = Plots.scatter(0:L, fx_m_wind.(0:L), color="black", xlim= (0,L), grid=true, ylabel="E[w]", legend=false)
# Trajectory per run
p2 = Plots.plot(legend=false)
Plots.hline!(p2, [1.0], color="red", ls=:dash)
for n=1:per_n:N
Plots.plot!(1:L, obs[:,n], color="black", alpha=0.1)
end
Plots.plot!(xlim=(1,L), ylim=(-5.5, 5.5), grid=true, ylabel="Elevation (x)")
# Control signal per run
p3 = Plots.plot(legend=false)
for n=1:per_n:N
Plots.plot!(1:L, act[:,n], color="black", alpha=0.1)
end
Plots.plot!(xlim= (1,L), ylim=(-1.5, 2.5), grid=true, ylabel= "Control Signal (a)")
# Violation ratio (of runs) over time
p4= Plots.plot(legend=false)
Plots.hline!(p4, [epsilon], color="red", ls=:dash)
r = vec(mean(obs .< 1.0, dims=2))
Plots.scatter!(1:L, r, color="black")
Plots.plot!(xlim=(1,L), ylim=(0, 0.05), grid=true, xlabel="Time (t)", ylabel="Target Violation Ratio")
Plots.plot(p1,p2,p3,p4, layout=(4,1), size=(800,1000), left_margin=5Plots.mm, )
Plots.savefig("./sim_reference.png")
#@infiltrate; @assert false
end
# --------------------------------------------------------------------------------------------------
function main()
#=
no hidden states, agent directly observes its elevation and wind velocity mean/variance
x = elevation (e.g., meters)
T = time for look-ahead horizon (e.g., seconds)
action = selected control value, vertical velocity (e.g., m/s)
t = current moment
k = time steps in look ahead horizon, k= t:t+T-1
m_w = mean of vertical wind velocity (e.g., m/s)
v_w = variance of vertical wind velocity (e.g, m^2)
u = control variable = action = vertical velocity (e.g, m/s)
L = simulation time (e.g., seconds)
m_ = mean of _
v_ = variance of _
=#
# Simulation parameters
L = 20
T = 9
v_wind = 0.2
m_goal = 2.0
v_goal = 0.18478
lambda = 0.01 # control prior precision
epsilon = 0.01 # unsafe mass
fx_m_wind(t::Int64) = 5<=t<10 ? -1.0 : 0.0 # Wind mean as function of time t
N = 100 # number of trials
# Step through experimental protocol
actions = Matrix{Float64}(undef, L, N)
observations = Matrix{Float64}(undef, L, N)
for n in 1:N
# Let there be a world
(execute_ai, observe_ai) = createWorld(
v_wind = v_wind,
)
# Let there be an agent
(infer_ai, act_ai, slide_ai, future_ai) = createAgent(;
T = T,
fx_m_wind = fx_m_wind,
v_wind = v_wind,
m_goal = m_goal,
v_goal = v_goal,
lambda = lambda,
)
for t=1:L
actions[t,n] = act_ai() # invoke an action from the agent
futures = future_ai() # fetch the predicted future states
execute_ai(actions[t,n], fx_m_wind(t)) # the action influences hidden external states
observations[t,n] = observe_ai() # observe the current environmental outcome
infer_ai(actions[t,n], observations[t,n], t) # infer beliefs from current model state
slide_ai() # prepare for next iteration
end
end
E_avg = round(mean(sum(actions.^2, dims=1)), digits=2) # Average quadratic cost of control (over all runs)
@show E_avg
per_n = ceil(Int, N/100) # Plot one in every per_n trajectories
plotTrajectory(observations, actions, per_n, fx_m_wind, epsilon)
end
end # module ----------------------------------------------------------
RXRef.main()
```