Type inference: map + accumulate tuple recursively

question

#1

I have a function f that returns two results: the first one maps the argument to something, the second is a scalar. For the MWE, let

f(x::Real) = x + 1, abs2(x)
f(x::AbstractArray) = x .+ 1, sum(abs2, x)

Please treat these as a black box otherwise.

I would like to define an f(::Tuple) that calls f on each field, returns a tuple of the first values and the sum of the second values. I would like the compiler to be able to infer the return type. For example,

julia> f((-1, [-2], [-3.0, -4.0]))
((0, [-1], [-2.0, -3.0]), 30.0)

I tried

function _f(acc, ys, x, xs...)
    y, s = f(x)
    _f(acc + s, (ys..., y), xs...)
end

_f(acc, ys) = ys, acc

f(x::Tuple) = _f(0, (), x...)

but

julia> @code_warntype f((-1, [-2], [-3.0, -4.0]))
Body::Tuple{Tuple,Float64}
 1 1 ─ %1 = (getfield)(x, 1)::Int64                                                        │  
   │   %2 = (getfield)(x, 2)::Array{Int64,1}                                               │  
   │   %3 = (getfield)(x, 3)::Array{Float64,1}                                             │  
   │   %4 = (Base.add_int)(%1, 1)::Int64                                                   │╻╷ _f
   │   %5 = (Base.mul_int)(%1, %1)::Int64                                                  ││╻  f
   │   %6 = (Base.add_int)(0, %5)::Int64                                                   ││╻  +
   │   %7 = (Core.tuple)(%4)::Tuple{Int64}                                                 ││ 
   │   %8 = invoke Main._f(%6::Int64, %7::Tuple{Int64}, %2::Array{Int64,1}, %3::Array{Float64,1})::Tuple{Tuple,Float64}
   └──      return %8                                                                      │  

julia> VERSION
v"1.1.0-DEV.671"

#2

Also, just to clarify: I know how to use a generated function, I am asking for a recursive solution.


#3

This works:

function _f2(x, xs...)
    y, s = f(x)
    ys, ss = _f2(xs...)
    (y, ys...), s + ss
end

_f2() = (), 0

f(x::Tuple) = _f2(x...)

but why doesn’t the first one?


#4

I’m guessing because in the second version it’s easier for the compiler to prove that the recursion is finite.

This reminded me of Jameson’s post here: Efficient tuple concatenation and explanation here: Efficient tuple concatenation