Hello all,
Just thought I’d talk a bit about the latest with Soss. Everything here is from the Soss, MeasureTheory, NestedTuples, and SymbolicCodegen packages.
An easy problem
Say we have a simple model with some observations:
julia> m = @model σ begin
μ ~ StudentT(3.0)
x ~ Normal(μ, σ) |> iid(1000)
return x
end;
julia> x = rand(m(σ=2.0));
Then we can find the posterior density and symbolically compute its log-density:
julia> post = m(σ=2.0) | (;x)
ConditionalModel given
arguments (:σ,)
observations (:x,)
@model σ begin
μ ~ StudentT(3.0)
x ~ Normal(μ, σ) |> iid(1000)
return x
end
julia> symlogdensity(post)
-1300.8333911728728 + 229.7492699560466μ + -2.0log(3.0 + (μ^2)) + -125.0(μ^2)
The σ
and x
are gone, because we were able to substitute them into the generated code. That works like this:
- Run the model forward to get all the types
- Generate symbolic values from these, with the appropriate type tags
- Evaluate the log-density on this result. This sounds like a bad idea, but all vectors are symbolic, so the code stays small
- Walk the expression, and replace any
<: Number
subexpression with its value - Return the result.
Next is CSE:
julia> symlogdensity(post) |> SymbolicCodegen.cse
7-element Vector{Pair{Symbol, Symbolic}}:
Symbol("##930") => 229.7492699560466μ
Symbol("##931") => μ^2
Symbol("##932") => 3.0 + ##931
Symbol("##933") => log(##932)
Symbol("##934") => -2.0##933
Symbol("##935") => -125.0##931
Symbol("##936") => -1300.8333911728728 + ##930 + ##934 + ##935
We put those together:
julia> sourceCodegen(post)
quote
var"##1010" = (*)(229.7492699560466, μ)
var"##1011" = (^)(μ, 2)
var"##1012" = (+)(3.0, var"##1011")
var"##1013" = (log)(var"##1012")
var"##1014" = (*)(-2.0, var"##1013")
var"##1015" = (*)(-125.0, var"##1011")
var"##1016" = (+)(-1300.8333911728728, var"##1010", var"##1014", var"##1015")
end
Then add some variable loading, and generate the function:
julia> codegen(post).f
function = (_args, _data, _pars;) -> begin
begin
μ = (Main).getproperty(_pars, :μ)
x = (Main).getproperty(_data, :x)
σ = (Main).getproperty(_args, :σ)
var"##1316" = (*)(229.7492699560466, μ)
var"##1317" = (^)(μ, 2)
var"##1318" = (+)(3.0, var"##1317")
var"##1319" = (log)(var"##1318")
var"##1320" = (*)(-2.0, var"##1319")
var"##1321" = (*)(-125.0, var"##1317")
var"##1322" = (+)(-1300.8333911728728, var"##1316", var"##1320", var"##1321")
end
end
Most of that’s hidden from the user. They just build it and then use it like a logpdf
:
julia> ℓ = codegen(post)
#159 (generic function with 1 method)
julia> ℓ(post, (μ=0.2,))
-1262.1072522124998
And really, even building it could be part of the inference algorithm. Anyway, this is pretty fast:
julia> @btime $ℓ($post, $((μ=0.2,)))
6.853 ns (0 allocations: 0 bytes)
-1262.1072522124998
… And a Slightly Tougher One
That case was easy, but the sum won’t always go away. That’s fine too:
julia> m = @model σ begin
μ ~ StudentT(3.0)
x ~ Laplace(μ, σ) |> iid(1000)
return x
end;
julia> post = m(σ=2.0) | (;x)
ConditionalModel given
arguments (:σ,)
observations (:x,)
@model σ begin
μ ~ StudentT(3.0)
x ~ Laplace(μ, σ) |> iid(1000)
return x
end
julia> symlogdensity(post)
3.4166191036395746 + -2.0log(3.0 + (μ^2)) + -1(693.1471805599452 + Sum(abs(-0.5μ + 0.5getindex(x, ##i#1584)), ##i#1584, 1, 1000))
julia> codegen(post).f
function = (_args, _data, _pars;) -> begin
begin
μ = (Main).getproperty(_pars, :μ)
x = (Main).getproperty(_data, :x)
σ = (Main).getproperty(_args, :σ)
var"##1643" = (^)(μ, 2)
var"##1644" = (+)(3.0, var"##1643")
var"##1645" = (log)(var"##1644")
var"##1646" = (*)(-2.0, var"##1645")
var"##1647" = begin
var"##sum#1676" = 0.0
var"##lo#1678" = 1
var"##hi#1679" = 1000
begin
$(Expr(:inbounds, true))
local var"#266#val" = for var"##i#1624" = (Main).:(:)(var"##lo#1678", var"##hi#1679")
begin
var"##Δsum#1677" = (abs)((+)((*)(-0.5, μ), (*)(0.5, (getindex)(x, var"##i#1624"))))
var"##sum#1676" = (Main).Base.FastMath.add_fast(var"##sum#1676", var"##Δsum#1677")
end
end
$(Expr(:inbounds, :((Main).pop)))
var"#266#val"
end
var"##sum#1676"
end
var"##1648" = (+)(693.1471805599452, var"##1647")
var"##1649" = (*)(-1, var"##1648")
var"##1650" = (+)(3.4166191036395746, var"##1646", var"##1649")
end
end
And it’s still pretty fast:
julia> ℓ = codegen(post)
#159 (generic function with 1 method)
julia> @btime $ℓ($post, $((μ=0.2,)))
162.952 ns (0 allocations: 0 bytes)
-1540.1297266997808
One thing I really like about this approach is that speeding it up just means finding better rewrite rules and better codegen. So the challenge is abstracted away a bit, and any solution can also help with problems in other domains.