Merging disjoint NamedTuples

I need a function that merges two NamedTuples, but throws an error if they have elements in common. Not unlike

@generated function merge_disjoint(a::NamedTuple{A}, b::NamedTuple{B}) where {A,B}
    AB = intersect(A, B)
    if isempty(AB)
        :(merge(a, b))
    else
        msg = "found common names $(join(AB, ", "))"
        :(throw(ArgumentError($msg)))
    end
end
julia> ab = (a= 1, b = 2)
(a = 1, b = 2)

julia> merge_disjoint(ab, ab)
ERROR: ArgumentError: found common names a, b
Stacktrace:
 [1] macro expansion
   @ ./REPL[63]:1 [inlined]
 [2] merge_disjoint(a::@NamedTuple{a::Int64, b::Int64}, b::@NamedTuple{a::Int64, b::Int64})
   @ Main ./REPL[63]:1
 [3] top-level scope
   @ REPL[71]:1

julia> merge_disjoint(ab, (c = 1, d = 2))
(a = 1, b = 2, c = 1, d = 2)

I am wondering if there anything simpler than the above.

My attempt is probably too simple… given that you use a macro, etc.

function merge_disjoint(x,y)
    x_k = keys(x)
    y_k = keys(y)
    if length(intersect(x_k,y_k)) == 0
        merge(x,y)
    else
        println("Shared names")
    end
end

What about

julia> function mergeDisjoint(a,b)
       res = merge(a,b)
       length(res) == length(a) + length(b) ? res : throw("oops")
       end

Relevant parts:

  1. You don’t need to be @generated – merge itself is already generated
  2. You don’t need to compute the intersection, you can simply check lengths
  3. Method specialization means that the issue is visible from the stacktrace anyways:
julia> mergeDisjoint((;a=1, b=2), (;c=3, b=1))
ERROR: "oops"
Stacktrace:
 [1] mergeDisjoint(a::@NamedTuple{a::Int64, b::Int64}, b::@NamedTuple{c::Int64, b::Int64})
   @ Main ./REPL[39]:3
 [2] top-level scope
   @ REPL[41]:1
  1. This solution correctly infers (to either Union{} in case of error, or the correct return-value).

Thanks! I will use that as a solution, slightly tweaked as I want an informative error message:

function merge_disjoint(a::NamedTuple, b::NamedTuple)
    ab = merge(a, b)
    if length(a) + length(b) ≠ length(ab)
        throw(ArgumentError("found common keys " *
            join(intersect(keys(a), keys(b)), ", ")))
    end
    ab
end