Improve this function for merging tuples so that the return type could be accurately inferred

I implemented the function merged_tuples for merging tuples in given order while preserving intra-tuple order. It seems that Julia isn’t able to infer the return types accurately, however (using Test.@inferred):

julia> @inferred merged_tuples(Val((1, 2, 1, 1, 2, 2, 2, 1)), Val(((Val(1), Val(2), Val(3), Val(4)), (Val(5), Val(6), Val(7), Val(8)))))
ERROR: return type Tuple{Val{1}, Val{5}, Val{2}, Val{3}, Val{6}, Val{7}, Val{8}, Val{4}} does not match inferred return type Tuple{Val{1}, Val{5}, Val, Val, Val, Val, Val, Val}

So it seems like the types of the first few elements get inferred accurately, but the others not so much. Perhaps this is because of the count call in merged_tuples.

This is my implementation, with a comment for documentation:

# Merges tuples while preserving intra-tuple order, with indices
# determining where each element goes (which element of the resultant
# tuple comes from which source tuple).
#
# For example, if indices is:
#  (1, 2, 1, 1, 2)
# and tuples is:
#  ((11, 12, 13), (21, 22))
# the result is:
#  (11, 21, 12, 13, 22)
function merged_tuples(::Val{indices}, ::Val{tuples}) where {indices, tuples}
  (length(indices) == sum(length, tuples, init = 0)) || error("mismatched total size")
  all(
    pair -> ==(pair...),
    zip(
      ntuple(
        let fi = firstindex(tuples)
          i -> count(==(i + fi - 1), indices, init = 0)
        end,
        Val{length(tuples)}()),
      map(length, tuples))) || error("mismatched size")

  ntuple(
    i ->
      let j = indices[i]
        tuples[j][count(==(j), indices[begin:(i - 1)], init = 0) + 1]
      end,
    Val{length(indices)}())
end

I’m using Julia 1.8.0-rc3 BTW.

In case it’s not clear what merged_tuples does, here’s another implementation. It has much worse inference, though.

# A different implementation of merged_tuples.
function merged_tuples_d(::Val{indices}, ::Val{tuples}) where {indices, tuples}
  local ret = Any[]
  local inner_inds = zeros(length(tuples))
  for i in indices
    inner_inds[i] += 1
    push!(ret, tuples[i][inner_inds[i]])
  end
  (ret...,)
end

If you’re OK with using generated functions, you can easily turn your implementation into one:

@generated function merged_tuples_gen(
        ::Val{indices}, ::Val{tuples},
    ) where {indices, tuples}
    (length(indices) == sum(length, tuples, init = 0)) || error("mismatched total size")
    all(
      pair -> ==(pair...),
      zip(
        ntuple(
          let fi = firstindex(tuples)
            i -> count(==(i + fi - 1), indices, init = 0)
          end,
          Val{length(tuples)}()),
        map(length, tuples))) || error("mismatched size")

    ret = ntuple(
      i ->
        let j = indices[i]
          tuples[j][count(==(j), indices[begin:(i - 1)], init = 0) + 1]
        end,
      Val{length(indices)}())
    :( $ret )
end


julia> @code_warntype merged_tuples_gen(
           Val((1, 2, 1, 1, 2, 2, 2, 1)),
           Val((
               (Val(1), Val(2), Val(3), Val(4)),
               (Val(5), Val(6), Val(7), Val(8)),
           ))
       )
MethodInstance for merged_tuples_gen(::Val{(1, 2, 1, 1, 2, 2, 2, 1)}, ::Val{((Val{1}(), Val{2}(), Val{3}(), Val{4}()), (Val{5}(), Val{6}(), Val{7}(), Val{8}()))})
  from merged_tuples_gen(::Val{indices}, ::Val{tuples}) where {indices, tuples} in Main at /home/jipolanco/tmp/julia/merged_tuples.jl:75
Static Parameters
  indices = (1, 2, 1, 1, 2, 2, 2, 1)
  tuples = ((Val{1}(), Val{2}(), Val{3}(), Val{4}()), (Val{5}(), Val{6}(), Val{7}(), Val{8}()))
Arguments
  #self#::Core.Const(merged_tuples_gen)
  _::Core.Const(Val{(1, 2, 1, 1, 2, 2, 2, 1)}())
  _::Core.Const(Val{((Val{1}(), Val{2}(), Val{3}(), Val{4}()), (Val{5}(), Val{6}(), Val{7}(), Val{8}()))}())
