The use case is probabilistic programming. In the current release of Soss.jl, a Model
is a block of Statement
s, each of which is either an Assign
(lhs = rhs
) or a Sample
(lhs ~ rhs
). I want to make this more flexible and allow control flow, especially for things like more complex dependencies among entries of an array.
I have a good start on this, mostly using ideas in this post. The current challenge is the question of how to update a sampled value and make sure to maintain consistency of the dependency graph.
It’s very popular to do this using something like a nested dictionary data structure to represent the trace. But this has a lot of overhead, and I’d much rather have something with smaller scope of use cases than something that’s slow. I had been thinking the “trace” could just be made of local variables, so it’s all in code with no data structure at all. At each ~
, we could add a @label
and a @goto
, and the tilde
function could pass back information about where to jump next.
Then I started poking around some more in @BenLauwens 's ResumableFunctions.jl. I had played with this before, but he’s now gotten it to be really fast. There are some cool ideas in the implementation:
https://benlauwens.github.io/ResumableFunctions.jl/dev/internals/
The mutable struct + state machine approach is really slick, and it’s really clever to use the typed code to get the types for the slots. But for my purposes, the name shadowing could get in the way. Reusing a variable name leads to one entry in the struct taking on multiple roles. This could lead to problems when we start jumping around.
Luckily, I think @thautwarm 's JuliaVariables.jl can be made to help with this. I opened a new issue about that here:
https://github.com/JuliaStaging/JuliaVariables.jl/issues/28
So I think the process would look something like
- Use JuliaVariables to solve for scopes and make sure names are unique
- Still at the AST level, make transformations like those described here to get us closer to the lowered representation
- Add
@label
s and@goto
s to make it easy to jump between samples
At this point, we probably need to replace each ~
with a concrete tilde
function, depending whether we want to call rand
or logdensityof
, or something else. Then
- Lower the code
- Use LoweredCodeUtils.jl to get the edges (Edges · LoweredCodeUtils)
Then, for any variable that’s updated, the edges tell us what else needs to be done. So maybe it’s making an abbreviated version of the code for each of these cases, or adding some jumps to cover each of these cases. Not really sure yet.
Then I don’t know the best way to “run” lowered code. It’s important to avoid world age issues, obviously. I’ve been very happy with GeneralizedGenerated, but if this lowered code is lifted back up to the AST level it will be much bigger and (IIUC) any closures will be gone. So maybe it would need RuntimeGeneratedFunctions.jl? Or maybe lowered code can be executed directly? I’m not sure yet about some of these things.
Ok, this response got pretty long. Maybe it’s a little too down-in-the-weeds, but if you have ideas or suggestions I’d enjoy a discussion