Oh right. Yes, that’s it. At least I got some type-level practice
EDIT: Wait, no it’s still having trouble (see below)
I guess any interactive system will have this issue. Do you see significant performance effects from this in Turing? Do you tell users to make inputs const
or wrap everything in a function, or is it not enough of an issue?
Perfect, thank you! I had thought @stevengj was asking why I needed codegen for the types, but either way more detail is very helpful.
This is something I’ve struggled with for a long time. When I describe it, the first response is often that I’m doing it wrong. But I’ve tried using macros and generated functions with no luck, and as much as I’ve asked around no one has been able to come up with a better approach.
It turns out I’m not alone in this; there’s some discussion about this as a relatively common problem in DSL-land here.
After a minor tweak from the original, I’m currently using this:
@inline @generated function _invokefrozen(f, ::Type{rt}, args...) where rt
tupargs = Expr(:tuple,(a==Nothing ? Int : a for a in args)...)
quote
_f = $(Expr(:cfunction, Base.CFunction, :f, rt, :((Core.svec)($((a==Nothing ? Int : a for a in args)...))), :(:ccall)))
return ccall(_f.ptr,rt,$tupargs,$((:(getindex(args,$i) === nothing ? 0 : getindex(args,$i)) for i in 1:length(args))...))
end
end
@inline function invokefrozen(f, rt, args...; kwargs...)
g(kwargs, args...) = f(args...; kwargs...)
kwargs = (;kwargs...)
_invokefrozen(g, rt, (;kwargs...), args...)
end
@inline function invokefrozen(f, rt, args...)
_invokefrozen(f, rt, args...)
end
Seems to be working quite well for the most part. But wrapping things in a function still doesn’t make everything completely happy. Here’s a simpler example:
using Soss
m = @model x begin
α ~ Normal(0,1)
x ~ Normal(α,1)
end
function f(x)
nuts(m, (x=x,))
end
@code_warntype f(1.2)
f(1.2) |> typeof
This produces
julia> @code_warntype f(1.2)
Variables
#self#::Core.Compiler.Const(f, false)
x::Float64
Body::Soss.NUTS_result{_A} where _A
1 ─ %1 = (:x,)::Core.Compiler.Const((:x,), false)
│ %2 = Core.apply_type(Core.NamedTuple, %1)::Core.Compiler.Const(NamedTuple{(:x,),T} where T<:Tuple, false)
│ %3 = Core.tuple(x)::Tuple{Float64}
│ %4 = (%2)(%3)::NamedTuple{(:x,),Tuple{Float64}}
│ %5 = Main.nuts(Main.m, %4)::Soss.NUTS_result{_A} where _A
└── return %5
julia> f(1.2) |> typeof
MCMC, adapting ϵ (75 steps)
2.8e-5 s/step ...done
MCMC, adapting ϵ (25 steps)
3.2e-5 s/step ...done
MCMC, adapting ϵ (50 steps)
3.1e-5 s/step ...done
MCMC, adapting ϵ (100 steps)
2.8e-5 s/step ...done
MCMC, adapting ϵ (200 steps)
2.7e-5 s/step ...done
MCMC, adapting ϵ (400 steps)
2.2e-5 s/step ...done
MCMC, adapting ϵ (50 steps)
5.9e-5 s/step ...done
MCMC (1000 steps)
3.9e-5 s/step ...done
Soss.NUTS_result{NamedTuple{(:α,),Tuple{Float64}}}
Which takes me back to thinking I need to do some codegen tricks in order to know the type in advance