Avoiding type instability when slicing a tuple

This question is an adapted version of what appeared here in the JuliaLang Zulip helpdesk.


Suppose I have a function that takes in a heterogeneously typed Tuple and will return a slice of that tuple where the slice indices may be statically inferred from only type information. How can I write my function in such a way that the output type is correctly inferred?

For example, suppose my function is

function f(t::Tuple, A::Array{T, N}) where {T, N}
    if T <: AbstractFloat
        imin = 1
    elseif T <: Integer
        imin = 2
    else
        imin = 3
    end
    imax = N+2    
    t[imin:imax]
end 

we see that type inference only figures out that this produces a Tuple, not it’s length or element types even though all the needed information is available at compile time?

julia> let t = (:a, "b", 2, 3.0, Val(1), 2+im), A = rand(Int, 3,3)
           Base.return_types(f, Tuple{typeof(t), typeof(A)})
       end
1-element Array{Any,1}:
 Tuple

How can I write f such that this works?

1 Like

The strategy I’d be most comfortable with (but maybe there’s an easier way?) is to write a @generated function to manually ensure julia does the type level operations I want at compile time:

@generated function f2(t::Tuple, A::Array{T, N}) where {T, N}
    if T <: AbstractFloat
        imin = 1
    elseif T <: Integer
        imin = 2
    else
        imin = 3
    end
    imax = N+2
    out_expr = Expr(:tuple, (:(t[$i]) for i ∈ imin:imax)...)
end 

The idea here is that in the generated function body, at compile time, we determine what imin and imax are, and then we manually build out an expression for our function body that reads (t[imin], t[imin+1], ..., t[imax-1], t[imax]).

For whatever reasons, julia is better able to reason about a sequence of getindex(::Tuple, ::Int) than it is about slicing a tuple, even with a statically known slice, so by manually building this expression, the compiler is able to do what we want:

julia> let t = (:a, "b", 2, 3.0, Val(1), 2+im), A = rand(Int, 3,3)
           Base.return_types(f2, Tuple{typeof(t), typeof(A)})
       end
1-element Array{Any,1}:
 Tuple{String,Int64,Float64}

Voila, the inferred output type is aTuple of length 3 whose elements are statically known to be a String, and Int and Float64!

1 Like

Here is one way to do this without @generated, relying on the good behaviour of ntuple in such situations:

julia> function f3(t::Tuple, A::Array{T,N}) where {T,N}
       min = T <: AbstractFloat ? 1 :
           T <: Integer ? 2 : 3
       max = N+2
       ntuple(i -> t[min+i-1], max-min+1)
       end

julia> @btime f2(t,A) setup=(t = (:a, "b", 2, 3.0, Val(1), 2+im); A = rand(Int, 3,3);)
  8.257 ns (1 allocation: 32 bytes)
("b", 2, 3.0)

julia> @btime f3(t,A) setup=(t = (:a, "b", 2, 3.0, Val(1), 2+im); A = rand(Int, 3,3);)
  8.223 ns (1 allocation: 32 bytes)
("b", 2, 3.0)

While @code_lowered f3(t, A) is longer than for f2, the output of @code_llvm appears to be identical.

3 Likes

Oh interesting! I tried something very similar initially (almost the identical code) and inference gave up:

function f4(t::Tuple, A::Array{T,N}) where {T,N}
    if T <: AbstractFloat
        imin = 1
    elseif T <: Integer
        imin = 2
    else
        imin = 3
    end
    imax = N+2
    ntuple(i -> t[imin+i-1], imax-imin+1)
end
julia> let t = (:a, "b", 2, 3.0, Val(1), 2+im), A = rand(Int, 3,3)
           @show Base.return_types(f4, Tuple{typeof(t), typeof(A)})
           @btime f4(t, A)  setup=(t = (:a, "b", 2, 3.0, Val(1), 2+im); A = rand(Int, 3,3);)
       end
Base.return_types(f4, Tuple{typeof(t), typeof(A)}) = Any[Any]
  1.480 ΞΌs (4 allocations: 96 bytes)
("b", 2, 3.0)

whereas your code does indeed work:

julia> let t = (:a, "b", 2, 3.0, Val(1), 2+im), A = rand(Int, 3,3)
           @show Base.return_types(f3, Tuple{typeof(t), typeof(A)})
           @btime f3(t, A)  setup=(t = (:a, "b", 2, 3.0, Val(1), 2+im); A = rand(Int, 3,3);)
       end
Base.return_types(f3, Tuple{typeof(t), typeof(A)}) = Any[Tuple{String,Int64,Float64}]
  5.480 ns (1 allocation: 32 bytes)
("b", 2, 3.0)

That slight instability towards changes of the source making inference give up makes me want to lean towards generated functions.

1 Like

On Julia 1.0, neither of these infer, but they do differ:

Base.return_types(f3, Tuple{typeof(t), typeof(A)}) = Any[Tuple{Any,Any,Any}]
Base.return_types(f4, Tuple{typeof(t), typeof(A)}) = Any[Any]
1 Like

