# Improve this function for merging tuples so that the return type could be accurately inferred

I implemented the function `merged_tuples` for merging tuples in given order while preserving intra-tuple order. It seems that Julia isn’t able to infer the return types accurately, however (using `Test.@inferred`):

``````julia> @inferred merged_tuples(Val((1, 2, 1, 1, 2, 2, 2, 1)), Val(((Val(1), Val(2), Val(3), Val(4)), (Val(5), Val(6), Val(7), Val(8)))))
ERROR: return type Tuple{Val{1}, Val{5}, Val{2}, Val{3}, Val{6}, Val{7}, Val{8}, Val{4}} does not match inferred return type Tuple{Val{1}, Val{5}, Val, Val, Val, Val, Val, Val}
``````

So it seems like the types of the first few elements get inferred accurately, but the others not so much. Perhaps this is because of the `count` call in `merged_tuples`.

This is my implementation, with a comment for documentation:

``````# Merges tuples while preserving intra-tuple order, with indices
# determining where each element goes (which element of the resultant
# tuple comes from which source tuple).
#
# For example, if indices is:
#  (1, 2, 1, 1, 2)
# and tuples is:
#  ((11, 12, 13), (21, 22))
# the result is:
#  (11, 21, 12, 13, 22)
function merged_tuples(::Val{indices}, ::Val{tuples}) where {indices, tuples}
(length(indices) == sum(length, tuples, init = 0)) || error("mismatched total size")
all(
pair -> ==(pair...),
zip(
ntuple(
let fi = firstindex(tuples)
i -> count(==(i + fi - 1), indices, init = 0)
end,
Val{length(tuples)}()),
map(length, tuples))) || error("mismatched size")

ntuple(
i ->
let j = indices[i]
tuples[j][count(==(j), indices[begin:(i - 1)], init = 0) + 1]
end,
Val{length(indices)}())
end
``````

I’m using Julia 1.8.0-rc3 BTW.

In case it’s not clear what `merged_tuples` does, here’s another implementation. It has much worse inference, though.

``````# A different implementation of merged_tuples.
function merged_tuples_d(::Val{indices}, ::Val{tuples}) where {indices, tuples}
local ret = Any[]
local inner_inds = zeros(length(tuples))
for i in indices
inner_inds[i] += 1
push!(ret, tuples[i][inner_inds[i]])
end
(ret...,)
end
``````

If you’re OK with using generated functions, you can easily turn your implementation into one:

``````@generated function merged_tuples_gen(
::Val{indices}, ::Val{tuples},
) where {indices, tuples}
(length(indices) == sum(length, tuples, init = 0)) || error("mismatched total size")
all(
pair -> ==(pair...),
zip(
ntuple(
let fi = firstindex(tuples)
i -> count(==(i + fi - 1), indices, init = 0)
end,
Val{length(tuples)}()),
map(length, tuples))) || error("mismatched size")

ret = ntuple(
i ->
let j = indices[i]
tuples[j][count(==(j), indices[begin:(i - 1)], init = 0) + 1]
end,
Val{length(indices)}())
:( \$ret )
end

julia> @code_warntype merged_tuples_gen(
Val((1, 2, 1, 1, 2, 2, 2, 1)),
Val((
(Val(1), Val(2), Val(3), Val(4)),
(Val(5), Val(6), Val(7), Val(8)),
))
)
MethodInstance for merged_tuples_gen(::Val{(1, 2, 1, 1, 2, 2, 2, 1)}, ::Val{((Val{1}(), Val{2}(), Val{3}(), Val{4}()), (Val{5}(), Val{6}(), Val{7}(), Val{8}()))})
from merged_tuples_gen(::Val{indices}, ::Val{tuples}) where {indices, tuples} in Main at /home/jipolanco/tmp/julia/merged_tuples.jl:75
Static Parameters
indices = (1, 2, 1, 1, 2, 2, 2, 1)
tuples = ((Val{1}(), Val{2}(), Val{3}(), Val{4}()), (Val{5}(), Val{6}(), Val{7}(), Val{8}()))
Arguments
#self#::Core.Const(merged_tuples_gen)
_::Core.Const(Val{(1, 2, 1, 1, 2, 2, 2, 1)}())
_::Core.Const(Val{((Val{1}(), Val{2}(), Val{3}(), Val{4}()), (Val{5}(), Val{6}(), Val{7}(), Val{8}()))}())
Body::Tuple{Val{1}, Val{5}, Val{2}, Val{3}, Val{6}, Val{7}, Val{8}, Val{4}}
1 ─     return (Val{1}(), Val{5}(), Val{2}(), Val{3}(), Val{6}(), Val{7}(), Val{8}(), Val{4}())
``````

