Probabilistic programming with source transformations

To clarify: Capstan will provide both forward-mode and reverse-mode Cassette-based AD. The Cassette-based implementation is quite different from ReverseDiff/ForwardDiff, so I decided it would be appropriate to make a new package. ReverseDiff/ForwardDiff will probably still exist for a while after Capstan is released, but my hope/goal/belief is that Capstan will eventually render both of those packages obsolete. Should be interesting to see where Capstan falls in the “operator overloading vs. source transformation” space, since the Cassette-based approach really is a hybrid of both…

7 Likes

Oh, didn’t realize that - yes that sounds really interesting. I’m really curious about the overloading/transformation tradeoffs, especially as it might relate to specifics of Julia’s compilation strategy.

Sorry, I’m short on time again, so a quick reply to not block you: right now you can use functional form, the following works (with master of Espresso and XGrad):

using StatsFuns
using XGrad

t = (0.0, 0.5, 2.0)
xdiff(:(normlogpdf(t[1], t[2], t[3])); t=t)

Later I’ll provide more details and other options (like using Normal struct directly).

Ah great, for a minute there it seemed like you were saying XGrad already does the expansion. Thanks for clarifying.

Anyway, I’ll likely be slowing down on Soss for a while anyway. I started a new job yesterday (with Metis), and I’m in a bit of a scramble right now to get up to speed before starting teaching next week.

Congrats on the new position!

Ah great, for a minute there it seemed like you were saying XGrad already does the expansion.

And it actually does. For example, when you write normlogpdf(μ, σ, x), what is sent to differentiation engine is something like this:

quote
    StatsFuns_zval_tmp755_757 = x - μ
    StatsFuns_normlogpdf_tmp739_743 = StatsFuns_zval_tmp755_757 / σ
    StatsFuns_normlogpdf_tmp746_751 = abs2(StatsFuns_normlogpdf_tmp739_743)
    StatsFuns_normlogpdf_tmp747_752 = StatsFuns_normlogpdf_tmp746_751 + log2π
    StatsFuns_normlogpdf_tmp748_753 = -StatsFuns_normlogpdf_tmp747_752
    StatsFuns_normlogpdf_tmp749_754 = 2
    StatsFuns_normlogpdf_tmp740_744 = StatsFuns_normlogpdf_tmp748_753 / StatsFuns_normlogpdf_tmp749_754
    StatsFuns_normlogpdf_tmp741_745 = log(σ)
    tmp738 = StatsFuns_normlogpdf_tmp740_744 - StatsFuns_normlogpdf_tmp741_745
end

The problem is that not all functions are made equal, and data type constructors are particularly hard. Let me explain it.

When you differentiate a combination of functions, e.g. f(g(x)) automatic differentiation tools normally split it into separate calls:

w1 = g(x)
z = w2 = f(w1)

and differentiate one by one:

dz/dz = dz/dw2 = 1
dz/dw1 = dz/dw2 * dw2/dw1
dz/dx =  dw2/dw1 * dw1/dx

If both - f and g - are known and diff rules for them are defined, you are done. But what if one of them comes from the outside?

In this case XGrad tries to recursively find and parse unknown function’s body until the resulting expression contains only known primitives. In most cases, like with logpdf or normlogpdf, it works well. But some functions can’t be parsed into correct algebraic expression. Default constructors are one of them (although additional outer constructors work fine in most cases). Default constructors don’t have function body, and the best thing you can get from AST is something like:

:(return $(Expr(:new, :((Core.apply_type)(Distributions.Normal, $(Expr(:static_parameter, 1)))), :((Base.convert)($(Expr(:static_parameter, 1)), _2)), :((Base.convert)($(Expr(:static_parameter, 1)), _3)))))

Hmm, this doesn’t look differentiable.

Pure AD tools normally don’t have this problem since they never parse function bodies or try to differentiate constructors. I have an experimental branch that uses tracked data (similar to TrackedArray in ReverseDiff or Variable in PyTorch, for example), but this approach has its own limitations. Notably, the graph built this way doesn’t preserve variable names, so it’s very hard to debug or edit generated code. Anyway, I’m eager to try it on Distributions / StatsFuns.

So what to do now? As I mentioned earlier, the easiest way to avoid all complexities and get all advantages is to transform things like logpdf(Normal(...), ...) into normlogpdf(...), which is trivial using Espresso.rewrite() function.

There is a new player in the probabilistic programming arena: Gen.jl. Unlike Stan or Turing it does inference using execution traces. Here is a recent paper discussing the design. Looking at the code it looks like a good application for Cassette

2 Likes

Looks like the team developing it come out of MIT so maybe they’ve crossed paths with @jrevels/Cassette already.

1 Like