Type inference with a tuple slice

Hi,

I’m playing around with NTuple to better familiarize myself with variable length function argument etc. In the process, I encountered an example where type inference failed. Here is a minimal example.

function test()
    tuple = (2,2,2)
    return tuple[1:2]
end

@code_warntype test()

This outputs

@code_warntype BLG_DMRG.test()
Variables
  #self#::Core.Compiler.Const(BLG_DMRG.test, false)
  tuple::Tuple{Int64,Int64,Int64}

Body::Tuple{Vararg{Int64,N} where N}
1 ─      (tuple = Core.tuple(2, 2, 2))
│   %2 = tuple::Core.Compiler.Const((2, 2, 2), false)::Core.Compiler.Const((2, 2, 2), false)
│   %3 = (1:2)::Core.Compiler.Const(1:2, false)
│   %4 = Base.getindex(%2, %3)::Tuple{Vararg{Int64,N} where N}
└──      return %4

Even though the compiler knows the value of the slice (1:2), it still cannot figure out the length of the output tuple. Why is that the case, and is there a way to make the type inference successful?

(This issue came up when I was playing around with the following code, which contracts the last index of the first tensor with the first index of the second tensor. The output type is determined by input types (i.e. given Array(N) and Array(M), output is Array(N+M-2)), but the compiler cannot seem to figure out, due to a similar tuple type inference failure. I added the result of @code_warntype for this at the bottom.)

function contract_last_and_first(W::AbstractArray, Z::AbstractArray)
    dimensions1 = size(W)
    dimensions2 = size(Z)
    W2 = reshape(W, prod(dimensions1[1:N-1]), dimensions1[end])
    Z2 = reshape(Z, dimensions2[1], prod(dimensions2[2:M]))
    Tensor = W2 * Z2
    reshape(Tensor, (dimensions1[1:end-1]...,dimensions2[2:end]...))
end

W = rand(3,3,3)
Z = rand(3,3,3,3)
@code_warntype(W,Z)
@code_warntype contract_last_and_first(W,Z)
Variables
  #self#::Core.Compiler.Const(contract_last_and_first, false)
  W::Array{Float64,3}
  Z::Array{Float64,4}
  dimensions1::Tuple{Int64,Int64,Int64}
  dimensions2::NTuple{4,Int64}
  W2::Array{Float64,2}
  Z2::Array{Float64,2}
  Tensor::Array{Float64,2}

Body::Array
1 ─       (dimensions1 = size(W))
│         (dimensions2 = size(Z))
│   %3  = dimensions1::Tuple{Int64,Int64,Int64}
│   %4  = (N - 1)::Any
│   %5  = (1:%4)::Any
│   %6  = Base.getindex(%3, %5)::Any
│   %7  = prod(%6)::Any
│   %8  = dimensions1::Tuple{Int64,Int64,Int64}
│   %9  = Base.lastindex(dimensions1)::Core.Compiler.Const(3, false)
│   %10 = Base.getindex(%8, %9)::Int64
│         (W2 = reshape(W, %7, %10))
│   %12 = Base.getindex(dimensions2, 1)::Int64
│   %13 = dimensions2::NTuple{4,Int64}
│   %14 = (2:M)::Any
│   %15 = Base.getindex(%13, %14)::Any
│   %16 = prod(%15)::Any
│         (Z2 = reshape(Z, %12, %16))
│         (Tensor = W2 * Z2)
│   %19 = Tensor::Array{Float64,2}
│   %20 = dimensions1::Tuple{Int64,Int64,Int64}
│   %21 = Base.lastindex(dimensions1)::Core.Compiler.Const(3, false)
│   %22 = (%21 - 1)::Core.Compiler.Const(2, false)
│   %23 = (1:%22)::Core.Compiler.Const(1:2, false)
│   %24 = Base.getindex(%20, %23)::Tuple{Vararg{Int64,N} where N}
│   %25 = dimensions2::NTuple{4,Int64}
│   %26 = Base.lastindex(dimensions2)::Core.Compiler.Const(4, false)
│   %27 = (2:%26)::Core.Compiler.Const(2:4, false)
│   %28 = Base.getindex(%25, %27)::Tuple{Vararg{Int64,N} where N}
│   %29 = Core._apply(Core.tuple, %24, %28)::Tuple{Vararg{Int64,N} where N}
│   %30 = reshape(%19, %29)::Array
└──       return %30

