Hello all,

Just thought I’d talk a bit about the latest with Soss. Everything here is from the Soss, MeasureTheory, NestedTuples, and SymbolicCodegen packages.

# An easy problem

Say we have a simple model with some observations:

``````julia> m = @model σ begin
μ ~ StudentT(3.0)
x ~ Normal(μ, σ) |> iid(1000)
return x
end;

julia> x = rand(m(σ=2.0));
``````

Then we can find the posterior density and symbolically compute its log-density:

``````julia> post = m(σ=2.0) | (;x)
ConditionalModel given
arguments    (:σ,)
observations (:x,)
@model σ begin
μ ~ StudentT(3.0)
x ~ Normal(μ, σ) |> iid(1000)
return x
end

julia> symlogdensity(post)
-1300.8333911728728 + 229.7492699560466μ + -2.0log(3.0 + (μ^2)) + -125.0(μ^2)
``````

The `σ` and `x` are gone, because we were able to substitute them into the generated code. That works like this:

1. Run the model forward to get all the types
2. Generate symbolic values from these, with the appropriate type tags
3. Evaluate the log-density on this result. This sounds like a bad idea, but all vectors are symbolic, so the code stays small
4. Walk the expression, and replace any `<: Number` subexpression with its value
5. Return the result.

Next is CSE:

``````julia> symlogdensity(post) |> SymbolicCodegen.cse
7-element Vector{Pair{Symbol, Symbolic}}:
Symbol("##930") => 229.7492699560466μ
Symbol("##931") => μ^2
Symbol("##932") => 3.0 + ##931
Symbol("##933") => log(##932)
Symbol("##934") => -2.0##933
Symbol("##935") => -125.0##931
Symbol("##936") => -1300.8333911728728 + ##930 + ##934 + ##935
``````

We put those together:

``````julia> sourceCodegen(post)
quote
var"##1010" = (*)(229.7492699560466, μ)
var"##1011" = (^)(μ, 2)
var"##1012" = (+)(3.0, var"##1011")
var"##1013" = (log)(var"##1012")
var"##1014" = (*)(-2.0, var"##1013")
var"##1015" = (*)(-125.0, var"##1011")
var"##1016" = (+)(-1300.8333911728728, var"##1010", var"##1014", var"##1015")
end
``````

``````julia> codegen(post).f
function = (_args, _data, _pars;) -> begin
begin
μ = (Main).getproperty(_pars, :μ)
x = (Main).getproperty(_data, :x)
σ = (Main).getproperty(_args, :σ)
var"##1316" = (*)(229.7492699560466, μ)
var"##1317" = (^)(μ, 2)
var"##1318" = (+)(3.0, var"##1317")
var"##1319" = (log)(var"##1318")
var"##1320" = (*)(-2.0, var"##1319")
var"##1321" = (*)(-125.0, var"##1317")
var"##1322" = (+)(-1300.8333911728728, var"##1316", var"##1320", var"##1321")
end
end
``````

Most of that’s hidden from the user. They just build it and then use it like a `logpdf`:

``````julia> ℓ = codegen(post)
#159 (generic function with 1 method)

julia> ℓ(post, (μ=0.2,))
-1262.1072522124998
``````

And really, even building it could be part of the inference algorithm. Anyway, this is pretty fast:

``````julia> @btime \$ℓ(\$post, \$((μ=0.2,)))
6.853 ns (0 allocations: 0 bytes)
-1262.1072522124998
``````

# … And a Slightly Tougher One

That case was easy, but the sum won’t always go away. That’s fine too:

``````julia> m = @model σ begin
μ ~ StudentT(3.0)
x ~ Laplace(μ, σ) |> iid(1000)
return x
end;
``````
``````julia> post = m(σ=2.0) | (;x)
ConditionalModel given
arguments    (:σ,)
observations (:x,)
@model σ begin
μ ~ StudentT(3.0)
x ~ Laplace(μ, σ) |> iid(1000)
return x
end
``````
``````julia> symlogdensity(post)
3.4166191036395746 + -2.0log(3.0 + (μ^2)) + -1(693.1471805599452 + Sum(abs(-0.5μ + 0.5getindex(x, ##i#1584)), ##i#1584, 1, 1000))

julia> codegen(post).f
function = (_args, _data, _pars;) -> begin
begin
μ = (Main).getproperty(_pars, :μ)
x = (Main).getproperty(_data, :x)
σ = (Main).getproperty(_args, :σ)
var"##1643" = (^)(μ, 2)
var"##1644" = (+)(3.0, var"##1643")
var"##1645" = (log)(var"##1644")
var"##1646" = (*)(-2.0, var"##1645")
var"##1647" = begin
var"##sum#1676" = 0.0
var"##lo#1678" = 1
var"##hi#1679" = 1000
begin
\$(Expr(:inbounds, true))
local var"#266#val" = for var"##i#1624" = (Main).:(:)(var"##lo#1678", var"##hi#1679")
begin
var"##Δsum#1677" = (abs)((+)((*)(-0.5, μ), (*)(0.5, (getindex)(x, var"##i#1624"))))
end
end
\$(Expr(:inbounds, :((Main).pop)))
var"#266#val"
end
var"##sum#1676"
end
var"##1648" = (+)(693.1471805599452, var"##1647")
var"##1649" = (*)(-1, var"##1648")
var"##1650" = (+)(3.4166191036395746, var"##1646", var"##1649")
end
end
``````

And it’s still pretty fast:

``````julia> ℓ = codegen(post)
#159 (generic function with 1 method)

julia> @btime \$ℓ(\$post, \$((μ=0.2,)))
162.952 ns (0 allocations: 0 bytes)
-1540.1297266997808
``````

One thing I really like about this approach is that speeding it up just means finding better rewrite rules and better codegen. So the challenge is abstracted away a bit, and any solution can also help with problems in other domains.

17 Likes

@Tamas_Papp I’d guess that soon (if not already) the constant folding and CSE can make this faster that what you might write by hand. I’m hoping to benchmark some simple examples before long, along with comarisons to Stan, etc. Let me know if you have any that would be interesting to you.

@dilumaluthge When this is ready to register, we’ll need to update SossMLJ accordingly. I think things like linear regression ought to really fly 1 Like

Thanks for the ping, I am following this with interest. Note that my focus is on log density with gradients, for NUTS.

1 Like

Yep, mine too. I think the generated code should be very AD-friendly, but we could also fine-tune that down the road (symbolic gradients with codegen)

1 Like
8 Likes

Thanks for the very nice writeup. Keep them coming!

2 Likes

Thanks @Tamas_Papp ! Here’s the next one:
https://informativeprior.com/blog/2021/01-28-measure-theory/

5 Likes