Hi, I’m working on extending Soss, and hoping I might get some design feedback.
In the current release of Soss, a model is restricted to be a DAG. To do this, each line of the body of a model is required to be either of the form
x = rhs
or
x ~ some_measure
There’s a bit more to it, but let’s leave it at this for now.
From this, I had a few things I’d like to improve:
- Models should be more flexible, falling back to DAGs only for cases where that’s required or helpful for performance
- Inference primitives should be easier to build
- Inference primitives should be easily extensible by adding new methods, for example allowing
randto return aDictinstead of aNamedTuple, if desired - All of this with no new overhead
Threading a Context
To motivate the current state of the new approach, let’s look at a simple example. Say you have a little model like
@model begin
p ~ Uniform()
x ~ Bernoulli(p)
end
we can think of turning this into something like
_ctx -> begin
(p, _ctx) = tilde(:p, Uniform(), _ctx)
(x, _ctx) = tilde(:x, Bernoulli(p), _ctx)
_ctx
end
This threads _ctx through the model, and allows the user to add new methods for tilde outside of this. So for example, to implement rand we might do something like
function tilde(v::Symbol, d, _ctx)
x = rand(d)
_ctx = merge(_ctx, NamedTuple{(v,)}((x,)))
(x, _ctx)
end
Adding a Configuration
Ok, that’s not quite enough. We shouldn’t really be implementing rand without allowing a choice of RNG. The context is intended to change as it flows through the code, so that’s probably not the best place for it. Instead, let’s add a _cfg variable. So now we would rewrite out code as
(_cfg, _ctx) -> begin
(p, _ctx) = tilde(:p, Uniform(), _cfg, _ctx)
(x, _ctx) = tilde(:x, Bernoulli(p), _cfg, _ctx)
_ctx
end
And our tilde function could be
function tilde(v::Symbol, d, _cfg, _ctx::NamedTuple)
x = rand(_cfg.rng, d)
_ctx = merge(_ctx, NamedTuple{(v,)}((x,)))
(x, _ctx)
end
Dispatching on Context
In the last code, I also added a type constraint on _ctx. That’s because we can also do things like
function tilde(v, d, cfg, ctx::Dict)
x = rand(cfg.rng, d)
ctx[v] = x
(x, ctx)
end
Customizing the Return Value
So far we’re always returning the context at the end of the code. But that’s kind of restrictive, and it’s convenient to add a third argument, the return value. This can be updated at each step, with the final value returned.
If this is only assigning one local variable to another, we can reasonably expect the compiler to optimize it away. So out rewrite can become
(_cfg, _ctx) -> begin
local _retn
(p, _ctx, _retn) = tilde(:p, Uniform(), _cfg, _ctx)
(x, _ctx, _retn) = tilde(:x, Bernoulli(p), _cfg, _ctx)
_retn
end
and out tilde function can look like this:
function tilde(v, d, cfg, ctx::Dict)
x = rand(cfg.rng, d)
ctx[v] = x
(x, ctx, ctx)
end
This also allows methods this alternative that only returns the last-computed value:
function tilde(v, d, cfg, ctx::Tuple{})
x = rand(cfg.rng, d)
(x, (), x)
end
Which tilde to call
So far every code rewrite has had a function just called tilde. That’s clearly not enough. The current implementaion (in the cs-astmodels branch) has, for example, rand building code with a tilde_rand, logdensity building code with a logdensity, etc. Alternatively, we could have one single tilde, and have rand or logdensity as the first arugment so we can dispatch on it.
Dispatching Across Functions
One (usually minor) limitation of Julia is the inability to dispatch on arguments across functions. For example, say you have
MyWrapper{T}
value::T
end
Then this is sometimes desirable, but it doesn’t work:
(f::Function)(w::MyWrapper) = f(w.value)
But since we’re rewriting the code anyway, we can fix this. So we have an option to pass some call function, say mycall, and rewrite every f(args...; kwargs...) as mycall(f, args...; kwargs...).
Then we could just say
mycall(f, w::MyWrapper) = f(w.value)
_args and _obs
Say we have a model m, and we want to work with m(args) | obs where each are named tuples. We bring these in as local variables, but we also need to make them available to the tilde function. Currently I’m doing this through _args and _obs fields in the configuration (_cfg).
In some cases, we might also want to change the behavior based on whether a variable is included in these. So the codegen also has
inargs = Val(x ∈ getntkeys(_args))
inobs = Val(x ∈ getntkeys(_obs))
These are passed to the tilde function so we can dispatch on them, allowing different behavior for different cases.
Performance
Looking good so far:
julia> rng = Random.MersenneTwister(1)
MersenneTwister(1)
julia> @btime rand($rng, $(m()))
4.058 ns (0 allocations: 0 bytes)
(p = 0.8502655945472257, x = true)
julia> @btime rand($rng, $(m()); ctx=NamedTuple())
4.068 ns (0 allocations: 0 bytes)
(p = 0.40102396520892314, x = true)
julia> @btime rand($rng, $(m()); ctx=$(Dict()))
32.042 ns (1 allocation: 16 bytes)
Dict{Any, Any} with 2 entries:
:p => 0.763435
:x => true
julia> @btime rand($rng, $(m()); ctx=())
2.765 ns (0 allocations: 0 bytes)
true
For comparison,
julia> @btime rand($rng) < rand($rng)
4.238 ns (0 allocations: 0 bytes)
false