Metaprogramming with types

I had been building my log-density function like this:

julia> sourceLogdensity(linReg1D)
:(function ##logdensity#413(pars)
      @unpack (α, β, σ, x, y) = pars
      ℓ = 0.0
      ℓ += logpdf(Cauchy(0, 10), α)
      ℓ += logpdf(Cauchy(0, 2.5), β)
      ℓ += logpdf(HalfCauchy(3), σ)
      ŷ = α .+ β .* x
      N = length(x)
      ℓ += logpdf(For(1:N) do n
                  #= /home/chad/git/Soss.jl/src/examples.jl:86 =#
                  Normal(ŷ[n], σ)
              end, y)
      return ℓ
  end)

In inference this goes in a tight loop, so it needs to be fast. But there’s not much type information to help it out - currently @code_warntype complains about Any. So I made it generate this:

julia> sourceLogdensity(linReg1D, (x=randn(10),y=randn(10)))
:(function ##logdensity#415(pars::NamedTuple{(:x, :y, :α, :β, :σ),Tuple{Array{Real,1},Array{Real,1},Real,Real,Real}})
      @unpack (x, y, α, β, σ) = pars
      ℓ = 0.0
      ℓ += logpdf(Cauchy(0, 10), α)
      ℓ += logpdf(Cauchy(0, 2.5), β)
      ℓ += logpdf(HalfCauchy(3), σ)
      ŷ = α .+ β .* x
      N = length(x)
      ℓ += logpdf(For(1:N) do n
                  #= /home/chad/git/Soss.jl/src/examples.jl:86 =#
                  Normal(ŷ[n], σ)
              end, y)
      return ℓ
  end)

(note the type constraint on pars)

Unfortunately this still doesn’t work, I think because NamedTuples aren’t covariant.

If it was just writing it once, I think I could get it working. But I need to write code to generate this type. Currently I’m doing this:

function realtypes(nt::Type{NamedTuple{S, T} } ) where {S, T} 
    NamedTuple{S, realtypes(T)}
end

realtypes(::Type{Tuple{A,B}} ) where {A,B} = Tuple{realtypes(A), realtypes(B)}
realtypes(::Type{Tuple{A,B,C}} ) where {A,B,C} = Tuple{realtypes(A), realtypes(B), realtypes(C)}
realtypes(::Type{Tuple{A,B,C,D}} ) where {A,B,C,D} = Tuple{realtypes(A), realtypes(B), realtypes(C), realtypes(D)}
realtypes(::Type{Tuple{A,B,C,D,E}} ) where {A,B,C,D,E} = Tuple{realtypes(A), realtypes(B), realtypes(C), realtypes(D), realtypes(E)}

realtypes(::Type{Array{T, N}}) where {T,N}= Array{realtypes(T),N}

realtypes(::Type{<: Real}) = Real

So I just take a sample from the model and get its type, then apply realtypes and stick that in the code.

I think the solution is to turn the final Real into T where {T}, except that the where clause needs to propagate upward.

  1. Is there a better way to do this?
  2. Is there a way to do this? :wink:

Type annotation of the inputs doesn’t affect type stability. You should study the output of @code_warntype and see what is the source of type instability, then try to fix it.

Why do you need code generation here?

2 Likes

Here’s what I’m trying to get around:

julia> m = linReg1D
@model (x, y) begin
        α ~ Cauchy(0, 10)
        β ~ Cauchy(0, 2.5)
        σ ~ HalfCauchy(3)
        ŷ = α .+ β .* x
        N = length(x)
        y ~ For(1:N) do n
                Normal(ŷ[n], σ)
            end
    end

julia>     f1 = makeLogdensity(m)
(::getfield(Soss, Symbol("#f#41")){getfield(Soss, Symbol("###logdensity#435"))}) (generic function with 1 method)

julia>     ℓ(pars) = f1(merge(data,pars))
ℓ (generic function with 1 method)

julia> @code_warntype ℓ((α=1.0,β=1.0,σ=1.0))
Variables
  #self#::Core.Compiler.Const(ℓ, false)
  pars::NamedTuple{(:α, :β, :σ),Tuple{Float64,Float64,Float64}}

Body::Any
1 ─ %1 = Main.merge(Main.data, pars)::NamedTuple{_A,_B} where _B where _A
│   %2 = Main.f1(%1)::Any
└──      return %2

