Recursive merge for named tuples

On the Turing call this morning, I mentioned a need for a very efficient recursive merge on named tuples. @mohamed82008 asked where this need comes in, and… I guess it was too early? Coffee hadn’t kicked in? Something like that. Anyway, here it is:

Say you have models like this in Soss:

μdist = @model ν begin
    s ~ Gamma(ν , ν)
    z ~ Normal()
    return sqrt(s)*z

σdist = @model begin
    x ~ Normal()
    return abs(x)

m = @model begin
    μ ~ μdist(ν=1.0)
    σ ~ σdist()
    x ~ Normal(μ,σ) |> iid(10)
    return x

with observations

julia> obs = (x=randn(2), μ = (z = 1.0,))
(x = [-1.2176456082647917, 0.22103067164326148], μ = (z = 1.0,))

For HMC, the transform we end up with is

julia> tr = xform(m() | obs)
TransformVariables.TransformTuple{NamedTuple{(:σ, :μ),Tuple{TransformVariables.TransformTuple{NamedTuple{(:x,),Tuple{TransformVariables.Identity}}},TransformVariables.TransformTuple{NamedTuple{(:s,),Tuple{TransformVariables.ShiftedExp{true,Float64}}}}}}}((σ = TransformVariables.TransformTuple{NamedTuple{(:x,),Tuple{TransformVariables.Identity}}}((x = asℝ,), 1), μ = TransformVariables.TransformTuple{NamedTuple{(:s,),Tuple{TransformVariables.ShiftedExp{true,Float64}}}}((s = asℝ₊,), 1)), 2)

julia> tr(randn(2))
(σ = (x = 0.8762773413553324,), μ = (s = 1.1202778056291636,))

The need for a merge comes in when we try to evaluate the log-density. We get some information from the observations, and the rest from the transformed values.

That’s not to say that this merging needs to happen at runtime. Another option would be to build a function at compile time, something like

