New Soss interface - request for feedback

Hi, I’m working on extending Soss, and hoping I might get some design feedback.

In the current release of Soss, a model is restricted to be a DAG. To do this, each line of the body of a model is required to be either of the form

x = rhs


x ~ some_measure

There’s a bit more to it, but let’s leave it at this for now.

From this, I had a few things I’d like to improve:

  1. Models should be more flexible, falling back to DAGs only for cases where that’s required or helpful for performance
  2. Inference primitives should be easier to build
  3. Inference primitives should be easily extensible by adding new methods, for example allowing rand to return a Dict instead of a NamedTuple, if desired
  4. All of this with no new overhead

Threading a Context

To motivate the current state of the new approach, let’s look at a simple example. Say you have a little model like

@model begin
    p ~ Uniform()
    x ~ Bernoulli(p)

we can think of turning this into something like

_ctx -> begin
    (p, _ctx) = tilde(:p, Uniform(), _ctx)
    (x, _ctx) = tilde(:x, Bernoulli(p), _ctx)

This threads _ctx through the model, and allows the user to add new methods for tilde outside of this. So for example, to implement rand we might do something like

function tilde(v::Symbol, d, _ctx)
    x = rand(d)
    _ctx = merge(_ctx, NamedTuple{(v,)}((x,)))
   (x, _ctx)

Adding a Configuration

Ok, that’s not quite enough. We shouldn’t really be implementing rand without allowing a choice of RNG. The context is intended to change as it flows through the code, so that’s probably not the best place for it. Instead, let’s add a _cfg variable. So now we would rewrite out code as

(_cfg, _ctx) -> begin
    (p, _ctx) = tilde(:p, Uniform(), _cfg, _ctx)
    (x, _ctx) = tilde(:x, Bernoulli(p), _cfg, _ctx)

And our tilde function could be

function tilde(v::Symbol, d, _cfg, _ctx::NamedTuple)
    x = rand(_cfg.rng, d)
    _ctx = merge(_ctx, NamedTuple{(v,)}((x,)))
   (x, _ctx)

Dispatching on Context

In the last code, I also added a type constraint on _ctx. That’s because we can also do things like

function tilde(v, d, cfg, ctx::Dict)
    x = rand(cfg.rng, d)
    ctx[v] = x 
    (x, ctx)

Customizing the Return Value

So far we’re always returning the context at the end of the code. But that’s kind of restrictive, and it’s convenient to add a third argument, the return value. This can be updated at each step, with the final value returned.

If this is only assigning one local variable to another, we can reasonably expect the compiler to optimize it away. So out rewrite can become

(_cfg, _ctx) -> begin
    local _retn
    (p, _ctx, _retn) = tilde(:p, Uniform(), _cfg, _ctx)
    (x, _ctx, _retn) = tilde(:x, Bernoulli(p), _cfg, _ctx)

and out tilde function can look like this:

function tilde(v, d, cfg, ctx::Dict)
    x = rand(cfg.rng, d)
    ctx[v] = x 
    (x, ctx, ctx)

This also allows methods this alternative that only returns the last-computed value:

function tilde(v, d, cfg, ctx::Tuple{})
    x = rand(cfg.rng, d)
    (x, (), x)

Which tilde to call

So far every code rewrite has had a function just called tilde. That’s clearly not enough. The current implementaion (in the cs-astmodels branch) has, for example, rand building code with a tilde_rand, logdensity building code with a logdensity, etc. Alternatively, we could have one single tilde, and have rand or logdensity as the first arugment so we can dispatch on it.

Dispatching Across Functions

One (usually minor) limitation of Julia is the inability to dispatch on arguments across functions. For example, say you have


Then this is sometimes desirable, but it doesn’t work:

(f::Function)(w::MyWrapper) = f(w.value)

But since we’re rewriting the code anyway, we can fix this. So we have an option to pass some call function, say mycall, and rewrite every f(args...; kwargs...) as mycall(f, args...; kwargs...).

Then we could just say

mycall(f, w::MyWrapper) = f(w.value)

_args and _obs

Say we have a model m, and we want to work with m(args) | obs where each are named tuples. We bring these in as local variables, but we also need to make them available to the tilde function. Currently I’m doing this through _args and _obs fields in the configuration (_cfg).

In some cases, we might also want to change the behavior based on whether a variable is included in these. So the codegen also has

inargs = Val(x ∈ getntkeys(_args))
inobs = Val(x ∈ getntkeys(_obs))

These are passed to the tilde function so we can dispatch on them, allowing different behavior for different cases.


Looking good so far:

julia> rng = Random.MersenneTwister(1)

julia> @btime rand($rng, $(m()))
  4.058 ns (0 allocations: 0 bytes)
(p = 0.8502655945472257, x = true)

julia> @btime rand($rng, $(m()); ctx=NamedTuple())
  4.068 ns (0 allocations: 0 bytes)
(p = 0.40102396520892314, x = true)

julia> @btime rand($rng, $(m()); ctx=$(Dict()))
  32.042 ns (1 allocation: 16 bytes)
Dict{Any, Any} with 2 entries:
  :p => 0.763435
  :x => true

julia> @btime rand($rng, $(m()); ctx=())
  2.765 ns (0 allocations: 0 bytes)

For comparison,

julia> @btime rand($rng) < rand($rng)
  4.238 ns (0 allocations: 0 bytes)