makeLogdensity uses sourceLogdensity, which produces this:

julia> sourceLogdensity(m)
:(function ##logdensity#438(pars)
      @unpack (α, β, σ, x, y) = pars
      ℓ = 0.0
      ℓ += logpdf(Cauchy(0, 10), α)
      ℓ += logpdf(Cauchy(0, 2.5), β)
      ℓ += logpdf(HalfCauchy(3), σ)
      ŷ = α .+ β .* x
      N = length(x)
      ℓ += logpdf(For(1:N) do n
                  #= /home/chad/git/jl/Soss/src/examples.jl:86 =#
                  Normal(ŷ[n], σ)
              end, y)
      return ℓ
  end)

Stick an @code_warntype before the inner function call, e.g. @code_warntype Main.f1(%1). This is where the inference is failing. Actually not quite, it seems your merge function is not inferring to begin with. Is Main.data a global variable?

I think the big picture story is that @cscherrer is doing various symbolic transformations on probabilistic models for Soss.jl. There was a big discussion about the design of Soss.jl in Working with ASTs where I tried to understand this very question: why does Chad need so much code generation, and could he maybe do it without calling eval so much :slight_smile:

As I understand it, the sticking point is that the desired transformations are being implemented symbolically and the choice of transformation isn’t known at macro expansion time. So the Soss Model data structure ends up holding an AST and transforming that into the various forms which can be eval'd as appropriate to the needs of a given inference method. This still makes me uneasy because it’s quite different from the julia compilation model: ASTs are carried around outside the code context in which they’re originally written, and it seems hard for the Soss code analysis tools to make much use of the julia compiler (other than as a backend). So it feels like there’s still a weird design mismatch here which would be nice to resolve but as a DSL for statistical models the Soss approach seems to make sense.

@cscherrer hopefully that summary is approximately accurate!

Oh right. Yes, that’s it. At least I got some type-level practice :laughing:

EDIT: Wait, no it’s still having trouble (see below)

I guess any interactive system will have this issue. Do you see significant performance effects from this in Turing? Do you tell users to make inputs const or wrap everything in a function, or is it not enough of an issue?

Perfect, thank you! I had thought @stevengj was asking why I needed codegen for the types, but either way more detail is very helpful.

This is something I’ve struggled with for a long time. When I describe it, the first response is often that I’m doing it wrong. But I’ve tried using macros and generated functions with no luck, and as much as I’ve asked around no one has been able to come up with a better approach.

It turns out I’m not alone in this; there’s some discussion about this as a relatively common problem in DSL-land here.

After a minor tweak from the original, I’m currently using this:

@inline @generated function _invokefrozen(f, ::Type{rt}, args...) where rt
    tupargs = Expr(:tuple,(a==Nothing ? Int : a for a in args)...)
    quote
        _f = $(Expr(:cfunction, Base.CFunction, :f, rt, :((Core.svec)($((a==Nothing ? Int : a for a in args)...))), :(:ccall)))
        return ccall(_f.ptr,rt,$tupargs,$((:(getindex(args,$i) === nothing ? 0 : getindex(args,$i)) for i in 1:length(args))...))
    end
end

@inline function invokefrozen(f, rt, args...; kwargs...)
    g(kwargs, args...) = f(args...; kwargs...)
    kwargs = (;kwargs...)
    _invokefrozen(g, rt, (;kwargs...), args...)
end

@inline function invokefrozen(f, rt, args...)
    _invokefrozen(f, rt, args...)
end

Seems to be working quite well for the most part. But wrapping things in a function still doesn’t make everything completely happy. Here’s a simpler example:

using Soss

m = @model x begin
    α ~ Normal(0,1)
    x ~ Normal(α,1)
end

function f(x) 
    nuts(m, (x=x,))
end

@code_warntype f(1.2)
f(1.2) |> typeof

This produces

julia> @code_warntype f(1.2)
Variables
  #self#::Core.Compiler.Const(f, false)
  x::Float64

Body::Soss.NUTS_result{_A} where _A
1 ─ %1 = (:x,)::Core.Compiler.Const((:x,), false)
│   %2 = Core.apply_type(Core.NamedTuple, %1)::Core.Compiler.Const(NamedTuple{(:x,),T} where T<:Tuple, false)
│   %3 = Core.tuple(x)::Tuple{Float64}
│   %4 = (%2)(%3)::NamedTuple{(:x,),Tuple{Float64}}
│   %5 = Main.nuts(Main.m, %4)::Soss.NUTS_result{_A} where _A
└──      return %5

