Type unstable function returning a tuple

Good evening,

I have been facing a (slightly) annoying type inference issue when writing a function which returns a tuple.

Here is a minimal working example:

f(v, ::Val{N}) where {N} = ntuple(i -> sum(v[1]), N) # type unstable
g(v, ::Val{N}) where {N} = ntuple(i -> v[1][1], N) # type stable

For instance, set v = [[1,2,3], [4,5,6]].

We have:

julia> @code_warntype f(v, Val(3))
Variables
  #self#::Core.Compiler.Const(f, false)
  v::Array{Array{Int64,1},1}
  #unused#::Core.Compiler.Const(Val{3}(), false)
  #11::var"#11#12"{Array{Array{Int64,1},1}}

Body::Tuple{Vararg{Int64,N} where N}
1 ─ %1 = Main.:(var"#11#12")::Core.Compiler.Const(var"#11#12", false)
│   %2 = Core.typeof(v)::Core.Compiler.Const(Array{Array{Int64,1},1}, false)
│   %3 = Core.apply_type(%1, %2)::Core.Compiler.Const(var"#11#12"{Array{Array{Int64,1},1}}, false)
│        (#11 = %new(%3, v))
│   %5 = #11::var"#11#12"{Array{Array{Int64,1},1}}
│   %6 = Main.ntuple(%5, $(Expr(:static_parameter, 1)))::Tuple{Vararg{Int64,N} where N}
└──      return %6

However, g is type stable:

julia> @code_warntype g(v, Val(3))
Variables
  #self#::Core.Compiler.Const(g, false)
  v::Array{Array{Int64,1},1}
  #unused#::Core.Compiler.Const(Val{3}(), false)
  #9::var"#9#10"{Array{Array{Int64,1},1}}

Body::Tuple{Int64,Int64,Int64}
1 ─ %1 = Main.:(var"#9#10")::Core.Compiler.Const(var"#9#10", false)
│   %2 = Core.typeof(v)::Core.Compiler.Const(Array{Array{Int64,1},1}, false)
│   %3 = Core.apply_type(%1, %2)::Core.Compiler.Const(var"#9#10"{Array{Array{Int64,1},1}}, false)
│        (#9 = %new(%3, v))
│   %5 = #9::var"#9#10"{Array{Array{Int64,1},1}}
│   %6 = Main.ntuple(%5, $(Expr(:static_parameter, 1)))::Tuple{Int64,Int64,Int64}
└──      return %6

In my code, the needed tuple is used in a loop so I resolved this type instability via function barrier.

Although, why the compiler could not find out the length of the tuple returned by f? Indeed, it knows what sum does and the length N is given as a type.

EDIT: I should mention that I am using Julia 1.5.2. Moreover, in case this relates, the function

h(v::Vector{Vector{Int}}) = ntuple(i -> sum(v[1]), 3)

is also type unstable.

I think that the problem is that the elements of v could in principle be of any type. If you construct v with a specified element type, what happens?

v = Vector{Int}[ .... ]

Sorry, cannot test now (edit: and looking closer that seems not to be case)

I tried defining

f(v::Vector{Vector{Int}}, ::Val{N}) where {N} = ntuple(i -> sum(v[1]), N)

But this function remains unstable.

Again, sorry for just guessing now. You may need to guarantee that the elements of v have dimension 1, and for that use static arrays. Or perhaps pass to sum that you only want the sum over the first dimension.

No worries :slight_smile:. Interesting, it seems that you are right; the following function is type stable:

l(v::Vector{NTuple{N,Int}}, ::Val{M}) where {N,M} = ntuple(i -> sum(v[1]), M)

However, I still do not understand the problem with the other function f and h (cf. edit in original post). The compiler knows that no matter the dimension / length the result of sum(v[1]) is Int64. No?

If v[1] was a vector of vectors it would return a vector:

sum([ones(3),ones(3)])

[2.0, 2.0, 2.0]

Agreed! However, considering the definition of h, Julia knows that v::Vector{Vector{Int}} such that v[1]::Vector{Int}. No?

In fact, isn’t it what @code_warntype shows us, i.e. Body::Tuple{Vararg{Int64,N} where N}? Perhaps I do not understand properly how to read this… Julia seems to know that the result is a tuple of Int64 but does not figure out the length of it. Perhaps the compiler chooses not to specialize these functions on purpose?

For the sake of completeness, in my case the function I use is closer to something like this:

foo(t::NTuple{N,NTuple{M,Int}}) where {N,M} =
    ntuple(i -> nextpow(2, mapreduce(tj -> tj[i], +, t)), M)

Again, to me it seems that everything is known at compile time… yet it is type unstable.

1 Like

So I actually found a trick to fix the type stability issue: use Val when calling ntuple. The following versions of h,f,foo are all type stable:

h_(v::Vector{Vector{Int}}) = ntuple(i -> sum(v[1]), Val{3}())
f_(v, ::Val{N}) where {N} = ntuple(i -> sum(v[1]), Val{N}())
foo_(t::NTuple{N,NTuple{M,Int}}) where {N,M} =
    ntuple(i -> nextpow(2, mapreduce(tj -> tj[i], +, t)), Val{M}())

(take v = [[1,2,3],[4,5,6]] and t = ((1, 2, 3), (4, 5, 6)) for instance)

This does not explain why h,f,foo were not type stable though :frowning:

ntuple(f, n::Integer) is an inherently type unstable function, because the type of the output depends on the value of n, rather than on the type of n. For example, the return type of ntuple(i -> 2i, 2) is Tuple{Int,Int}, but the return type of ntuple(i -> 2i, 3) is Tuple{Int,Int,Int}.

If you call methods(ntuple), you can see that they have provided some type-stable versions of ntuple:

julia> methods(ntuple)
# 6 methods for generic function "ntuple":
[1] ntuple(f::F, n::Integer) where F in Base at ntuple.jl:17
[2] ntuple(f, ::Val{0}) in Base at ntuple.jl:40
[3] ntuple(f, ::Val{1}) in Base at ntuple.jl:41
[4] ntuple(f, ::Val{2}) in Base at ntuple.jl:42
[5] ntuple(f, ::Val{3}) in Base at ntuple.jl:43
[6] ntuple(f::F, ::Val{N}) where {F, N} in Base at ntuple.jl:45

However, these type stable methods are not documented in the docstring for ntuple

EDIT: Upon further reflection, maybe that doesn’t explain everything. As you originally asked, why is g type stable? I guess when ntuple is used inside a function like f and g above, the compiler sometimes knows that N is a constant? :thinking: But it depends on the function provided to ntuple??? :man_shrugging:

Just to point out that even f becomes type-stable if the v provided has a defined length:

julia> f(v, ::Val{N}) where {N} = ntuple(i -> sum(v[1]), N)

julia> v2 = [ SVector([1,2,3]...), SVector([1,2,3]...) ]
2-element Array{SArray{Tuple{3},Int64,1,3},1}:
 [1, 2, 3]
 [1, 2, 3]

julia> @code_warntype f(v2,Val(3))
Variables
  #self#::Core.Compiler.Const(f, false)
  v::Array{SArray{Tuple{3},Int64,1,3},1}
  #unused#::Core.Compiler.Const(Val{3}(), false)
  #1::var"#1#2"{Array{SArray{Tuple{3},Int64,1,3},1}}

Body::Tuple{Int64,Int64,Int64}

You can just write

f(v, ::Val{N}) where {N} = ntuple(i -> sum(v[1]), Val(N))

or, alternatively

f(v, n::Val{N}) where {N} = ntuple(i -> sum(v[1]), n)

Or even:

f(v, n::Val) = ntuple(i -> sum(v[1]), n)
3 Likes

There are helper functions in Unrolled.jl which help with these, in the general case.

using Unrolled
f(v, ::Val{N}) where {N} = unrolled_map(i->sum(v[1]), @fixed_range 1:N)

works. @DNF’s solution is simpler though in this instance.

1 Like

I thought I had found a nice solution for this, namely:

 foo(t) = ntuple(i -> nextpow(2, mapreduce(tj -> tj[i], +, t)), length(t[Val(1)]))

since @code_warntype gave a nice result. However, apparently, you cannot index with Val:

julia> foo(((1,2,3), (3,4,5), (5,6,7)))
ERROR: MethodError: no method matching getindex(::Tuple{Tuple{Int64,Int64,Int64},Tuple{Int64,Int64,Int64},Tuple{Int64,Int64,Int64}}, ::Val{1})

I wonder why that is.

Actually, you should just write

f(v, n) = ntuple(i -> sum(v[1]), n)

As long as you call it with f(v, Val(3)) it will be type stable.

I like this. Just keep removing type annotations until there’s nothing left, and then you have the right solution.

1 Like

Why it should be? Val is not a subtype of number, nor any index type.

Well, Val(1) would be very nice to be able to index with. But I see it could be difficult to implement.

I do not think it is difficult…

Base.getindex(a :: AbstractArray, ::Val{T}) where T = a[T]

I just think they are different concepts which probably should be kept separate.

I am not sure if your Val(1) in length(t[Val(1)]) makes any sense, it seems to me the compiler should be able to do any optimization without the call to Val(1) and just using 1 instead. Previous code used Val together with a function barrier to compile a different and type-stable function for each number of elements in the tuple.

Yes, preferably, the compiler should just see this. But as for it making sense: Does

getindex(a, Val(3))

make any less sense than

ntuple(i->i^2, Val(3))

?

This removes the error, but doesn’t help with the type instability, btw.

No, you are correct, it does not. I think someone should make a PR to investigate why ntuple does not throw an error when Val is passed as second parameter. The documentation is clear that the second parameter should be an Integer and Val(3) is not.

Yes, this is the whole point, this is why I said your try did not make sense; I think I am not getting my point across.

In the other cases, the type-instability did not disappear by some Val magic, Val is a simple parametric type any of us could define. The type-instability did disappear because the part where it arose was wrapped into a function and the tuple size, which was passed as a value but defined the return type, was instead passed as a type (i.e., wrapped in Val). So for the same input types (i.e., same method), we have the same return type. This is what solves the type instability. Consequently, continuing to pass the tuple size as a value and then wrap it in Val inside the function has no effect.