EDIT: Alternatively, a recursive implementation along with the new `Base.@assume_effects` seems to also do the job:

``````function merged_tuples_rec(::Val{indices}, ::Val{tuples}) where {indices, tuples}
(length(indices) == sum(length, tuples, init = 0)) || error("mismatched total size")
_merged_tuples_rec((), Val(indices), Val(tuples))
end

Base.@assume_effects :foldable function _merged_tuples_rec(
ret::Tuple, ::Val{indices}, ::Val{tuples},
) where {indices, tuples}
i, inds... = indices
x, ti... = tuples[i]
merged = (ret..., x)
tups = Base.setindex(tuples, ti, i)
_merged_tuples_rec(merged, Val(inds), Val(tups))
end

_merged_tuples_rec(ret::Tuple, ::Val{()}, ::Val) = ret

julia> @code_warntype merged_tuples_rec(
Val((1, 2, 1, 1, 2, 2, 2, 1)),
Val((
(Val(1), Val(2), Val(3), Val(4)),
(Val(5), Val(6), Val(7), Val(8)),
))
)
MethodInstance for merged_tuples_rec(::Val{(1, 2, 1, 1, 2, 2, 2, 1)}, ::Val{((Val{1}(), Val{2}(), Val{3}(), Val{4}()), (Val{5}(), Val{6}(), Val{7}(), Val{8}()))})
from merged_tuples_rec(::Val{indices}, ::Val{tuples}) where {indices, tuples} in Main at /home/jipolanco/tmp/julia/merged_tuples.jl:100
Static Parameters
indices = (1, 2, 1, 1, 2, 2, 2, 1)
tuples = ((Val{1}(), Val{2}(), Val{3}(), Val{4}()), (Val{5}(), Val{6}(), Val{7}(), Val{8}()))
Arguments
#self#::Core.Const(merged_tuples_rec)
_::Core.Const(Val{(1, 2, 1, 1, 2, 2, 2, 1)}())
_::Core.Const(Val{((Val{1}(), Val{2}(), Val{3}(), Val{4}()), (Val{5}(), Val{6}(), Val{7}(), Val{8}()))}())
Body::Tuple{Val{1}, Val{5}, Val{2}, Val{3}, Val{6}, Val{7}, Val{8}, Val{4}}
1 ─ %1  = Main.length(\$(Expr(:static_parameter, 1)))::Core.Const(8)
│   %2  = (:init,)::Core.Const((:init,))
│   %3  = Core.apply_type(Core.NamedTuple, %2)::Core.Const(NamedTuple{(:init,)})
│   %4  = Core.tuple(0)::Core.Const((0,))
│   %5  = (%3)(%4)::Core.Const((init = 0,))
│   %6  = Core.kwfunc(Main.sum)::Core.Const(Base.var"#sum##kw"())
│   %7  = (%6)(%5, Main.sum, Main.length, \$(Expr(:static_parameter, 2)))::Core.Const(8)
│   %8  = (%1 == %7)::Core.Const(true)
│         Core.typeassert(%8, Core.Bool)
└──       goto #3
2 ─       Core.Const(:(Main.error("mismatched total size")))
3 ┄ %12 = ()::Core.Const(())
│   %13 = Main.Val(\$(Expr(:static_parameter, 1)))::Core.Const(Val{(1, 2, 1, 1, 2, 2, 2, 1)}())
│   %14 = Main.Val(\$(Expr(:static_parameter, 2)))::Core.Const(Val{((Val{1}(), Val{2}(), Val{3}(), Val{4}()), (Val{5}(), Val{6}(), Val{7}(), Val{8}()))}())
│   %15 = Main._merged_tuples_rec(%12, %13, %14)::Core.Const((Val{1}(), Val{5}(), Val{2}(), Val{3}(), Val{6}(), Val{7}(), Val{8}(), Val{4}()))
└──       return %15
``````
1 Like