f(arr) = ( μ = 
        ( s = exp(arr[1])
        , z = obs.μ.z
    , σ = ( x = arr[2],)
    , x = obs.x

As it is this is very slow, but I’m sure it (or something like it) can be much faster

julia> arr = randn(2)
2-element Array{Float64,1}:

julia> @btime f($arr)
  463.320 ns (12 allocations: 320 bytes)
(μ = (s = 0.5867907144635588, z = 1.0), σ = (x = 0.6365326906148479,), x = [-1.2176456082647917, 0.22103067164326148])

NamedTupleTools.jl has a merge and a recursive_merge ( I cannot tell from your note if this is the behavior you need, or whether it operates at the speed you requre.


Thanks @JeffreySarnoff,

Yes, I’m a big fan of NamedTupleTools. merge_recursive is useful, but I think we can make it faster. For example,

julia> x
(d = (b = :b,), c = (e = :e, f = :f))

julia> y
(c = (f = :f, e = :e), e = (a = :a, b = :b))

julia> @code_warntype merge_recursive(x,y)
  #self#::Core.Compiler.Const(NamedTupleTools.merge_recursive, false)
  nt1::NamedTuple{(:d, :c),Tuple{NamedTuple{(:b,),Tuple{Symbol}},NamedTuple{(:e, :f),Tuple{Symbol,Symbol}}}}
  nt2::NamedTuple{(:c, :e),Tuple{NamedTuple{(:f, :e),Tuple{Symbol,Symbol}},NamedTuple{(:a, :b),Tuple{Symbol,Symbol}}}}
  #5::NamedTupleTools.var"#5#6"{NamedTuple{(:d, :c),Tuple{NamedTuple{(:b,),Tuple{Symbol}},NamedTuple{(:e, :f),Tuple{Symbol,Symbol}}}},NamedTuple{(:c, :e),Tuple{NamedTuple{(:f, :e),Tuple{Symbol,Symbol}},NamedTuple{(:a, :b),Tuple{Symbol,Symbol}}}}}
  gen::Base.Generator{Array{Symbol,1},NamedTupleTools.var"#5#6"{NamedTuple{(:d, :c),Tuple{NamedTuple{(:b,),Tuple{Symbol}},NamedTuple{(:e, :f),Tuple{Symbol,Symbol}}}},NamedTuple{(:c, :e),Tuple{NamedTuple{(:f, :e),Tuple{Symbol,Symbol}},NamedTuple{(:a, :b),Tuple{Symbol,Symbol}}}}}}

1 ─ %1  = NamedTupleTools.keys(nt1)::Core.Compiler.Const((:d, :c), false)
│   %2  = NamedTupleTools.keys(nt2)::Core.Compiler.Const((:c, :e), false)
│         (all_keys = NamedTupleTools.union(%1, %2))
│   %4  = Base.Generator::Core.Compiler.Const(Base.Generator, false)
│   %5  = NamedTupleTools.:(var"#5#6")::Core.Compiler.Const(NamedTupleTools.var"#5#6", false)
│   %6  = Core.typeof(nt1)::Core.Compiler.Const(NamedTuple{(:d, :c),Tuple{NamedTuple{(:b,),Tuple{Symbol}},NamedTuple{(:e, :f),Tuple{Symbol,Symbol}}}}, false)
│   %7  = Core.typeof(nt2)::Core.Compiler.Const(NamedTuple{(:c, :e),Tuple{NamedTuple{(:f, :e),Tuple{Symbol,Symbol}},NamedTuple{(:a, :b),Tuple{Symbol,Symbol}}}}, false)
│   %8  = Core.apply_type(%5, %6, %7)::Core.Compiler.Const(NamedTupleTools.var"#5#6"{NamedTuple{(:d, :c),Tuple{NamedTuple{(:b,),Tuple{Symbol}},NamedTuple{(:e, :f),Tuple{Symbol,Symbol}}}},NamedTuple{(:c, :e),Tuple{NamedTuple{(:f, :e),Tuple{Symbol,Symbol}},NamedTuple{(:a, :b),Tuple{Symbol,Symbol}}}}}, false)
│         (#5 = %new(%8, nt1, nt2))
│   %10 = #5::NamedTupleTools.var"#5#6"{NamedTuple{(:d, :c),Tuple{NamedTuple{(:b,),Tuple{Symbol}},NamedTuple{(:e, :f),Tuple{Symbol,Symbol}}}},NamedTuple{(:c, :e),Tuple{NamedTuple{(:f, :e),Tuple{Symbol,Symbol}},NamedTuple{(:a, :b),Tuple{Symbol,Symbol}}}}}
│         (gen = (%4)(%10, all_keys))
│   %12 = Base.NamedTuple()::Core.Compiler.Const(NamedTuple(), false)
│   %13 = Base.merge(%12, gen)::NamedTuple
└──       return %13

That %13 = Base.merge(%12, gen)::NamedTuple is dynamic, so there’s some slowdown there.

Also, my use case is a little different:

  • For a given pair of types, the recursive merge will happen many times, so it’s well worth a little compilation overhead (generated functions)
  • NamedTupleTools is fairly low-level, so I’d guess you’re avoiding having too many dependencies. But I’m assuming some things like GeneralizedGenerated.jl and Accessors.jl will already be needed.
  • In my context it’s reasonable to first recursively sort the keys. So even without generated functions, a linear walk down the keys should be very quick.

Actually, on that last point, taking a union of keys slows things down quite a bit:

julia> @btime union(keys($x), keys($y))
  194.821 ns (6 allocations: 608 bytes)
3-element Array{Symbol,1}:

So there’s still some low-hanging fruit:

@generated function keyunion(x::NamedTuple{Kx, Tx},y::NamedTuple{Ky,Ty}) where {Kx,Tx,Ky,Ty}

julia> @btime keyunion($x,$y)
  1.152 ns (0 allocations: 0 bytes)
3-element Array{Symbol,1}:
1 Like

thanks for the help

Thank you! It’s really useful to have libraries to collect utility functions, rather than everyone re-implementing the same thing.

Maybe we should update NTT to use keyunion. Probably best to have it as an if @generated so the compiler can choose. And I’ll keep it in mind in case I end up with an approach that doesn’t need other dependencies :slight_smile:

please take a moment to look at last night’s
NamedTupleTools (@generated branch)

Thanks for the pointer @JeffreySarnoff

This stuff gets really tricky to me, there are so many subtleties. I guess the important thing is to identify places where the compiler is really struggling, and generate simpler code to make its job easier.

For example, in the keyunion code, we might instead just do union(keys(x), keys(y)). That builds an array, and the compiler can’t reduce it too much:

julia> x = (a=1, b=2)
(a = 1, b = 2)

julia> y = (b=3, c=4)
(b = 3, c = 4)

julia> @code_typed union(keys(x), keys(y))
1 ─ %1 = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Symbol,1}, svec(Any, Int64), 0, :(:ccall), Array{Symbol,1}, 0, 0))::Array{Symbol,1}
│   %2 = (getfield)(sets, 1)::Tuple{Symbol,Symbol}
│   %3 = invoke Base.union!(%1::Array{Symbol,1}, _2::Tuple{Symbol,Symbol}, %2::Tuple{Symbol,Symbol})::Array{Symbol,1}
└──      return %3
) => Array{Symbol,1}

I think part of the weirdness here comes from the dynamic types. So in this case we can generate code to return a tuple, and it should do even better.

julia> @generated function keyunion(x::NamedTuple{Kx, Tx},y::NamedTuple{Ky,Ty}) where {Kx,Tx,Ky,Ty}

julia> @code_typed keyunion(x,y)
1 ─     return (:a, :b, :c)
) => Tuple{Symbol,Symbol,Symbol}

Can’t do much better than that :slight_smile:

But then sometimes the compiler has no trouble, e.g.

julia> fieldnames1(x::NamedTuple{N}) where {N} = N
fieldnames1 (generic function with 1 method)

julia> @code_llvm fieldnames1((a=1,b=2,c=3))
;  @ REPL[60]:1 within `fieldnames1'
define void @julia_fieldnames1_4718([3 x %jl_value_t*]* noalias nocapture sret %0, [3 x i64]* nocapture nonnull readonly dereferenceable(24) %1) {
  %2 = bitcast [3 x %jl_value_t*]* %0 to i8*
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* nonnull align 8 dereferenceable(24) %2, i8* nonnull align 16 dereferenceable(24) inttoptr (i64 140503769783536 to i8*), i64 24, i1 false)
  ret void

julia> @generated function fieldnames2(x::T) where {N,S, T<:NamedTuple{N,S}}
fieldnames2 (generic function with 1 method)

julia> @code_llvm fieldnames2((a=1,b=2,c=3))
;  @ REPL[62]:1 within `fieldnames2'
define void @julia_fieldnames2_4720([3 x %jl_value_t*]* noalias nocapture sret %0, [3 x i64]* nocapture nonnull readonly dereferenceable(24) %1) {
; ┌ @ REPL[62]:1 within `macro expansion'
   %2 = bitcast [3 x %jl_value_t*]* %0 to i8*
   call void @llvm.memcpy.p0i8.p0i8.i64(i8* nonnull align 8 dereferenceable(24) %2, i8* nonnull align 16 dereferenceable(24) inttoptr (i64 140503769783536 to i8*), i64 24, i1 false)
   ret void
; └

Again, there are lots of subtleties here, honestly I’m still tuning my mental model.

There were Julia versions where doing the obvious with functional handling of type parameters could be less performant. I am taking license with your more developed example – this is my first use of @generated, so I am starting with the simplest to get it right. I do intend to revert to the more straightforward coding for fieldnames and the internal field_types once your approach is properly made a part of this new branch.

I found @generated very weird to get used to. Here’s a little example that would have helped me:

@generated function f(x)
    y = g(x)
        h(x, $y)

In this case,

  • The x in g(x) refers to typeof(x)
  • The x inside the quote refers to the argument value x (this is the only place it’s visible)
  • The x in the quote does not need to be interpolated, but
  • The y does (because it was constructed statically)

@JeffreySarnoff I’ve added some issues for adding functions I’ve found useful for building and testing generated functions on named tuples.

I think schema in particular is big enough to warrant its own file. Maybe we can discuss in the package issues how you’d like to structure things.

I figured out a way to do this with merge_recursive at compile time. Check it out, say you have

julia> x = (a = (b = 1, c = 2), d = 3)
(a = (b = 1, c = 2), d = 3)

julia> y = (a = (d = 3,),e = (f = 1, g = 2))
(a = (d = 3,), e = (f = 1, g = 2))

Then we now have

julia> f = leaf_setter(x,y)
function = (x, y;) -> begin
        (var"##259", var"##260", var"##261") = x
        (var"##262", var"##263", var"##264") = y
        return (a = (b = var"##259", c = var"##260", d = var"##262"), d = var"##261", e = (f = var"##263", g = var"##264"))

so, for example,

julia> @btime $f((1,2,3),(4,5,6))
  0.010 ns (0 allocations: 0 bytes)
(a = (b = 1, c = 2, d = 4), d = 3, e = (f = 5, g = 6))

I think this fits our current use case really well. In PPL, some of the arguments will be static, coming from the observed data. This could go where the (1,2,3) currently sits. Then for any proposal, say (4,5,6) for the remaining variables, we need to be able to very quickly construct the named tuple for evaluation.

@cscherrer do you have a proposed revision for merge_recursive? I have cleaned up the intended revision enough to work on this now.

The generated code version of merge_recursive is very fast when there are few enough leaves, but at a point it gets very slow (over 1 microsecond). I’m guessing this is because of the compiler’s recursion limit.

For now, I’ve added a LazyMerge struct to NestedTuples.jl that does this one layer at a time (and NestedTuples will be registered soon). But it’s no longer a NamedTuple, kind of an AbstractNamedTuple if there were such a thing :slight_smile:

There’s some functionality of NestedTuples that stays in NamedTuple land and has no heavy dependencies. I’ve added some issues to NamedTupleTools to transition these in case they’re more useful to others.