Type inference for nested iterators

In this thread, I wrote this function:

using Base.Iterators

function foo(itr, n, fillvalue)
    ntake = ceil(Int, length(itr)/n)
    extended_itr = flatten( ( itr, repeated(fillvalue) ) )
    take(partition(extended_itr, n), ntake)
end

I noticed that when I collect the iterator returned by foo, the type inference appears to fail:

julia> collect(foo(1:7, 3, 0))
3-element Array{Array{Any,1},1}:
 [1, 2, 3]
 [4, 5, 6]
 [7, 0, 0]

I would expect the output type to be Vector{Vector{Int}} rather than Vector{Vector{Any}}. @mkitti pointed out that length is an optional part of the iteration interface, but I wouldn’t think that would affect the type inference. Can anyone explain why type inference fails here? Should I open an issue on Github?

2 Likes

It appears that the problem is with flatten( ( itr, repeated(fillvalue) ) ):

julia> Base.IteratorEltype( flatten( ( 1:7, repeated(0) ) ))
Base.EltypeUnknown()

julia> eltype(flatten( ( 1:7, repeated(0) ) ))
Any

We can fix this by evaling the missing method into Base:

julia> @eval Base IteratorEltype(::Iterators.Flatten{Tuple{UnitRange{T}, Iterators.Repeated{T}}}) where {T} = HasEltype()
Base.IteratorEltype

julia> @eval Base eltype(::Iterators.Flatten{Tuple{UnitRange{T}, Iterators.Repeated{T}}}) where {T} = T
eltype (generic function with 70 methods)

Now we can check if the problem is fixed:

julia> collect(foo(1:7, 3, 0))
3-element Array{Array{Int64,1},1}:
 [1, 2, 3]
 [4, 5, 6]
 [7, 0, 0]

It’s basically just whackamole to try and get all the missing methods to make nested iterators inferrable communicate their types properly. I’d suggest checking out the interface docs for Iterators to learn more: Interfaces · The Julia Language

2 Likes

Note that this doesn’t have much to do with inference. Inference is an optimization and should not affect the output type.

4 Likes

Good point.

By the way, since you’re here, do you have any thoughts on whether it’s worth trying to make a PR with methods to make flatten(irt, repeated(x)) have an eltype, or is it not the sort of thing that’s likely to get accepted?

I know very little of the iterator stuff but in general it’s always good to PR :slight_smile:

5 Likes

Maybe the way to improve this situation in Base is to add Iterators.cat with the API that the element types of each iterator must be the same. So, a call like

Iterators.cat(itr1, itr2, itr3)

where eltype(itr1) == eltype(itr2) == eltype(itr3), would create an Iterators.Cat iterator object with element type equal to the common element type of the input iterators.

I could add this suggestion to the following Github issue:

https://github.com/JuliaLang/julia/issues/36760

Ok, I added a comment to that github issue proposing that Iterators.cat be added and that it should propagate the common element type of the concatenated iterators (and should throw an exception if they do not all have the same element type).

https://github.com/JuliaLang/julia/issues/36760#issuecomment-698698602

I don’t think a “whack-a-mole” approach like implementing eltype is the ideal solution. I think it would be great if PartitionIterator implemented the mutate-or-widen -based approach so that the returned vector always has accurate eltype. Doing this in Base is not trivial though.

2 Likes

After digging a little deeper, I see that the root source of the problem is that the eltype of a tuple is often Any. The definition of eltype for a Flatten iterator is this:

eltype(::Type{Flatten{I}}) where {I} = eltype(eltype(I))

So, Flatten will know its element type if I (the iterator of iterators) knows its element type. However, eltype for tuples is often Any. We can see this in action:

julia> using Base.Iterators

julia> eltype((1, 2))
Int64

julia> eltype(flatten((1, 2)))
Int64

julia> eltype((1, 2:3))
Any

julia> eltype(flatten((1, 2:3)))
Any

julia> eltype((1:2, 3:4))
UnitRange{Int64}

julia> eltype(flatten((1:2, 3:4)))
Int64

julia> eltype((1:2, repeated(3, 2)))
Any

julia> eltype(flatten((1:2, repeated(3, 2))))
Any
1 Like

Is doing something like this ill advised?

function eltype(::Type{Flatten{I}}) where {I}
    promote_type(eltype.(I.parameters)...)
end

Going with typejoin instead of promote_type might make more sense though since it won’t change the types of the iterators, but will lead to abstract types :man_shrugging:

julia> fl = Iterators.Flatten((1:3, 4:6, 7.0:9.0));

julia> function Base.eltype(::Type{Iterators.Flatten{I}}) where {I}
           promote_type(eltype.(I.parameters)...)
       end

julia> collect(fl)
9-element Array{Float64,1}:
 1.0
 2.0
 3.0
 4.0
 5.0
 6.0
 7.0
 8.0
 9.0

julia> function Base.eltype(::Type{Iterators.Flatten{I}}) where {I}
           typejoin(eltype.(I.parameters)...)
       end