I’d be curious if the handful of special-cased unitranges in #31138 would be sufficient for the original questioner. In short, it special-cases t[1:end], t[2:end] , t[3:end] , t[1:end-1] , and t[1:end-2] such that those are inferred.

I plan on merging that in time for 1.5.

3 Likes

For generic slicing, I’d try:

2 Likes

I tried StaticRanges.jl already but it didn’t seem to help, I have doubts StaticNumbers.jl would.

It works for me:

julia> function f(t::Tuple, A::Array{T, N}) where {T, N}
           if T <: AbstractFloat
               imin = static(1)
           elseif T <: Integer
               imin = static(2)
           else
               imin = static(3)
           end
           imax = static(N+2)
           t[imin:imax]
       end
f (generic function with 1 method)

julia> Base.return_types(f, Tuple{typeof((:a, "b", 2, 3.0, Val(1), 2+im)), Array{Int,4}})
1-element Array{Any,1}:
 Tuple{String,Int64,Float64,Val{1},Complex{Int64}}
2 Likes

Inference gave up because of the closure bug, it boxed imin. Fixing it with a let works:

julia> function f4(t::Tuple, A::Array{T,N}) where {T,N}
           if T <: AbstractFloat
               imin = 1
           elseif T <: Integer
               imin = 2
           else
               imin = 3
           end
           imax = N+2
           let imin=imin
           ntuple(i -> t[imin+i-1], imax-imin+1)
           end
           end
f4 (generic function with 1 method)

julia> f4(t,A)
("b", 2, 3.0)

julia> @code_warntype f4(t,A)
Variables
  #self#::Core.Compiler.Const(f4, false)
  t::Tuple{Symbol,String,Int64,Float64,Val{1},Complex{Int64}}
  A::Array{Int64,2}
  imin@_4::Int64
  imax::Int64
  #34::var"#34#35"{Tuple{Symbol,String,Int64,Float64,Val{1},Complex{Int64}},Int64}
  imin@_7::Int64

Body::Tuple{String,Int64,Float64}
1 ─       Core.NewvarNode(:(imin@_4))
β”‚         Core.NewvarNode(:(imax))
β”‚   %3  = ($(Expr(:static_parameter, 1)) <: Main.AbstractFloat)::Core.Compiler.Const(false, false)
└──       goto #3 if not %3
2 ─       Core.Compiler.Const(:(imin@_4 = 1), false)
└──       Core.Compiler.Const(:(goto %12), false)
3 β”„ %7  = ($(Expr(:static_parameter, 1)) <: Main.Integer)::Core.Compiler.Const(true, false)
β”‚         %7
β”‚         (imin@_4 = 2)
└──       goto #5
4 ─       Core.Compiler.Const(:(imin@_4 = 3), false)
5 β”„       (imax = $(Expr(:static_parameter, 2)) + 2)
β”‚   %13 = imin@_4::Core.Compiler.Const(2, false)::Core.Compiler.Const(2, false)
β”‚         (imin@_7 = %13)
β”‚   %15 = Main.:(var"#34#35")::Core.Compiler.Const(var"#34#35", false)
β”‚   %16 = Core.typeof(t)::Core.Compiler.Const(Tuple{Symbol,String,Int64,Float64,Val{1},Complex{Int64}}, false)
β”‚   %17 = Core.typeof(imin@_7::Core.Compiler.Const(2, false))::Core.Compiler.Const(Int64, false)
β”‚   %18 = Core.apply_type(%15, %16, %17)::Core.Compiler.Const(var"#34#35"{Tuple{Symbol,String,Int64,Float64,Val{1},Complex{Int64}},Int64}, false)
β”‚         (#34 = %new(%18, t, imin@_7::Core.Compiler.Const(2, false)))
β”‚   %20 = #34::Core.Compiler.PartialStruct(var"#34#35"{Tuple{Symbol,String,Int64,Float64,Val{1},Complex{Int64}},Int64}, Any[Tuple{Symbol,String,Int64,Float64,Val{1},Complex{Int64}}, Core.Compiler.Const(2, false)])::Core.Compiler.PartialStruct(var"#34#35"{Tuple{Symbol,String,Int64,Float64,Val{1},Complex{Int64}},Int64}, Any[Tuple{Symbol,String,Int64,Float64,Val{1},Complex{Int64}}, Core.Compiler.Const(2, false)])
β”‚   %21 = (imax::Core.Compiler.Const(4, false) - imin@_7::Core.Compiler.Const(2, false))::Core.Compiler.Const(2, false)
β”‚   %22 = (%21 + 1)::Core.Compiler.Const(3, false)
β”‚   %23 = Main.ntuple(%20, %22)::Tuple{String,Int64,Float64}
└──       return %23

julia> @btime f4($t,$A)
  5.869 ns (1 allocation: 32 bytes)
("b", 2, 3.0)

Cheers!

5 Likes

I see, good find! That makes sense of why @tkf’s solution worked as well!