Hi, I’m working on extending Soss, and hoping I might get some design feedback.
In the current release of Soss, a model is restricted to be a DAG. To do this, each line of the body of a model is required to be either of the form
x = rhs
or
x ~ some_measure
There’s a bit more to it, but let’s leave it at this for now.
From this, I had a few things I’d like to improve:
- Models should be more flexible, falling back to DAGs only for cases where that’s required or helpful for performance
- Inference primitives should be easier to build
- Inference primitives should be easily extensible by adding new methods, for example allowing
rand
to return aDict
instead of aNamedTuple
, if desired - All of this with no new overhead
Threading a Context
To motivate the current state of the new approach, let’s look at a simple example. Say you have a little model like
@model begin
p ~ Uniform()
x ~ Bernoulli(p)
end
we can think of turning this into something like
_ctx -> begin
(p, _ctx) = tilde(:p, Uniform(), _ctx)
(x, _ctx) = tilde(:x, Bernoulli(p), _ctx)
_ctx
end
This threads _ctx
through the model, and allows the user to add new methods for tilde
outside of this. So for example, to implement rand
we might do something like
function tilde(v::Symbol, d, _ctx)
x = rand(d)
_ctx = merge(_ctx, NamedTuple{(v,)}((x,)))
(x, _ctx)
end
Adding a Configuration
Ok, that’s not quite enough. We shouldn’t really be implementing rand
without allowing a choice of RNG. The context is intended to change as it flows through the code, so that’s probably not the best place for it. Instead, let’s add a _cfg
variable. So now we would rewrite out code as
(_cfg, _ctx) -> begin
(p, _ctx) = tilde(:p, Uniform(), _cfg, _ctx)
(x, _ctx) = tilde(:x, Bernoulli(p), _cfg, _ctx)
_ctx
end
And our tilde
function could be
function tilde(v::Symbol, d, _cfg, _ctx::NamedTuple)
x = rand(_cfg.rng, d)
_ctx = merge(_ctx, NamedTuple{(v,)}((x,)))
(x, _ctx)
end
Dispatching on Context
In the last code, I also added a type constraint on _ctx
. That’s because we can also do things like
function tilde(v, d, cfg, ctx::Dict)
x = rand(cfg.rng, d)
ctx[v] = x
(x, ctx)
end
Customizing the Return Value
So far we’re always returning the context at the end of the code. But that’s kind of restrictive, and it’s convenient to add a third argument, the return value. This can be updated at each step, with the final value returned.
If this is only assigning one local variable to another, we can reasonably expect the compiler to optimize it away. So out rewrite can become
(_cfg, _ctx) -> begin
local _retn
(p, _ctx, _retn) = tilde(:p, Uniform(), _cfg, _ctx)
(x, _ctx, _retn) = tilde(:x, Bernoulli(p), _cfg, _ctx)
_retn
end
and out tilde
function can look like this:
function tilde(v, d, cfg, ctx::Dict)
x = rand(cfg.rng, d)
ctx[v] = x
(x, ctx, ctx)
end
This also allows methods this alternative that only returns the last-computed value:
function tilde(v, d, cfg, ctx::Tuple{})
x = rand(cfg.rng, d)
(x, (), x)
end
Which tilde
to call
So far every code rewrite has had a function just called tilde
. That’s clearly not enough. The current implementaion (in the cs-astmodels
branch) has, for example, rand
building code with a tilde_rand
, logdensity
building code with a logdensity
, etc. Alternatively, we could have one single tilde
, and have rand
or logdensity
as the first arugment so we can dispatch on it.
Dispatching Across Functions
One (usually minor) limitation of Julia is the inability to dispatch on arguments across functions. For example, say you have
MyWrapper{T}
value::T
end
Then this is sometimes desirable, but it doesn’t work:
(f::Function)(w::MyWrapper) = f(w.value)
But since we’re rewriting the code anyway, we can fix this. So we have an option to pass some call
function, say mycall
, and rewrite every f(args...; kwargs...)
as mycall(f, args...; kwargs...)
.
Then we could just say
mycall(f, w::MyWrapper) = f(w.value)
_args
and _obs
Say we have a model m
, and we want to work with m(args) | obs
where each are named tuples. We bring these in as local variables, but we also need to make them available to the tilde
function. Currently I’m doing this through _args
and _obs
fields in the configuration (_cfg
).
In some cases, we might also want to change the behavior based on whether a variable is included in these. So the codegen also has
inargs = Val(x ∈ getntkeys(_args))
inobs = Val(x ∈ getntkeys(_obs))
These are passed to the tilde
function so we can dispatch on them, allowing different behavior for different cases.
Performance
Looking good so far:
julia> rng = Random.MersenneTwister(1)
MersenneTwister(1)
julia> @btime rand($rng, $(m()))
4.058 ns (0 allocations: 0 bytes)
(p = 0.8502655945472257, x = true)
julia> @btime rand($rng, $(m()); ctx=NamedTuple())
4.068 ns (0 allocations: 0 bytes)
(p = 0.40102396520892314, x = true)
julia> @btime rand($rng, $(m()); ctx=$(Dict()))
32.042 ns (1 allocation: 16 bytes)
Dict{Any, Any} with 2 entries:
:p => 0.763435
:x => true
julia> @btime rand($rng, $(m()); ctx=())
2.765 ns (0 allocations: 0 bytes)
true
For comparison,
julia> @btime rand($rng) < rand($rng)
4.238 ns (0 allocations: 0 bytes)
false