Body::Tuple{Val{1}, Val{5}, Val{2}, Val{3}, Val{6}, Val{7}, Val{8}, Val{4}}
1 ─     return (Val{1}(), Val{5}(), Val{2}(), Val{3}(), Val{6}(), Val{7}(), Val{8}(), Val{4}())

EDIT: Alternatively, a recursive implementation along with the new Base.@assume_effects seems to also do the job:

function merged_tuples_rec(::Val{indices}, ::Val{tuples}) where {indices, tuples}
    (length(indices) == sum(length, tuples, init = 0)) || error("mismatched total size")
    _merged_tuples_rec((), Val(indices), Val(tuples))
end

Base.@assume_effects :foldable function _merged_tuples_rec(
        ret::Tuple, ::Val{indices}, ::Val{tuples},
    ) where {indices, tuples}
    i, inds... = indices
    x, ti... = tuples[i]
    merged = (ret..., x)
    tups = Base.setindex(tuples, ti, i)
    _merged_tuples_rec(merged, Val(inds), Val(tups))
end

_merged_tuples_rec(ret::Tuple, ::Val{()}, ::Val) = ret


julia> @code_warntype merged_tuples_rec(
           Val((1, 2, 1, 1, 2, 2, 2, 1)),
           Val((
               (Val(1), Val(2), Val(3), Val(4)),
               (Val(5), Val(6), Val(7), Val(8)),
           ))
       )
MethodInstance for merged_tuples_rec(::Val{(1, 2, 1, 1, 2, 2, 2, 1)}, ::Val{((Val{1}(), Val{2}(), Val{3}(), Val{4}()), (Val{5}(), Val{6}(), Val{7}(), Val{8}()))})
  from merged_tuples_rec(::Val{indices}, ::Val{tuples}) where {indices, tuples} in Main at /home/jipolanco/tmp/julia/merged_tuples.jl:100
Static Parameters
  indices = (1, 2, 1, 1, 2, 2, 2, 1)
  tuples = ((Val{1}(), Val{2}(), Val{3}(), Val{4}()), (Val{5}(), Val{6}(), Val{7}(), Val{8}()))
Arguments
  #self#::Core.Const(merged_tuples_rec)
  _::Core.Const(Val{(1, 2, 1, 1, 2, 2, 2, 1)}())
  _::Core.Const(Val{((Val{1}(), Val{2}(), Val{3}(), Val{4}()), (Val{5}(), Val{6}(), Val{7}(), Val{8}()))}())
Body::Tuple{Val{1}, Val{5}, Val{2}, Val{3}, Val{6}, Val{7}, Val{8}, Val{4}}
1 ─ %1  = Main.length($(Expr(:static_parameter, 1)))::Core.Const(8)
│   %2  = (:init,)::Core.Const((:init,))
│   %3  = Core.apply_type(Core.NamedTuple, %2)::Core.Const(NamedTuple{(:init,)})
│   %4  = Core.tuple(0)::Core.Const((0,))
│   %5  = (%3)(%4)::Core.Const((init = 0,))
│   %6  = Core.kwfunc(Main.sum)::Core.Const(Base.var"#sum##kw"())
│   %7  = (%6)(%5, Main.sum, Main.length, $(Expr(:static_parameter, 2)))::Core.Const(8)
│   %8  = (%1 == %7)::Core.Const(true)
│         Core.typeassert(%8, Core.Bool)
└──       goto #3
2 ─       Core.Const(:(Main.error("mismatched total size")))
3 ┄ %12 = ()::Core.Const(())
│   %13 = Main.Val($(Expr(:static_parameter, 1)))::Core.Const(Val{(1, 2, 1, 1, 2, 2, 2, 1)}())
│   %14 = Main.Val($(Expr(:static_parameter, 2)))::Core.Const(Val{((Val{1}(), Val{2}(), Val{3}(), Val{4}()), (Val{5}(), Val{6}(), Val{7}(), Val{8}()))}())
│   %15 = Main._merged_tuples_rec(%12, %13, %14)::Core.Const((Val{1}(), Val{5}(), Val{2}(), Val{3}(), Val{6}(), Val{7}(), Val{8}(), Val{4}()))
└──       return %15
1 Like