Are there models that aren’t easily expressible as DAGs? If so, is that really the use case for Soss, and the most important thing to work on for getting more people using Soss? I would expect that if your model is that weird, you’d move over to something like Turing or Gen.

Great question. There are a few different things going on here…

First, some primitives like rand and logdensity don’t make much use of the DAG-ness of a model. In terms of the ~s only, DAGs are still, and will continue to be very common. But outside of this, it can be convenient to allow for arbitrary Julia code. For many cases, this seems like a benefit with effectively no downside. And for DAG models that benefit from that constraint, we can still work in those terms.

What I describe above is an ASTModel. We’ll still have DAGModels, and users can start with one of these, or convert an ASTModel into one, if the DAG requirements are met.

Second, I think there’s a misconception that Soss is all about DAGs. That’s not it at all. Soss is about syntax-level representation and manipulation, and run-time codegen. The syntactic representation is “initial”, so we should be able to write a model at this level and automatically build a Turing or Gen version of it, or use methods specific to Soss.


I have some beginner-level questions.

  • What motivaets ctx's place in the argument order? For functions like map and filter the transformed container goes last, enabling partial application; for DataFrames functions it goes first, enabling piping. I often see “context” parameters going first, but I don’t know where that idiom comes from – maybe it’s a remnant of OOP or maybe it’s something about monads…

  • Why are some of the variables in your examples prefixed with underscores? Is that being used to avoid name collisions?

1 Like

Argument order here is kind of arbitrary. Mostly I wanted the first few arguments to be similar to the x ~ some_measure expression itself. These are mostly utility functions for building new primitives. End users would use rand etc as usual, but now have the option of adding a context. In that case I wanted changes to the usual argument order of Base.rand to be minimal.

As for argument order in general… I’m not sure about the origin, but I’ve found it’s sometimes handy to have things in an order to make dispatch easy. For example you might have something like

foo(::T, args...; kwargs) = some_foo_specific_to_T(args...; kwargs...)

In cases when there’s already a some_foo_specific_to_T, this makes it easy to manage the other parameters.

Yep, that’s it. It doesn’t make any difference in the functions themselves, but when names appear in code we want to avoid collision. We could gensym everything, but that takes extra code to pass the names around. So for now I just have it as “leading underscores are for internal use”.

1 Like

I generally like this proposal – as it reflects a number of design choices which I’m familiar with in Gen. Especially the notion of threading a context object through, and or otherwise dispatching using a context object to control how certain inference methods work.

From a performance perspective, my main concern is that there are a number of things unaddressed by this proposal.

Using rand, for example – to collect random choices into your ctx object seems lightweight – but it doesn’t expose or quantify performance on the full complexity of inference. Specifically, consider implementing multi-choice MH steps – this requires re-visiting choices in a model with a ctx which has been “sampled from” (or assigned choices from the denoted measure) of the model.

Including arbitrary Julia code requires that you think explicitly about how to cache or otherwise incrementally compute sub-sets of your model AST – otherwise MH updates will not be asymptotically optimal – especially in a way which we should be able to solve programmatically.

I think codegen from an AST representation likely provides a way for ASTModel to be asymptotically optimal like Gen’s static compiler is – as they are based on similar concepts. But the main impediment to including arbitrary Julia code is the fact that inference algorithms which require re-visiting choices will re-execute code which potentially has not changed from the previous sample.

This particular requirement induces a number of design complexities – and I’m generally really curious about how you might address them.

Also, just to mention – there room to grow in the design space for the above issues I mentioned – so I am really curious about things that ASTModel might enable / or even design inefficiencies in parts of Gen’s static subset which ASTModel might improve on.

I think it would actually be really nice to be able to encode structured computation (including control flow style constructs) in ASTModel while also eliminating overhead during inference – especially for code which hasn’t changed.

Thanks @McCoy !

Yep, this is a tricky question. It’s one nice thing about having everything at the top level. Some possibilities I considered:

  • There’s the old continuations route. But these are awkward in Julia, and mostly handled by Turing, so it’s maybe better just to go through Turing for this than to reinvent it.
  • Something like inserting an @label before and after each ~ statement. But at the Julia AST level, I still don’t think it’s possible to control this from outside a function. But maybe something could be inserted here as a “message” to lower-level code manipulation?
  • Statically determine the variables in each block (the “space between the tilde expressions”), and rewrite it as a function, so the whole model becomes a composition of functions. Then you could easily evaluate a subset of these. But I think this may only apply when everything is at the top-level, in which case it may be better to treat it as a DAG.

Hmmm, maybe the best option yet, but there could be a lot to it… Maybe we could do something like

  1. Rewrite it in SSA form (still at the Julia AST level) with @label and @goto as needed
  2. Run once, to get the types of all variables
  3. Make variables of these type available at the beginning of the code, so they persist throughout
  4. Break it apart at each @label or @goto, so now you have chunks that can be run in any order.

Just an idea, not sure it will pan out.

Finally, though I want to leverage AST benefits, I also want to keep an eye out for ways to connect explicitly with Gen, Turing, etc. Please let me know as you see possibilities for that :slight_smile:

1 Like

Hey @McCoy , I think I see a way to make all of this work. Do you have any suggestions for a good proof of concept of this? Like, a “hello world” model and inference methods that’s as simple as possible for demonstrating the approach?

Come to think of it, it seems like it could be useful to have a general collection of these, for different kinds of capabilities. But that’s maybe a separate discussion :slight_smile: