Probabilistic programming with source transformations

Oh that looks great, thanks for the link! The Turing folks might be interested in this as well. I think they’ve implemented this in Turing, but it might not be so cleanly abstracted as yours. Also, if you’re interested in variations of SMC, you should check out Anglican. The site has links to their papers, and also implementations of the algorithms in Clojure.

Some context, in case others aren’t familiar with this area

SMC-based systems are sometimes called “universal” probabilistic programming. This is a historical artifact originating with an analogy with Turing completeness (though this is inaccurate, as Stan is Turing complete), reinforced due to the lack of a better term.

Historically, BUGS and Church are the Fortran and Lisp of probabilistic programming. Stan descends from the former, working in terms of a known fixed-dimension parameter space with a computable density function. In exchange for this constraint, posterior conditional densities are available, usually leading to better performance (both statistical and computational).

Church-like languages are more flexible. Instead of a joint density, they work in terms of a randomized simulation process. SMC-like algorithms use a collection of “particles”, each representing one run of the simulation. Cleverly manipulating weighted particles (e.g., choosing to interrupt some and clone others) leads to some different different possibilities for parameter inference.


The “right” way to do it is to use the integrator interface so that way you build the integrator once and then re-solve using the same cache variables each time.

Of course, I’m ignoring that right now because the algorithm issue is much more important to get correct first. But if you want to optimize it, there you go.

Or for small problems you can use static arrays. But Julia has a few inference problems on some of our algorithms (many operation statement inference with static arrays issue of sorts) so that path won’t be the best until a 1.x gets that.

1 Like

Hold on, I found out how to fix static array stuff. I’ll get that in by tonight and these tests should then be using static arrays.

Edit: It’s done. You can use the static array parameterized function setup for this now quite performantly without reiniting.


As @dfdx had mentioned, there’s some trouble with differentiating expressions like logpdf(Normal(μ,σ),x) - “differentiating a data type” doesn’t really make sense.

I looked a bit more into the implementation. For functions like logpdf, Distributions just calls StatsFuns, which leads to a nice algebraic expression after a couple of steps. The following are all equivalent, because they are successive simplifications of the first line:

Main> logpdf(Normal(μ,σ),x)

Main> normlogpdf(μ,σ,x)

Main> normlogpdf((x-μ)/σ) - log(σ)

Main> -(abs2((x-μ)/σ) + log2π)/2 - log(σ)

Distributions manages this with a macro _delegate_statsfuns.

What’s really needed, I think, is a way to easily get from the first of these to the last. I could code this in Soss for this specific case, but it seems like a generally useful thing anyway. Most generally, I could imagine a function that takes an expression and returns an iterator over successive evaluation steps. I’m guessing the most efficient path might be somewhere between a quick hack and the “best possible world” evaluation iterator.

Any thoughts/suggestions?

ForwardDiff.jl works through these just fine. IMO AD is the “best possible” solution for these situations, as you have seen the symbolic approach does not scale well.

Took longer than I expected, but I have started a set of worked examples as Jupyter notebooks in

Currently, it only contains a single problem; estimation of an unknown covariance matrix (decomposing into correlation and variance, using the LKJ prior on correlation). This is WIP, but I made it as clear as I could; note the that the API can of course change (I will of course update these examples then).


I agree. I wasn’t after symbolic diff, but AD via source rewriting, maybe similar to what Tangent does for Python. I think XGrad is along these lines, but maybe I’ve misunderstood.

That looks really helpful. Thank you!

I think I need to clarify it a bit. In math, derivatives are defined for functions. In this piece of code Normal represents a distribution, and I don’t know such thing as derivative of or w.r.t. a distribution. We can treat Normal as a struct (i.e. collection of its parameters) and find derivatives w.r.t. each field. In general, this should work and I’m open to this route. However, this way we will be bound to the internal structure of data types, making the code more fragile (Actually, Distributions.jl defines function params() to encapsulate implementation details, but from reading the source code I don’t think any AD tools will understand it and find derivatives only w.r.t. these parameters).

