On the Turing call this morning, I mentioned a need for a very efficient recursive merge on named tuples. @mohamed82008 asked where this need comes in, and… I guess it was too early? Coffee hadn’t kicked in? Something like that. Anyway, here it is:
Say you have models like this in Soss:
μdist = @model ν begin
s ~ Gamma(ν , ν)
z ~ Normal()
return sqrt(s)*z
end
σdist = @model begin
x ~ Normal()
return abs(x)
end
m = @model begin
μ ~ μdist(ν=1.0)
σ ~ σdist()
x ~ Normal(μ,σ) |> iid(10)
return x
end
with observations
julia> obs = (x=randn(2), μ = (z = 1.0,))
(x = [-1.2176456082647917, 0.22103067164326148], μ = (z = 1.0,))
For HMC, the transform we end up with is
julia> tr = xform(m() | obs)
TransformVariables.TransformTuple{NamedTuple{(:σ, :μ),Tuple{TransformVariables.TransformTuple{NamedTuple{(:x,),Tuple{TransformVariables.Identity}}},TransformVariables.TransformTuple{NamedTuple{(:s,),Tuple{TransformVariables.ShiftedExp{true,Float64}}}}}}}((σ = TransformVariables.TransformTuple{NamedTuple{(:x,),Tuple{TransformVariables.Identity}}}((x = asℝ,), 1), μ = TransformVariables.TransformTuple{NamedTuple{(:s,),Tuple{TransformVariables.ShiftedExp{true,Float64}}}}((s = asℝ₊,), 1)), 2)
julia> tr(randn(2))
(σ = (x = 0.8762773413553324,), μ = (s = 1.1202778056291636,))
The need for a merge comes in when we try to evaluate the log-density. We get some information from the observations, and the rest from the transformed values.
That’s not to say that this merging needs to happen at runtime. Another option would be to build a function at compile time, something like
f(arr) = ( μ =
( s = exp(arr[1])
, z = obs.μ.z
)
, σ = ( x = arr[2],)
, x = obs.x
)
As it is this is very slow, but I’m sure it (or something like it) can be much faster
julia> arr = randn(2)
2-element Array{Float64,1}:
-0.5330870568610766
0.6365326906148479
julia> @btime f($arr)
463.320 ns (12 allocations: 320 bytes)
(μ = (s = 0.5867907144635588, z = 1.0), σ = (x = 0.6365326906148479,), x = [-1.2176456082647917, 0.22103067164326148])