Thanks in advance!

1 Like

While we are waiting for the explanation, you may actually circumvent this problem by giving hint to compiler (actually, I was surprised that you can pass function to the type definition without any performance issues)

function test()
    tuple = (2,2,2)
    return tuple[1:2]
end

function test2()
    test() .+ test()
end

function testa()
    tuple = (2,2,2)
    return tuple[1:2]::NTuple{length(1:2), Int}
end

function test2a()
    testa() .+ testa()
end

function testb()
    tuple = (2,2,2)
    return tuple[1:2]::NTuple{2, Int}
end

function test2b()
    testb() .+ testb()
end

I used test2* functions in order to see problems propagation

@code_warntype test2()
Body::Any
1 ─ %1 = Main.test()::Tuple{Vararg{Int64,N} where N}
│   %2 = Main.test()::Tuple{Vararg{Int64,N} where N}
│   %3 = Base.broadcasted(Main.:+, %1, %2)::Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(+),_A} where _A<:Tuple
│   %4 = Base.materialize(%3)::Any
└──      return %4

@code_warntype test2a()
Body::Tuple{Int64,Int64}
1 ─ %1 = Main.testa()::Tuple{Int64,Int64}
│   %2 = Main.testa()::Tuple{Int64,Int64}
│   %3 = Base.broadcasted(Main.:+, %1, %2)::Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(+),Tuple{Tuple{Int64,Int64},Tuple{Int64,Int64}}}
│   %4 = Base.materialize(%3)::Tuple{Int64,Int64}
└──      return %4

And performance

@btime test2()  # 1.385 μs (7 allocations: 384 bytes)
@btime test2a() # 967.667 ns (4 allocations: 256 bytes)
@btime test2b() # 972.800 ns (4 allocations: 256 bytes)
1 Like

You may also consider using the Base functions first, tail and sometimes the odd reverse(Base.tail(reverse(tup))). These are type stable

3 Likes

NTuple{length(1:2), Int is a pretty neat trick! Thanks! For now I will use type hinting as suggested.

Ah thanks! Looks like first, last, Base.tail, Base.front are exactly what I need for the particular use case I have in mind.

1 Like

For future references. Instead of reverse, tail trick, one may use more straightforward approach with @generated functions

@generated function droplast(t::NTuple{N}) where N
    ex = :()
    for i in 1:N-1
        ex = :($ex..., t[$i])
    end
    return ex
end

droplast((1, 2, 3, 4)) # (1, 2, 3)

It is type stable and fast

@code_typed droplast((1, 2, 3, 4))

CodeInfo(
1 ─ %1 = Base.getfield(t, 1, true)::Int64
│   %2 = Base.getfield(t, 2, true)::Int64
│   %3 = Base.getfield(t, 3, true)::Int64
│   %4 = Core.tuple(%1, %2, %3)::Tuple{Int64,Int64,Int64}
└──      return %4
) => Tuple{Int64,Int64,Int64}

Generally one would avoid generated functions whenever possible (and here it is possible). See

1 Like

Thank you for the link! It is very informative. But how one can avoid generated functions in this case, except of reverse/tail trick which looks odd?

It may look unfamiliar, but it is a very idiomatic solution for these kind of problems. Check out

Just writing it explicitly with ntuple often works fine, these seem pretty much indistinguishable:

droplast2(t) = ntuple(i -> t[i], length(t)-1)

droplast3(t) = reverse(Base.tail(reverse(t)))

@code_typed droplast2((1, 2, 3, 4)) 
1 Like

ntuple is amazing! Thank you!