# Metaprogramming with types

I had been building my log-density function like this:

``````julia> sourceLogdensity(linReg1D)
:(function ##logdensity#413(pars)
@unpack (α, β, σ, x, y) = pars
ℓ = 0.0
ℓ += logpdf(Cauchy(0, 10), α)
ℓ += logpdf(Cauchy(0, 2.5), β)
ℓ += logpdf(HalfCauchy(3), σ)
ŷ = α .+ β .* x
N = length(x)
ℓ += logpdf(For(1:N) do n
Normal(ŷ[n], σ)
end, y)
return ℓ
end)
``````

In inference this goes in a tight loop, so it needs to be fast. But there’s not much type information to help it out - currently `@code_warntype` complains about `Any`. So I made it generate this:

``````julia> sourceLogdensity(linReg1D, (x=randn(10),y=randn(10)))
:(function ##logdensity#415(pars::NamedTuple{(:x, :y, :α, :β, :σ),Tuple{Array{Real,1},Array{Real,1},Real,Real,Real}})
@unpack (x, y, α, β, σ) = pars
ℓ = 0.0
ℓ += logpdf(Cauchy(0, 10), α)
ℓ += logpdf(Cauchy(0, 2.5), β)
ℓ += logpdf(HalfCauchy(3), σ)
ŷ = α .+ β .* x
N = length(x)
ℓ += logpdf(For(1:N) do n
Normal(ŷ[n], σ)
end, y)
return ℓ
end)
``````

(note the type constraint on `pars`)

Unfortunately this still doesn’t work, I think because `NamedTuple`s aren’t covariant.

If it was just writing it once, I think I could get it working. But I need to write code to generate this type. Currently I’m doing this:

``````function realtypes(nt::Type{NamedTuple{S, T} } ) where {S, T}
NamedTuple{S, realtypes(T)}
end

realtypes(::Type{Tuple{A,B}} ) where {A,B} = Tuple{realtypes(A), realtypes(B)}
realtypes(::Type{Tuple{A,B,C}} ) where {A,B,C} = Tuple{realtypes(A), realtypes(B), realtypes(C)}
realtypes(::Type{Tuple{A,B,C,D}} ) where {A,B,C,D} = Tuple{realtypes(A), realtypes(B), realtypes(C), realtypes(D)}
realtypes(::Type{Tuple{A,B,C,D,E}} ) where {A,B,C,D,E} = Tuple{realtypes(A), realtypes(B), realtypes(C), realtypes(D), realtypes(E)}

realtypes(::Type{Array{T, N}}) where {T,N}= Array{realtypes(T),N}

realtypes(::Type{<: Real}) = Real
``````

So I just take a sample from the model and get its type, then apply `realtypes` and stick that in the code.

I think the solution is to turn the final `Real` into `T where {T}`, except that the where clause needs to propagate upward.

1. Is there a better way to do this?
2. Is there a way to do this?
1 Like

Type annotation of the inputs doesn’t affect type stability. You should study the output of `@code_warntype` and see what is the source of type instability, then try to fix it.

Why do you need code generation here?

2 Likes

Here’s what I’m trying to get around:

``````julia> m = linReg1D
@model (x, y) begin
α ~ Cauchy(0, 10)
β ~ Cauchy(0, 2.5)
σ ~ HalfCauchy(3)
ŷ = α .+ β .* x
N = length(x)
y ~ For(1:N) do n
Normal(ŷ[n], σ)
end
end

julia>     f1 = makeLogdensity(m)
(::getfield(Soss, Symbol("#f#41")){getfield(Soss, Symbol("###logdensity#435"))}) (generic function with 1 method)

julia>     ℓ(pars) = f1(merge(data,pars))
ℓ (generic function with 1 method)

julia> @code_warntype ℓ((α=1.0,β=1.0,σ=1.0))
Variables
#self#::Core.Compiler.Const(ℓ, false)
pars::NamedTuple{(:α, :β, :σ),Tuple{Float64,Float64,Float64}}

Body::Any
1 ─ %1 = Main.merge(Main.data, pars)::NamedTuple{_A,_B} where _B where _A
│   %2 = Main.f1(%1)::Any
└──      return %2
``````