julia> collect(fl)
9-element Array{Real,1}:
 1
 2
 3
 4
 5
 6
 7.0
 8.0
 9.0
1 Like

I was thinking about adding an Iterators.cat function that checks that the element types of its inputs are equal. Here’s a partial implementation, where I use the name mycat instead of Iterators.cat:

using Base.Iterators

struct CompatibleIterators{T}
    itrs::T

    function CompatibleIterators(itrs::T) where {T}
        eltypes_same = ( length(unique(eltype.(itrs))) == 1 )
        eltypes_same || throw(ArgumentError("element types are not all equal"))
        new{T}(itrs)
    end
end

Base.eltype(c::CompatibleIterators) = eltype(first(c.itrs))

Base.iterate(c::CompatibleIterators) = iterate(c.itrs)
Base.iterate(c::CompatibleIterators, state) = iterate(c.itrs, state)

mycat(itrs...) = flatten(CompatibleIterators(itrs))

In action:

julia> collect(mycat(1:2, 3, 4:5))
5-element Array{Int64,1}:
 1
 2
 3
 4
 5

julia> mycat(1:2, 3, 4.0:5.0)
ERROR: ArgumentError: element types are not all equal

However, in some situations we would probably like type promotion to occur, as @tomerarnon mentioned.

1 Like

I noticed that

using Base.Iterators

function Base.eltype(::Type{Iterators.Flatten{I}}) where {I}
    promote_type(eltype.(I.parameters)...)
end

doesn’t quite get us all the way there. It works for this:

julia> collect(flatten((1:2, 3:4, 4.5:5.5)))
6-element Array{Float64,1}:
 1.0
 2.0
 3.0
 4.0
 4.5
 5.5

But for this we get an array of Reals instead of Float64s:

julia> collect(flatten((1:2, 4, 4.5:5.5)))
5-element Array{Real,1}:
 1
 2
 4
 4.5
 5.5

We can fix this by also overriding the IteratorEltype(::Type{Flatten{I}}) method. Actually, that method calls _flatteneltype, which is what I’m going to override:

using Base.Iterators

promote_itr_eltypes(I) = promote_type(eltype.(I.parameters)...)

Base.eltype(::Type{Iterators.Flatten{I}}) where {I} = promote_itr_eltypes(I)
Base.Iterators._flatteneltype(I, ::Base.HasEltype) = Base.IteratorEltype(promote_itr_eltypes(I))

This works properly for both of the above cases:

julia> collect(flatten((1:2, 3:4, 4.5:5.5)))
6-element Array{Float64,1}:
 1.0
 2.0
 3.0
 4.0
 4.5
 5.5

julia> collect(flatten((1:2, 4, 4.5:5.5)))
5-element Array{Float64,1}:
 1.0
 2.0
 4.0
 4.5
 5.5

However, it might be necessary to restrict this behavior to Flatten{<:Tuple}, because I’m not sure that the I.parameters trick will work for arbitrary iterators.

1 Like

I think the root problem is rather that eltype is used (by default).

1 Like

Well, “root” is relative. I guess I meant the root cause taking the current implementation as a given.

Whether or not we move to a mutate-or-widen approach (which would be great!), I think there’d still be many situations where we’d want a working eltype.

2 Likes

I agree tighter eltype is better for optimization.

3 Likes

Ok, here’s my latest implementation, which calculates the promoted type of all the element types. I’ve implemented it as catitrs, and I’ve taken the liberty of changing the syntax to a vararg function. The implementation is provisional and incomplete, but it gets the basic idea across.

using Base.Iterators

struct CatIterator{T, I}
    itrs::I

    function CatIterator(itrs::I) where {I}
        T = mapreduce(eltype, promote_type, itrs)
        new{T, I}(itrs)
    end
end

catitrs(itrs...) = CatIterator(itrs)
Base.eltype(::Type{CatIterator{T, I}}) where {T, I} = T
Base.length(c::CatIterator) = sum(length.(c.itrs))

The iterate method is essentially copy-pasted from the Base iterate method for Flatten.

Iteration method
function Base.iterate(c::CatIterator, state=())
    if state !== ()
        y = iterate(Iterators.tail(state)...)
        y !== nothing && return (y[1], (state[1], state[2], y[2]))
    end
    x = (state === () ? iterate(c.itrs) : iterate(c.itrs, state[1]))
    x === nothing && return nothing
    y = iterate(x[1])
    while y === nothing
        x = iterate(c.itrs, x[2])
        x === nothing && return nothing
        y = iterate(x[1])
    end
    return y[1], (x[2], x[1], y[2])
end

At the REPL:

julia> collect(catitrs(1:2, 3, 4.5:5.5))
5-element Array{Float64,1}:
 1.0
 2.0
 3.0
 4.5
 5.5

julia> collect(catitrs(1:2, "ab"))
4-element Array{Any,1}:
 1
 2
  'a'
  'b'