I had been building my log-density function like this:
julia> sourceLogdensity(linReg1D)
:(function ##logdensity#413(pars)
@unpack (α, β, σ, x, y) = pars
ℓ = 0.0
ℓ += logpdf(Cauchy(0, 10), α)
ℓ += logpdf(Cauchy(0, 2.5), β)
ℓ += logpdf(HalfCauchy(3), σ)
ŷ = α .+ β .* x
N = length(x)
ℓ += logpdf(For(1:N) do n
#= /home/chad/git/Soss.jl/src/examples.jl:86 =#
Normal(ŷ[n], σ)
end, y)
return ℓ
end)
In inference this goes in a tight loop, so it needs to be fast. But there’s not much type information to help it out - currently @code_warntype
complains about Any
. So I made it generate this:
julia> sourceLogdensity(linReg1D, (x=randn(10),y=randn(10)))
:(function ##logdensity#415(pars::NamedTuple{(:x, :y, :α, :β, :σ),Tuple{Array{Real,1},Array{Real,1},Real,Real,Real}})
@unpack (x, y, α, β, σ) = pars
ℓ = 0.0
ℓ += logpdf(Cauchy(0, 10), α)
ℓ += logpdf(Cauchy(0, 2.5), β)
ℓ += logpdf(HalfCauchy(3), σ)
ŷ = α .+ β .* x
N = length(x)
ℓ += logpdf(For(1:N) do n
#= /home/chad/git/Soss.jl/src/examples.jl:86 =#
Normal(ŷ[n], σ)
end, y)
return ℓ
end)
(note the type constraint on pars
)
Unfortunately this still doesn’t work, I think because NamedTuple
s aren’t covariant.
If it was just writing it once, I think I could get it working. But I need to write code to generate this type. Currently I’m doing this:
function realtypes(nt::Type{NamedTuple{S, T} } ) where {S, T}
NamedTuple{S, realtypes(T)}
end
realtypes(::Type{Tuple{A,B}} ) where {A,B} = Tuple{realtypes(A), realtypes(B)}
realtypes(::Type{Tuple{A,B,C}} ) where {A,B,C} = Tuple{realtypes(A), realtypes(B), realtypes(C)}
realtypes(::Type{Tuple{A,B,C,D}} ) where {A,B,C,D} = Tuple{realtypes(A), realtypes(B), realtypes(C), realtypes(D)}
realtypes(::Type{Tuple{A,B,C,D,E}} ) where {A,B,C,D,E} = Tuple{realtypes(A), realtypes(B), realtypes(C), realtypes(D), realtypes(E)}
realtypes(::Type{Array{T, N}}) where {T,N}= Array{realtypes(T),N}
realtypes(::Type{<: Real}) = Real
So I just take a sample from the model and get its type, then apply realtypes
and stick that in the code.
I think the solution is to turn the final Real
into T where {T}
, except that the where clause needs to propagate upward.
- Is there a better way to do this?
- Is there a way to do this?