julia> f(1.2) |> typeof
MCMC, adapting ϵ (75 steps)
2.8e-5 s/step ...done
MCMC, adapting ϵ (25 steps)
3.2e-5 s/step ...done
MCMC, adapting ϵ (50 steps)
3.1e-5 s/step ...done
MCMC, adapting ϵ (100 steps)
2.8e-5 s/step ...done
MCMC, adapting ϵ (200 steps)
2.7e-5 s/step ...done
MCMC, adapting ϵ (400 steps)
2.2e-5 s/step ...done
MCMC, adapting ϵ (50 steps)
5.9e-5 s/step ...done
MCMC (1000 steps)
3.9e-5 s/step ...done
Soss.NUTS_result{NamedTuple{(:α,),Tuple{Float64}}}

Which takes me back to thinking I need to do some codegen tricks in order to know the type in advance

I don’t quite understand your comment here. Main.merge(Main.data, pars) is struggling to infer. I am assuming pars is input to the function so its type is known, but Main.data is not, it is closed over by the function. When you do this in global scope, you get type instability because Main.data can change value and type in the life of a REPL session. In Turing, we take the data in a struct and then make the struct callable. You can also pass Main.data as input to an outer function that returns a closure (inner function) over data. This should also infer.

1 Like

Oh, that’s interesting. Does making it callable have any effect? If it’s just putting it a struct that matters, how would this be different than putting it in a NamedTuple? Think I’m missing a fine point here.

Hmm, ok I think I need to improve my sense of which thing infer and which don’t, and how to predict that. Any suggestions for getting a better handle on this?

It’s not. It’s just about not closing over variables without telling the Julia compiler that the type of this variable will not change during the life of the function. When you close over a variable x = 1 for example, doing x = "s" between 2 calls of the closing function is valid and will change the type of x.

x = 1
f() = x
f() # returns 1
x = "s" # Changes f
f() # returns "s"

When you pass x as input to a function. Doing x = "s" will only affect the global state, not the variables inside the function. Similarly, when making a callable struct, you are “freezing” the type of the closed over variable. So let f = F(x) be a callable struct that uses the variable f.x inside its body, e.g.

struct F
    x::Int
end
(f::F)() = f.x

You can tell the Julia compiler that f.x is not going to change type using the field type of x, Int above. So I can do:

x = 1
f = F(x)
f() # returns 1
x = "s" # Doesn't change f
f() # returns 1

so x and f are now “decoupled”. The same applies for the outer and inner function approach.

This is probably a good place. https://docs.julialang.org/en/v1/manual/performance-tips/index.html.

4 Likes

Haha ok, I should have known you’d say that. I’ve seen that, but guess I need to reread it a few more times. I’ll keep it under my pillow :wink:

BTW, that’s a great example you gave, and very clearly written. I’d love to see that added to the performance page, if that’s a possibility.

You can always make a PR :wink:

1 Like

Great example. One point should be clarified here:

I’m not sure it’s correct to say that this version of f() closes over x. Rather it refers symbolically to the global variable Main.x which need not even be defined at the time f is defined:

julia> @code_lowered f()
CodeInfo(
1 ─     return Main.x
)

This is different to closing over a local variable x which creates a callable type quite like your F:

julia> function g(x)
           f() = x
           f
       end

julia> h = g("a")
(::getfield(Main, Symbol("#f#7")){String}) (generic function with 1 method)

julia> fieldnames(typeof(h))
(:x,)

julia> h()
"a"

julia> @code_lowered h()
CodeInfo(
1 ─ %1 = (Core.getfield)(#self#, :x)
└──      return %1
)
2 Likes

Thanks for the correction :slight_smile:

In this issue, @Chris_Foster suggests using the notation m(a=2, b=[2,3,4]) for observations. So for example we might have

nuts(m(a=2, b=[2,3,4]); numSamples=1000)

This means we’d need to define m(a=2, b=[2,3,4]), independent of any reference to sampling.

But maybe that’s a good thing! We could for example have an m.data field to hold the observations, and the type of this could be determined at model construction time.

It seems to me like that would solve the type inference issue. @mohamed82008 and @Chris_Foster does that sounds right?