Given source rewriting approach that @cscherrer is using I think that simpler and more robust approach is to expand things like logpdf(Normal(...)) into normlogpdf() and differentiate it directly. It’s actually quite easy to do - XGrad already does it internally. But other ideas are welcome too.

I agree. I wasn’t after symbolic diff, but AD via source rewriting, maybe similar to what Tangent does for Python. I think XGrad is along these lines, but maybe I’ve misunderstood.

There are 2 main types of AD: forward-mode and reverse-mode. Reverse-mode AD can be dynamic or static. Static reverse-mode AD usually builds and optimizes computation graph. It can build this graph using function overloading or source transformation.

Symbolic differentiation rewrites symbolic expression (e.g. code) into corresponding derivatives. In most cases rewriting is done using the same 2-step procedure as in reverse-mode AD.

XGrad (as well as Theano, TensorFlow or Tangent) falls under definition of both - AD and SD.

A short guide:

  • forward vs. reverse-mode depends on whether you have a function of form R -> R^n or R^n -> R
  • static vs. dynamic depends on whether you can define computation graph beforehand
  • function overloading vs. source rewriting depends on the context

Ideally yes, but my impression is that ReverseDiff.jl is waiting for Cassette.jl to mature, so given the current state of packages, ForwardDiff.jl is a reasonable (but of course not ideal) choice for \mathbb{R}^n \to \mathbb{R} functions too, out of necessity.

1 Like

Thanks for the details @dfdx. I guess my impression is that for sufficiently simple code, statically rewriting the source will always be better. For me the relevant questions are

  • Which approach is most appropriate in principle?
  • How much risk is there that a given implementation will become obsolete?

I can see that XGrad still has a way to go before maturity, but it seems like a good approach to me, and I think there will always be a place for source transformation autodiff. I was really kind of shocked to see that ReverseDiffSource seems to have been abandoned, and I’m excited to see new work in its place.

Wait, now I’m confused. XGrad already does this? Maybe I’m calling it wrong. I was thinking I’d have to reconstruct the name-mangling Distributions does to call StatsFuns. Can you show a minimal example, like computing the gradient of this?

t -> logpdf(Normal(t[1],t[2]),t[3])

Note, that Cassete-based version of ReverseDiff moved to Capstan.jl. Another battle-tested reverse-mode AD package with function overloading is AutoGrad.jl.

I mostly use AD/SD in machine learning tasks where size of inputs can easily exceed 10^6. Using ForwardDiff.jl for such data isn’t an option at all :slightly_smiling_face:


I always try to isolate my code from other packages implementation details. Normally it means fixing derivatives - finding them once for each function and defining corresponding primitives. This way any changes to the internals of that packages won’t affect you as long as outer interfaces stay the same. Note, that source rewriting is advantageous here since you can literally generate and copy-paste code for derivatives you need into source files.

How much risk is there that a given implementation will become obsolete?

There’s a high risk that one of the functions you don’t own will get non-parsable (for source transformation), not generic (for function overloading) or just somehow non-differentiable (e.g. linking C code) part. If this happens, your code will just break even though all the interfaces stay the same.

Most libraries I’ve seen define a number of primitives that are guaranteed to work and additionally support others’ code that is expected to work in most cases.

Wait, now I’m confused.

It’s a bit longer story, so I’ll answer this tomorrow :slight_smile:

To clarify: Capstan will provide both forward-mode and reverse-mode Cassette-based AD. The Cassette-based implementation is quite different from ReverseDiff/ForwardDiff, so I decided it would be appropriate to make a new package. ReverseDiff/ForwardDiff will probably still exist for a while after Capstan is released, but my hope/goal/belief is that Capstan will eventually render both of those packages obsolete. Should be interesting to see where Capstan falls in the “operator overloading vs. source transformation” space, since the Cassette-based approach really is a hybrid of both…


Oh, didn’t realize that - yes that sounds really interesting. I’m really curious about the overloading/transformation tradeoffs, especially as it might relate to specifics of Julia’s compilation strategy.

