Soss updates

Hello all,

Just thought I’d talk a bit about the latest with Soss. Everything here is from the Soss, MeasureTheory, NestedTuples, and SymbolicCodegen packages.

An easy problem

Say we have a simple model with some observations:

julia> m = @model σ begin
           μ ~ StudentT(3.0)
           x ~ Normal(μ, σ) |> iid(1000)
           return x
       end;

julia> x = rand(m(σ=2.0));

Then we can find the posterior density and symbolically compute its log-density:

julia> post = m(σ=2.0) | (;x)
ConditionalModel given
    arguments    (:σ,)
    observations (:x,)
@model σ begin
        μ ~ StudentT(3.0)
        x ~ Normal(μ, σ) |> iid(1000)
        return x
    end

julia> symlogdensity(post)
-1300.8333911728728 + 229.7492699560466μ + -2.0log(3.0 + (μ^2)) + -125.0(μ^2)

The σ and x are gone, because we were able to substitute them into the generated code. That works like this:

  1. Run the model forward to get all the types
  2. Generate symbolic values from these, with the appropriate type tags
  3. Evaluate the log-density on this result. This sounds like a bad idea, but all vectors are symbolic, so the code stays small
  4. Walk the expression, and replace any <: Number subexpression with its value
  5. Return the result.

Next is CSE:

julia> symlogdensity(post) |> SymbolicCodegen.cse
7-element Vector{Pair{Symbol, Symbolic}}:
 Symbol("##930") => 229.7492699560466μ
 Symbol("##931") => μ^2
 Symbol("##932") => 3.0 + ##931
 Symbol("##933") => log(##932)
 Symbol("##934") => -2.0##933
 Symbol("##935") => -125.0##931
 Symbol("##936") => -1300.8333911728728 + ##930 + ##934 + ##935

We put those together:

julia> sourceCodegen(post)
quote
    var"##1010" = (*)(229.7492699560466, μ)
    var"##1011" = (^)(μ, 2)
    var"##1012" = (+)(3.0, var"##1011")
    var"##1013" = (log)(var"##1012")
    var"##1014" = (*)(-2.0, var"##1013")
    var"##1015" = (*)(-125.0, var"##1011")
    var"##1016" = (+)(-1300.8333911728728, var"##1010", var"##1014", var"##1015")
end

Then add some variable loading, and generate the function:

julia> codegen(post).f
function = (_args, _data, _pars;) -> begin
    begin
        μ = (Main).getproperty(_pars, :μ)
        x = (Main).getproperty(_data, :x)
        σ = (Main).getproperty(_args, :σ)
        var"##1316" = (*)(229.7492699560466, μ)
        var"##1317" = (^)(μ, 2)
        var"##1318" = (+)(3.0, var"##1317")
        var"##1319" = (log)(var"##1318")
        var"##1320" = (*)(-2.0, var"##1319")
        var"##1321" = (*)(-125.0, var"##1317")
        var"##1322" = (+)(-1300.8333911728728, var"##1316", var"##1320", var"##1321")
    end
end

Most of that’s hidden from the user. They just build it and then use it like a logpdf:

julia> ℓ = codegen(post)
#159 (generic function with 1 method)

julia> ℓ(post, (μ=0.2,))
-1262.1072522124998

And really, even building it could be part of the inference algorithm. Anyway, this is pretty fast:

julia> @btime $ℓ($post, $((μ=0.2,)))
  6.853 ns (0 allocations: 0 bytes)
-1262.1072522124998

… And a Slightly Tougher One

That case was easy, but the sum won’t always go away. That’s fine too:

julia> m = @model σ begin
           μ ~ StudentT(3.0)
           x ~ Laplace(μ, σ) |> iid(1000)
           return x
       end;
julia> post = m(σ=2.0) | (;x)
ConditionalModel given
    arguments    (:σ,)
    observations (:x,)
@model σ begin
        μ ~ StudentT(3.0)
        x ~ Laplace(μ, σ) |> iid(1000)
        return x
    end
julia> symlogdensity(post)
3.4166191036395746 + -2.0log(3.0 + (μ^2)) + -1(693.1471805599452 + Sum(abs(-0.5μ + 0.5getindex(x, ##i#1584)), ##i#1584, 1, 1000))

julia> codegen(post).f
function = (_args, _data, _pars;) -> begin
    begin
        μ = (Main).getproperty(_pars, :μ)
        x = (Main).getproperty(_data, :x)
        σ = (Main).getproperty(_args, :σ)
        var"##1643" = (^)(μ, 2)
        var"##1644" = (+)(3.0, var"##1643")
        var"##1645" = (log)(var"##1644")
        var"##1646" = (*)(-2.0, var"##1645")
        var"##1647" = begin
                var"##sum#1676" = 0.0
                var"##lo#1678" = 1
                var"##hi#1679" = 1000
                begin
                    $(Expr(:inbounds, true))
                    local var"#266#val" = for var"##i#1624" = (Main).:(:)(var"##lo#1678", var"##hi#1679")
                                begin
                                    var"##Δsum#1677" = (abs)((+)((*)(-0.5, μ), (*)(0.5, (getindex)(x, var"##i#1624"))))
                                    var"##sum#1676" = (Main).Base.FastMath.add_fast(var"##sum#1676", var"##Δsum#1677")
                                end
                            end
                    $(Expr(:inbounds, :((Main).pop)))
                    var"#266#val"
                end
                var"##sum#1676"
            end
        var"##1648" = (+)(693.1471805599452, var"##1647")
        var"##1649" = (*)(-1, var"##1648")
        var"##1650" = (+)(3.4166191036395746, var"##1646", var"##1649")
    end
end

And it’s still pretty fast:

julia> ℓ = codegen(post)
#159 (generic function with 1 method)

julia> @btime $ℓ($post, $((μ=0.2,)))
  162.952 ns (0 allocations: 0 bytes)
-1540.1297266997808

One thing I really like about this approach is that speeding it up just means finding better rewrite rules and better codegen. So the challenge is abstracted away a bit, and any solution can also help with problems in other domains.

17 Likes

@Tamas_Papp I’d guess that soon (if not already) the constant folding and CSE can make this faster that what you might write by hand. I’m hoping to benchmark some simple examples before long, along with comarisons to Stan, etc. Let me know if you have any that would be interesting to you.

@dilumaluthge When this is ready to register, we’ll need to update SossMLJ accordingly. I think things like linear regression ought to really fly :slight_smile:

1 Like

Thanks for the ping, I am following this with interest. Note that my focus is on log density with gradients, for NUTS.

1 Like

Yep, mine too. I think the generated code should be very AD-friendly, but we could also fine-tune that down the road (symbolic gradients with codegen)

1 Like

Progress!
https://informativeprior.com/blog/2021/01-25-symbolic-simplification/

8 Likes

Thanks for the very nice writeup. Keep them coming!

2 Likes

Thanks @Tamas_Papp ! Here’s the next one:
https://informativeprior.com/blog/2021/01-28-measure-theory/

5 Likes