I am trying to implement the van de Laar chance constraint model in RXInfer, and as a warm-up I have implemented the reference model from the same paper. Code is included at the end of this post that produces an output graph that closely resembles the published one. I used the mountain car example as a template. This post is a continuation of a Github issue for RXInfer.
While the code runs fine, I have questions about modifying it to achieve two goals. Any thoughts, @bvdmitri or others?
First, I would like to use multiple iterations in the inference process, the equivalent of e.g., iterations=10
in the inference
function. But I want to stop the inference procedure after an iteration if a condition is met. This is what the chance constraint code (written for ForneyLab.jl) does.
I see two possible ways to control iterations: (1) use a loop containing update!()
; or (2) use a callback. I discuss issues with both.
The Advanced Tutorial section of the RXInfer documentation appears to implement inference iterations by stepping through a loop, where in each iteration update!(y, dataset)
is called. In that example, y
is a vector of datavar returned from create_model()
, and dataset
is a vector of Float64.
But this is not the way my infer
function or problem is structured. I set some values in Float64 vectors (e.g., m_u[1] = agent_action_t
). I could return a datavar vector m_u
from the model and then grab it in my infer
function using a create_model
statement, but then I couldn’t update m_u[1] manually as required.
What I want is to call update!
with something like update!(data)
where data is the dict that contains for example (:m_u => m_u)
. But that doesn’t work.
An alternative would be to use a callback in the inference
function. Building on an example callback in the documentation, the following runs if I first create function cb(model, name, update) ... end
:
result = inference(
...,
callbacks = (
on_marginal_update = (model, name, update) -> cb(model, name, update),
)
)
But from the documentation, it appears that there are not any callback choices (e.g., before_data_update
) that would allow for halting iterations.
Any ideas on how to use update!
or callbacks in my situation, or how to otherwise halt iterations based on some condition?
Second, I would like to print out free energy values as is done in the Advanced Tutorial example. It uses something akin to:
model2, returnval = create_model(model)
bfe_observable = score(model2, Float64, BetheFreeEnergy())
bfe_subscription = subscribe!(bfe_observable, (fe) -> println("Current BFE value: ", fe));
This code is commented out in my infer
function. If I remove the comments, nothing is printed. I’m just not sure where this code is suppose to go, or whether I have specified it correctly.
In trying to follow the example in the Advanced Tutorial, I assume that my final infer
function after addressing the above two questions, and if I use update!
, would look something like:
infer = (agent_action_t::Float64, agent_obs_t::Float64, t::Int) -> begin
m_w[:] = [fx_m_wind(t+k) for k in 0:T-1] # mean wind
m_u[1] = agent_action_t # register action with the generative model
v_u[1] = tiny # clamp control prior to performed action
m_x[1] = agent_obs_t # register observation with the generative model
v_x[1] = tiny # clamp goal prior to observation
data = Dict(
:m_u => m_u,
:v_u => v_u,
:m_x => m_x,
:v_x => v_x,
:m_w => m_w,
:v_w => v_w,
:m_x_t_last => m_x_t_last,
:v_x_t_last => v_x_t_last
)
model, (u, uw, x, x_t_last) = create_model(ref_model(; T=T))
bfe_observable = score(model, Float64, BetheFreeEnergy())
bfe_subscription = subscribe!(bfe_observable, (fe) -> println("Current BFE value: ", fe));
setmarginal!.(u, NormalMeanVariance(0.0, huge))
setmarginal!.(uw, NormalMeanVariance(0.0, huge))
setmarginal!.(x, NormalMeanVariance(0.0, huge))
setmarginal!.(x_t_last, NormalMeanVariance(0.0, huge))
u_values = keep(Marginal)
uw_values = keep(Marginal)
x_values = keep(Marginal)
x_t_last_values = keep(Marginal)
u_subscription = [subscribe!(getmarginal(u[i]), u_values) for i in 1:T]
uw_subscription = [subscribe!(getmarginal(uw[i]), uw_values) for i in 1:T]
x_subscription = [subscribe!(getmarginal(x[i]), x_values) for i in 1:T]
x_t_last_subscription = subscribe!(getmarginal(x_t_last), x_t_last_values)
for i in 1:n_iterations
update!(data)
# now perform some test to halt iterations
end
return result # where result is a collection of marginals
end
Is that about correct? The above does not work, of course, as update!
does not work. And I’m not sure it would print out free energy as desired. Once I have an answer for my two questions, then I will try to tackle the harder problem of how to specify the @node
and @rules
code for the chance constraint model.
The code that does work is below. It can be run by using include("./RXRef.jl")
in the REPL, if the script is saved with that filename. In brief, the goal is to keep a drone above a target elevation, where the drone is subject to a vertical wind.
As an aside, if I change the code below to e.g., iterations=10 in the inference
function, the printout of E_avg
at the end does not change from when iterations=1. So, I’m not even sure that multiple iterations is doing anything.
module RXRef
import Plots
import Random
# use some private functionality from ReactiveMP,
import RxInfer.ReactiveMP: getrecent, messageout
using Formatting
using Infiltrator
using Revise
using RxInfer
Random.seed!(51233) # Set random seed for reproducibility
Plots.scalefontsizes()
Plots.scalefontsizes(0.8)
# ==================================================================================================
# --------------------------------------------------------------------------------------------------
function createWorld(; v_wind)
# functions for interacting with the simulated environment.
x_0 = 0.0 # Initial position, drone elevation
x_t0 = x_0
x_t1 = x_0
execute = (action_t::Float64, m_wind_t::Float64) -> begin
# Execute the action
# action_t = ascention velocity, dx = action_t * t, with t=1
x_t1 = x_t0 + action_t + m_wind_t + sqrt(v_wind) * randn() # Update elevation
x_t0 = x_t1 # Prepare for next step
end
observe = () -> begin
return x_t0 # Observe the current state
end
return (execute, observe)
end
# --------------------------------------------------------------------------------------------------
@model function ref_model(; T)
# during each training step, from t to t+1, this function is called only once.
# the mean and variance of observation at the last t
m_x_t_last = datavar(Float64)
v_x_t_last = datavar(Float64)
# sample the observation at the last t
x_t_last ~ GaussianMeanVariance(m_x_t_last, v_x_t_last)
x_k_last = x_t_last
# control
m_u = datavar(Float64, T)
v_u = datavar(Float64, T)
# obs
m_x = datavar(Float64, T)
v_x = datavar(Float64, T)
# wind
m_w = datavar(Float64, T)
v_w = datavar(Float64, T)
# random variables
u = randomvar(T) # control
x = randomvar(T) # height
uw = randomvar(T) # control + wind variance
# loop over horizon
for k = 1:T
x[k] ~ GaussianMeanVariance(m_x[k], v_x[k]) # goal prior
u[k] ~ GaussianMeanVariance(m_u[k], v_u[k]) # control prior
uw[k] ~ GaussianMeanVariance(u[k], v_w[k]) # control + wind variance
x[k] ~ x_k_last + uw[k] + m_w[k]
x_k_last = x[k]
end
return (x,)
end
# --------------------------------------------------------------------------------------------------
function createAgent(; T, fx_m_wind, v_wind, m_goal, v_goal, lambda)
# control prior
m_u = Vector{Float64}([0.0 for k=1:T ])
v_u = Vector{Float64}([lambda^(-1) for k=1:T])
# goal
m_x = Vector{Float64}([m_goal for k=1:T])
v_x = Vector{Float64}([v_goal for k=1:T])
# wind
m_w = Vector{Float64}([0.0 for k=1:T])
v_w = Vector{Float64}([v_wind for k=1:T])
# initial position and variance
m_x_t_last = 0.0
v_x_t_last = convert(Float64, tiny)
# Set current inference results
result = nothing
infer = (agent_action_t::Float64, agent_obs_t::Float64, t::Int) -> begin
m_w[:] = [fx_m_wind(t+k) for k in 0:T-1] # mean wind
m_u[1] = agent_action_t # register action with the generative model
v_u[1] = tiny # clamp control prior to performed action
m_x[1] = agent_obs_t # register observation with the generative model
v_x[1] = tiny # clamp goal prior to observation
data = Dict(
:m_u => m_u,
:v_u => v_u,
:m_x => m_x,
:v_x => v_x,
:m_w => m_w,
:v_w => v_w,
:m_x_t_last => m_x_t_last,
:v_x_t_last => v_x_t_last
)
model = ref_model(; T=T)
#model2, returnval = create_model(model)
#bfe_observable = score(model2, Float64, BetheFreeEnergy())
#bfe_subscription = subscribe!(bfe_observable, (fe) -> println("Current BFE value: ", fe));
result = inference(
model = model,
data = data,
initmarginals = (
u = NormalMeanVariance(0.0, huge),
uw = NormalMeanVariance(0.0, huge),
x = NormalMeanVariance(0.0, huge),
x_t_last = NormalMeanVariance(0.0, huge),
),
iterations = 1,
free_energy = true,
#showprogress = true,
returnvars = (
u = KeepLast(),
uw = KeepLast(),
x_t_last = KeepLast(),
x = KeepLast()
),
)
#unsubscribe!([bfe_subscription, ])
end
# The `act` function returns the inferred best possible action
act = () -> begin
if result !== nothing
return mode(result.posteriors[:u][2]) # index 2 for next step, when iterations=1
else
return 0.0 # Without inference result we return some 'random' action
end
end
# The `future` function returns the inferred future states
future = () -> begin
if result !== nothing
return getindex.(mode.(result.posteriors[:x]), 1) # for iterations =1
else
return zeros(T)
end
end
slide = () -> begin
(x, ) = result.returnval
slide_msg_idx = 3 # This index is model dependent
(m_x_t_last, v_x_t_last) = mean_var(getrecent(messageout(x[2], slide_msg_idx)))
# these are not actually necessary for this simple problem, as the vectors do not change
m_u = circshift(m_u, -1)
m_u[end] = 0.0
v_u = circshift(v_u, -1)
v_u[end] = lambda^(-1)
m_x = circshift(m_x, -1)
m_x[end] = m_goal
v_x = circshift(v_x, -1)
v_x[end] = v_goal
end
return (infer, act, slide, future)
end
# --------------------------------------------------------------------------------------------------
function plotTrajectory(obs, act, per_n, fx_m_wind, epsilon)
(L,N) = size(obs)
# mean of wind
p1 = Plots.scatter(0:L, fx_m_wind.(0:L), color="black", xlim= (0,L), grid=true, ylabel="E[w]", legend=false)
# Trajectory per run
p2 = Plots.plot(legend=false)
Plots.hline!(p2, [1.0], color="red", ls=:dash)
for n=1:per_n:N
Plots.plot!(1:L, obs[:,n], color="black", alpha=0.1)
end
Plots.plot!(xlim=(1,L), ylim=(-5.5, 5.5), grid=true, ylabel="Elevation (x)")
# Control signal per run
p3 = Plots.plot(legend=false)
for n=1:per_n:N
Plots.plot!(1:L, act[:,n], color="black", alpha=0.1)
end
Plots.plot!(xlim= (1,L), ylim=(-1.5, 2.5), grid=true, ylabel= "Control Signal (a)")
# Violation ratio (of runs) over time
p4= Plots.plot(legend=false)
Plots.hline!(p4, [epsilon], color="red", ls=:dash)
r = vec(mean(obs .< 1.0, dims=2))
Plots.scatter!(1:L, r, color="black")
Plots.plot!(xlim=(1,L), ylim=(0, 0.05), grid=true, xlabel="Time (t)", ylabel="Target Violation Ratio")
Plots.plot(p1,p2,p3,p4, layout=(4,1), size=(800,1000), left_margin=5Plots.mm, )
Plots.savefig("./sim_reference.png")
#@infiltrate; @assert false
end
# --------------------------------------------------------------------------------------------------
function main()
#=
no hidden states, agent directly observes its elevation and wind velocity mean/variance
x = elevation (e.g., meters)
T = time for look-ahead horizon (e.g., seconds)
action = selected control value, vertical velocity (e.g., m/s)
t = current moment
k = time steps in look ahead horizon, k= t:t+T-1
m_w = mean of vertical wind velocity (e.g., m/s)
v_w = variance of vertical wind velocity (e.g, m^2)
u = control variable = action = vertical velocity (e.g, m/s)
L = simulation time (e.g., seconds)
m_ = mean of _
v_ = variance of _
=#
# Simulation parameters
L = 20
T = 9
v_wind = 0.2
m_goal = 2.0
v_goal = 0.18478
lambda = 0.01 # control prior precision
epsilon = 0.01 # unsafe mass
fx_m_wind(t::Int64) = 5<=t<10 ? -1.0 : 0.0 # Wind mean as function of time t
N = 100 # number of trials
# Step through experimental protocol
actions = Matrix{Float64}(undef, L, N)
observations = Matrix{Float64}(undef, L, N)
for n in 1:N
# Let there be a world
(execute_ai, observe_ai) = createWorld(
v_wind = v_wind,
)
# Let there be an agent
(infer_ai, act_ai, slide_ai, future_ai) = createAgent(;
T = T,
fx_m_wind = fx_m_wind,
v_wind = v_wind,
m_goal = m_goal,
v_goal = v_goal,
lambda = lambda,
)
for t=1:L
actions[t,n] = act_ai() # invoke an action from the agent
futures = future_ai() # fetch the predicted future states
execute_ai(actions[t,n], fx_m_wind(t)) # the action influences hidden external states
observations[t,n] = observe_ai() # observe the current environmental outcome
infer_ai(actions[t,n], observations[t,n], t) # infer beliefs from current model state
slide_ai() # prepare for next iteration
end
end
E_avg = round(mean(sum(actions.^2, dims=1)), digits=2) # Average quadratic cost of control (over all runs)
@show E_avg
per_n = ceil(Int, N/100) # Plot one in every per_n trajectories
plotTrajectory(observations, actions, per_n, fx_m_wind, epsilon)
end
end # module ----------------------------------------------------------
RXRef.main()