Sorry, I’m short on time again, so a quick reply to not block you: right now you can use functional form, the following works (with master of Espresso and XGrad):

using StatsFuns
using XGrad

t = (0.0, 0.5, 2.0)
xdiff(:(normlogpdf(t[1], t[2], t[3])); t=t)

Later I’ll provide more details and other options (like using Normal struct directly).

Ah great, for a minute there it seemed like you were saying XGrad already does the expansion. Thanks for clarifying.

Anyway, I’ll likely be slowing down on Soss for a while anyway. I started a new job yesterday (with Metis), and I’m in a bit of a scramble right now to get up to speed before starting teaching next week.

Congrats on the new position!

Ah great, for a minute there it seemed like you were saying XGrad already does the expansion.

And it actually does. For example, when you write normlogpdf(μ, σ, x), what is sent to differentiation engine is something like this:

    StatsFuns_zval_tmp755_757 = x - μ
    StatsFuns_normlogpdf_tmp739_743 = StatsFuns_zval_tmp755_757 / σ
    StatsFuns_normlogpdf_tmp746_751 = abs2(StatsFuns_normlogpdf_tmp739_743)
    StatsFuns_normlogpdf_tmp747_752 = StatsFuns_normlogpdf_tmp746_751 + log2π
    StatsFuns_normlogpdf_tmp748_753 = -StatsFuns_normlogpdf_tmp747_752
    StatsFuns_normlogpdf_tmp749_754 = 2
    StatsFuns_normlogpdf_tmp740_744 = StatsFuns_normlogpdf_tmp748_753 / StatsFuns_normlogpdf_tmp749_754
    StatsFuns_normlogpdf_tmp741_745 = log(σ)
    tmp738 = StatsFuns_normlogpdf_tmp740_744 - StatsFuns_normlogpdf_tmp741_745

The problem is that not all functions are made equal, and data type constructors are particularly hard. Let me explain it.

When you differentiate a combination of functions, e.g. f(g(x)) automatic differentiation tools normally split it into separate calls:

w1 = g(x)
z = w2 = f(w1)

and differentiate one by one:

dz/dz = dz/dw2 = 1
dz/dw1 = dz/dw2 * dw2/dw1
dz/dx =  dw2/dw1 * dw1/dx

If both - f and g - are known and diff rules for them are defined, you are done. But what if one of them comes from the outside?

In this case XGrad tries to recursively find and parse unknown function’s body until the resulting expression contains only known primitives. In most cases, like with logpdf or normlogpdf, it works well. But some functions can’t be parsed into correct algebraic expression. Default constructors are one of them (although additional outer constructors work fine in most cases). Default constructors don’t have function body, and the best thing you can get from AST is something like:

:(return $(Expr(:new, :((Core.apply_type)(Distributions.Normal, $(Expr(:static_parameter, 1)))), :((Base.convert)($(Expr(:static_parameter, 1)), _2)), :((Base.convert)($(Expr(:static_parameter, 1)), _3)))))

Hmm, this doesn’t look differentiable.

Pure AD tools normally don’t have this problem since they never parse function bodies or try to differentiate constructors. I have an experimental branch that uses tracked data (similar to TrackedArray in ReverseDiff or Variable in PyTorch, for example), but this approach has its own limitations. Notably, the graph built this way doesn’t preserve variable names, so it’s very hard to debug or edit generated code. Anyway, I’m eager to try it on Distributions / StatsFuns.

So what to do now? As I mentioned earlier, the easiest way to avoid all complexities and get all advantages is to transform things like logpdf(Normal(...), ...) into normlogpdf(...), which is trivial using Espresso.rewrite() function.

There is a new player in the probabilistic programming arena: Gen.jl. Unlike Stan or Turing it does inference using execution traces. Here is a recent paper discussing the design. Looking at the code it looks like a good application for Cassette


Looks like the team developing it come out of MIT so maybe they’ve crossed paths with @jrevels/Cassette already.

1 Like