`makeLogdensity` uses `sourceLogdensity`, which produces this:

``````julia> sourceLogdensity(m)
:(function ##logdensity#438(pars)
@unpack (α, β, σ, x, y) = pars
ℓ = 0.0
ℓ += logpdf(Cauchy(0, 10), α)
ℓ += logpdf(Cauchy(0, 2.5), β)
ℓ += logpdf(HalfCauchy(3), σ)
ŷ = α .+ β .* x
N = length(x)
ℓ += logpdf(For(1:N) do n
Normal(ŷ[n], σ)
end, y)
return ℓ
end)
``````

Stick an `@code_warntype` before the inner function call, e.g. `@code_warntype Main.f1(%1)`. This is where the inference is failing. Actually not quite, it seems your merge function is not inferring to begin with. Is `Main.data` a global variable?

I think the big picture story is that @cscherrer is doing various symbolic transformations on probabilistic models for Soss.jl. There was a big discussion about the design of Soss.jl in Working with ASTs where I tried to understand this very question: why does Chad need so much code generation, and could he maybe do it without calling `eval` so much

As I understand it, the sticking point is that the desired transformations are being implemented symbolically and the choice of transformation isn’t known at macro expansion time. So the Soss `Model` data structure ends up holding an AST and transforming that into the various forms which can be `eval`'d as appropriate to the needs of a given inference method. This still makes me uneasy because it’s quite different from the julia compilation model: ASTs are carried around outside the code context in which they’re originally written, and it seems hard for the Soss code analysis tools to make much use of the julia compiler (other than as a backend). So it feels like there’s still a weird design mismatch here which would be nice to resolve but as a DSL for statistical models the Soss approach seems to make sense.

@cscherrer hopefully that summary is approximately accurate!

1 Like

Oh right. Yes, that’s it. At least I got some type-level practice

EDIT: Wait, no it’s still having trouble (see below)

I guess any interactive system will have this issue. Do you see significant performance effects from this in Turing? Do you tell users to make inputs `const` or wrap everything in a function, or is it not enough of an issue?

Perfect, thank you! I had thought @stevengj was asking why I needed codegen for the types, but either way more detail is very helpful.

This is something I’ve struggled with for a long time. When I describe it, the first response is often that I’m doing it wrong. But I’ve tried using macros and generated functions with no luck, and as much as I’ve asked around no one has been able to come up with a better approach.

It turns out I’m not alone in this; there’s some discussion about this as a relatively common problem in DSL-land here.

After a minor tweak from the original, I’m currently using this:

``````@inline @generated function _invokefrozen(f, ::Type{rt}, args...) where rt
tupargs = Expr(:tuple,(a==Nothing ? Int : a for a in args)...)
quote
_f = \$(Expr(:cfunction, Base.CFunction, :f, rt, :((Core.svec)(\$((a==Nothing ? Int : a for a in args)...))), :(:ccall)))
return ccall(_f.ptr,rt,\$tupargs,\$((:(getindex(args,\$i) === nothing ? 0 : getindex(args,\$i)) for i in 1:length(args))...))
end
end

@inline function invokefrozen(f, rt, args...; kwargs...)
g(kwargs, args...) = f(args...; kwargs...)
kwargs = (;kwargs...)
_invokefrozen(g, rt, (;kwargs...), args...)
end

@inline function invokefrozen(f, rt, args...)
_invokefrozen(f, rt, args...)
end
``````

Seems to be working quite well for the most part. But wrapping things in a function still doesn’t make everything completely happy. Here’s a simpler example:

``````using Soss

m = @model x begin
α ~ Normal(0,1)
x ~ Normal(α,1)
end

function f(x)
nuts(m, (x=x,))
end

@code_warntype f(1.2)
f(1.2) |> typeof
``````

This produces

``````julia> @code_warntype f(1.2)
Variables
#self#::Core.Compiler.Const(f, false)
x::Float64

Body::Soss.NUTS_result{_A} where _A
1 ─ %1 = (:x,)::Core.Compiler.Const((:x,), false)
│   %2 = Core.apply_type(Core.NamedTuple, %1)::Core.Compiler.Const(NamedTuple{(:x,),T} where T<:Tuple, false)
│   %3 = Core.tuple(x)::Tuple{Float64}
│   %4 = (%2)(%3)::NamedTuple{(:x,),Tuple{Float64}}
│   %5 = Main.nuts(Main.m, %4)::Soss.NUTS_result{_A} where _A
└──      return %5

