This question is an adapted version of what appeared here in the JuliaLang Zulip helpdesk.
Suppose I have a function that takes in a heterogeneously typed Tuple and will return a slice of that tuple where the slice indices may be statically inferred from only type information. How can I write my function in such a way that the output type is correctly inferred?
For example, suppose my function is
function f(t::Tuple, A::Array{T, N}) where {T, N}
if T <: AbstractFloat
imin = 1
elseif T <: Integer
imin = 2
else
imin = 3
end
imax = N+2
t[imin:imax]
end
we see that type inference only figures out that this produces a Tuple, not itβs length or element types even though all the needed information is available at compile time?
julia> let t = (:a, "b", 2, 3.0, Val(1), 2+im), A = rand(Int, 3,3)
Base.return_types(f, Tuple{typeof(t), typeof(A)})
end
1-element Array{Any,1}:
Tuple
The strategy Iβd be most comfortable with (but maybe thereβs an easier way?) is to write a @generated function to manually ensure julia does the type level operations I want at compile time:
@generated function f2(t::Tuple, A::Array{T, N}) where {T, N}
if T <: AbstractFloat
imin = 1
elseif T <: Integer
imin = 2
else
imin = 3
end
imax = N+2
out_expr = Expr(:tuple, (:(t[$i]) for i β imin:imax)...)
end
The idea here is that in the generated function body, at compile time, we determine what imin and imax are, and then we manually build out an expression for our function body that reads (t[imin], t[imin+1], ..., t[imax-1], t[imax]).
For whatever reasons, julia is better able to reason about a sequence of getindex(::Tuple, ::Int) than it is about slicing a tuple, even with a statically known slice, so by manually building this expression, the compiler is able to do what we want:
julia> let t = (:a, "b", 2, 3.0, Val(1), 2+im), A = rand(Int, 3,3)
Base.return_types(f2, Tuple{typeof(t), typeof(A)})
end
1-element Array{Any,1}:
Tuple{String,Int64,Float64}
Voila, the inferred output type is aTuple of length 3 whose elements are statically known to be a String, and Int and Float64!
Oh interesting! I tried something very similar initially (almost the identical code) and inference gave up:
function f4(t::Tuple, A::Array{T,N}) where {T,N}
if T <: AbstractFloat
imin = 1
elseif T <: Integer
imin = 2
else
imin = 3
end
imax = N+2
ntuple(i -> t[imin+i-1], imax-imin+1)
end
julia> let t = (:a, "b", 2, 3.0, Val(1), 2+im), A = rand(Int, 3,3)
@show Base.return_types(f4, Tuple{typeof(t), typeof(A)})
@btime f4(t, A) setup=(t = (:a, "b", 2, 3.0, Val(1), 2+im); A = rand(Int, 3,3);)
end
Base.return_types(f4, Tuple{typeof(t), typeof(A)}) = Any[Any]
1.480 ΞΌs (4 allocations: 96 bytes)
("b", 2, 3.0)
whereas your code does indeed work:
julia> let t = (:a, "b", 2, 3.0, Val(1), 2+im), A = rand(Int, 3,3)
@show Base.return_types(f3, Tuple{typeof(t), typeof(A)})
@btime f3(t, A) setup=(t = (:a, "b", 2, 3.0, Val(1), 2+im); A = rand(Int, 3,3);)
end
Base.return_types(f3, Tuple{typeof(t), typeof(A)}) = Any[Tuple{String,Int64,Float64}]
5.480 ns (1 allocation: 32 bytes)
("b", 2, 3.0)
That slight instability towards changes of the source making inference give up makes me want to lean towards generated functions.
Iβd be curious if the handful of special-cased unitranges in #31138 would be sufficient for the original questioner. In short, it special-cases t[1:end], t[2:end] , t[3:end] , t[1:end-1] , and t[1:end-2] such that those are inferred.
julia> function f(t::Tuple, A::Array{T, N}) where {T, N}
if T <: AbstractFloat
imin = static(1)
elseif T <: Integer
imin = static(2)
else
imin = static(3)
end
imax = static(N+2)
t[imin:imax]
end
f (generic function with 1 method)
julia> Base.return_types(f, Tuple{typeof((:a, "b", 2, 3.0, Val(1), 2+im)), Array{Int,4}})
1-element Array{Any,1}:
Tuple{String,Int64,Float64,Val{1},Complex{Int64}}