julia> f(1.2) |> typeof
2.8e-5 s/step ...done
3.2e-5 s/step ...done
3.1e-5 s/step ...done
2.8e-5 s/step ...done
2.7e-5 s/step ...done
2.2e-5 s/step ...done
5.9e-5 s/step ...done
MCMC (1000 steps)
3.9e-5 s/step ...done
Soss.NUTS_result{NamedTuple{(:α,),Tuple{Float64}}}
``````

Which takes me back to thinking I need to do some codegen tricks in order to know the type in advance

I don’t quite understand your comment here. `Main.merge(Main.data, pars)` is struggling to infer. I am assuming `pars` is input to the function so its type is known, but `Main.data` is not, it is closed over by the function. When you do this in global scope, you get type instability because `Main.data` can change value and type in the life of a REPL session. In Turing, we take the data in a struct and then make the struct callable. You can also pass `Main.data` as input to an outer function that returns a closure (inner function) over `data`. This should also infer.

1 Like

Oh, that’s interesting. Does making it callable have any effect? If it’s just putting it a struct that matters, how would this be different than putting it in a NamedTuple? Think I’m missing a fine point here.

Hmm, ok I think I need to improve my sense of which thing infer and which don’t, and how to predict that. Any suggestions for getting a better handle on this?

It’s not. It’s just about not closing over variables without telling the Julia compiler that the type of this variable will not change during the life of the function. When you close over a variable `x = 1` for example, doing `x = "s"` between 2 calls of the closing function is valid and will change the type of `x`.

``````x = 1
f() = x
f() # returns 1
x = "s" # Changes f
f() # returns "s"
``````

When you pass `x` as input to a function. Doing `x = "s"` will only affect the global state, not the variables inside the function. Similarly, when making a callable struct, you are “freezing” the type of the closed over variable. So let `f = F(x)` be a callable struct that uses the variable `f.x` inside its body, e.g.

``````struct F
x::Int
end
(f::F)() = f.x
``````

You can tell the Julia compiler that `f.x` is not going to change type using the field type of `x`, `Int` above. So I can do:

``````x = 1
f = F(x)
f() # returns 1
x = "s" # Doesn't change f
f() # returns 1
``````

so `x` and `f` are now “decoupled”. The same applies for the outer and inner function approach.

This is probably a good place. https://docs.julialang.org/en/v1/manual/performance-tips/index.html.

4 Likes

Haha ok, I should have known you’d say that. I’ve seen that, but guess I need to reread it a few more times. I’ll keep it under my pillow

BTW, that’s a great example you gave, and very clearly written. I’d love to see that added to the performance page, if that’s a possibility.

You can always make a PR

1 Like

Great example. One point should be clarified here:

I’m not sure it’s correct to say that this version of `f()` closes over x. Rather it refers symbolically to the global variable `Main.x` which need not even be defined at the time f is defined:

``````julia> @code_lowered f()
CodeInfo(
1 ─     return Main.x
)
``````

This is different to closing over a local variable `x` which creates a callable type quite like your `F`:

``````julia> function g(x)
f() = x
f
end

julia> h = g("a")
(::getfield(Main, Symbol("#f#7")){String}) (generic function with 1 method)

julia> fieldnames(typeof(h))
(:x,)

julia> h()
"a"

julia> @code_lowered h()
CodeInfo(
1 ─ %1 = (Core.getfield)(#self#, :x)
└──      return %1
)
``````
2 Likes

Thanks for the correction

In this issue, @Chris_Foster suggests using the notation `m(a=2, b=[2,3,4])` for observations. So for example we might have

``````nuts(m(a=2, b=[2,3,4]); numSamples=1000)
``````

This means we’d need to define `m(a=2, b=[2,3,4])`, independent of any reference to sampling.

But maybe that’s a good thing! We could for example have an `m.data` field to hold the observations, and the type of this could be determined at model construction time.

It seems to me like that would solve the type inference issue. @mohamed82008 and @Chris_Foster does that sounds right?