Defining a custom iterator that wraps multiple iterators

I’m having a tough time writing performant iterators that wrap multiple other iterators; as an illustrative example this is how I would currently write a simple merge sort iterator:

struct MergeSortIterator{I1,I2}
    a::I1
    b::I2
end

Base.length(itr::MergeSortIterator) = length(itr.a) + length(itr.b)
Base.eltype(::Type{MergeSortIterator{I1,I2}}) where {I1,I2} = promote_type(eltype(I1), eltype(I2))

function Base.iterate(itr::MergeSortIterator, state=(iterate(itr.a), iterate(itr.b)))
    a_iter_result, b_iter_result = state
    a_iter_result === nothing && b_iter_result === nothing && return nothing
    a_iter_result === nothing && return (b_iter_result[1], (nothing, iterate(itr.b, b_iter_result[2])))
    b_iter_result === nothing && return (a_iter_result[1], (iterate(itr.a, a_iter_result[2]), nothing))
    item_a, state_a = a_iter_result
    item_b, state_b = b_iter_result
    item_a < item_b ? (item_a, (iterate(itr.a, state_a), b_iter_result)) : (item_b, (a_iter_result, iterate(itr.b, state_b)))
end

This works, e.g.,

v1 = sort!(rand(10))
v2 = sort!(rand(5))
collect(MergeSortIterator(v1, v2))
# [0.0135208, 0.125779, 0.221346, 0.386235, 0.530192, 0.614787, 0.629101, 0.702266, 0.708002, 0.85657, 0.86487, 0.886334, 0.895294, 0.912969, 0.947071]

but it’s pretty slow because the iterator state type can be a combinatorial explosion of component iterator iterate return types (Nothing or Tuple{Float64,Int64}} in this case):

julia> @code_warntype iterate(MergeSortIterator(v1, v2))
Body::Union{Nothing, Tuple{Float64,Any}}
2 1 ─ %1  = (Base.getfield)(itr, :a)::Array{Float64,1}                                                    β”‚β•»    getproperty
  β”‚   %2  = (Base.arraylen)(%1)::Int64                                                                    β”‚β”‚β•»β•·   iterate
  β”‚   %3  = (Base.sle_int)(0, %2)::Bool                                                                   β”‚β”‚β”‚β•»β•·   <
  β”‚   %4  = (Base.bitcast)(UInt64, %2)::UInt64                                                            β”‚β”‚β”‚β”‚β•»    unsigned
  β”‚   %5  = (Base.ult_int)(0x0000000000000000, %4)::Bool                                                  β”‚β”‚β”‚β”‚β•»    <
  β”‚   %6  = (Base.and_int)(%3, %5)::Bool                                                                  β”‚β”‚β”‚β”‚β•»    &
  └──       goto #3 if not %6                                                                             β”‚β”‚β”‚  
  2 ─ %8  = (Base.arrayref)(false, %1, 1)::Float64                                                        β”‚β”‚β”‚β•»    getindex
  β”‚   %9  = (Core.tuple)(%8, 2)::Tuple{Float64,Int64}                                                     β”‚β”‚β”‚  
  └──       goto #4                                                                                       β”‚β”‚β”‚  
  3 ─ %11 = Base.nothing::Nothing                                                                         β”‚β”‚β”‚  
  └──       goto #4                                                                                       β”‚β”‚β”‚  
  4 β”„ %13 = Ο† (#2 => %9, #3 => %11)::Union{Nothing, Tuple{Float64,Int64}}                                 β”‚β”‚   
  └──       goto #5                                                                                       β”‚β”‚   
  5 ─ %15 = (Base.getfield)(itr, :b)::Array{Float64,1}                                                    β”‚β•»    getproperty
  β”‚   %16 = (Base.arraylen)(%15)::Int64                                                                   β”‚β”‚β•»β•·   iterate
  β”‚   %17 = (Base.sle_int)(0, %16)::Bool                                                                  β”‚β”‚β”‚β•»β•·   <
  β”‚   %18 = (Base.bitcast)(UInt64, %16)::UInt64                                                           β”‚β”‚β”‚β”‚β•»    unsigned
  β”‚   %19 = (Base.ult_int)(0x0000000000000000, %18)::Bool                                                 β”‚β”‚β”‚β”‚β•»    <
  β”‚   %20 = (Base.and_int)(%17, %19)::Bool                                                                β”‚β”‚β”‚β”‚β•»    &
  └──       goto #7 if not %20                                                                            β”‚β”‚β”‚  
  6 ─ %22 = (Base.arrayref)(false, %15, 1)::Float64                                                       β”‚β”‚β”‚β•»    getindex
  β”‚   %23 = (Core.tuple)(%22, 2)::Tuple{Float64,Int64}                                                    β”‚β”‚β”‚  
  └──       goto #8                                                                                       β”‚β”‚β”‚  
  7 ─ %25 = Base.nothing::Nothing                                                                         β”‚β”‚β”‚  
  └──       goto #8                                                                                       β”‚β”‚β”‚  
  8 β”„ %27 = Ο† (#6 => %23, #7 => %25)::Union{Nothing, Tuple{Float64,Int64}}                                β”‚β”‚   
  └──       goto #9                                                                                       β”‚β”‚   
  9 ─ %29 = (Core.tuple)(%13, %27)::Tuple{Union{Nothing, Tuple{Float64,Int64}},Union{Nothing, Tuple{Float64,Int64}}}
  β”‚   %30 = (#self#)(itr, %29)::Union{Nothing, Tuple{Float64,Any}}                                        β”‚    
  └──       return %30                                                                                    β”‚    

Does anyone know of a better way to write iterators that need to maintain the states of sub-iterators?


I’ve taken a look at the implementation of Base.zip (in particular Zip2). It doesn’t suffer from this type inference problem because it truncates on the first nothing it sees from any of the wrapped iterators (and thus the return type of iterate is just a union of Nothing and Tuple{[Tuple of value types], [Tuple of state types]}:

julia> @code_warntype iterate(zip(v1, v2))
Body::Union{Nothing, Tuple{Tuple{Float64,Float64},Tuple{Int64,Int64}}}
320 1 ── %1  = (Base.getfield)(z, :a)::Array{Float64,1}                                                  β”‚β•»     getproperty
    β”‚    %2  = (Base.getfield)(z, :b)::Array{Float64,1}                                                  β”‚β”‚    
    └───       goto #8 if not true                                                                       β”‚β•»     zip_iterate
...

In order to achieve this sort of semi-type-stability, I can keep around an extra sentinel as in this pretty hacky implementation (this works, and is performant), but I’m hoping there’s a better way (in particular, _dummy_iterate_result feels problematic):

_dummy_iterate_result(itr::Vector{T}) where {T} = (zero(T), 0)

function _start(itr::MergeSortIterator)
    a_iter_result = iterate(itr.a)
    b_iter_result = iterate(itr.b)
    if (a_done = a_iter_result === nothing)
        a_iter_result = _dummy_iterate_result(itr.a)
    end
    if (b_done = b_iter_result === nothing)
        b_iter_result = _dummy_iterate_result(itr.b)
    end
    (a_done, b_done, a_iter_result, b_iter_result)
end

function Base.iterate(itr::MergeSortIterator, state=_start(itr))
    a_done, b_done, a_iter_result, b_iter_result = state
    a_done && b_done && return nothing
    if b_done || (!a_done && a_iter_result[1] < b_iter_result[1])
        item_a, state_a = a_iter_result
        a_iter_result = iterate(itr.a, state_a)
        a_done = a_iter_result === nothing
        return (item_a, (a_done, b_done, a_done ? _dummy_iterate_result(itr.a) : a_iter_result, b_iter_result))
    else
        item_b, state_b = b_iter_result
        b_iter_result = iterate(itr.b, state_b)
        b_done = b_iter_result === nothing
        return (item_b, (a_done, b_done, a_iter_result, b_done ? _dummy_iterate_result(itr.b) : b_iter_result))
    end
end
1 Like

This would have been relatively easy under the old (0.6) iteration protocol. I think the following more complicated solution may work for the new iteration protocol. First, define a mutable struct to hold the vector of iteration states; each field is parameterized. If a component container is initially empty, then the corresponding field is populated with a nothing instead of an iteration state. This mutable struct also contains a Bool for each component iterator to indicate whether the particular container is exhausted. This is an object whose type cannot be determined at compile time, but once it is initialized, its type is stable. Therefore, after it is initialized, you can use a β€œfunction barrier” to send it to the actual worker routine.

Upon further thought, I decided that the solution proposed in my previous posting would have significant overhead in recompiling the worker function each time an iterator is created, so that it would be non-performant unless the containers are very large to compensate for the overhead.

However, I think the proposed solution can be fixed by defining some helper routines like this (not tested):

   iterator_state_type(::Any) = Any
   iterator_state_type(::Vector{T}) where {T<:Number} = Tuple{Int,T}
   # etc for all the common types over which you might want to call MergeSort
   default_iterator_state(::Any) = nothing
   default_iterator_state(::Vector{T}) where {T<:Number} = (0,zero(T))
   # etc for all the common types

Then in the setup routine described in my previous message, use these helper functions to create the iteration state for mergesort. In the common cases, the compiler will be able to use the helper function to infer the type of the struct, so that the worker arguments have a defined type signature at compile time. By the way, in my previous message, I said that the iteration state should be a mutable struct, but on further reflection, I think an immutable struct would be better.

Thanks for your suggestions – unless I’m mistaken it seems the solution you’re proposing is similar in spirit to the β€œhacky implementation” at the bottom of the OP (there, a tuple is used as an anonymous version of the immutable struct you’re proposing). Over the past few days I’ve actually come to have a much more positive view of this style of approach. I can confirm that it’s zero overhead (i.e., iterate(::MergeSortIterator) doesn’t allocate as long as the iterators it wraps don’t) and in my code variants of it actually outperform the indexing-based for loops (applicable only to AbstractVectors, not arbitrary iterators) they are replacing.

My only lingering concern is that _dummy_iterate_result (or in your case, default_iterator_state with the order of element/state flipped) needs to instantiate a default value of the eltype T in general (I’m not thrilled with, but ultimately okay with defining _dummy_iterate_result for each iterator type I might use). We both use zero(T) but I can imagine cases where zero(T) doesn’t make any sense. It would be nice to have a language feature to define placeholder instances of types where the contract for usage is that the programmer must verify that those instances won’t actually be used but, notably, the burden of proof is taken off of the compiler. Maybe a better version of something like this (which actually seems to work okay for isbitstypes):

julia> struct A end

julia> undef(A)
ERROR: MethodError: objects of type UndefInitializer are not callable
Stacktrace:
 [1] top-level scope at none:0

julia> (::UndefInitializer)(::Type{T}) where {T} = Ref{T}()[]

julia> undef(Int)
139950704410448

julia> undef(A)
A()

julia> using StaticArrays

julia> undef(SVector{3,Float64})
3-element SArray{Tuple{3},Float64,1,3}:
 6.9529415980266e-310 
 1.0e-323             
 6.91450508901983e-310

julia> undef(typeof(+))
+ (generic function with 171 methods)

which would allow a definition like

_dummy_iterate_result(itr::Vector{T}) where {T} = (undef(T), 0)

that I would be more happy with.

You are right-- my proposed solution is quite similar to your OP; sorry I didn’t notice this. So your solution seems fine to me, except possibly (IIUC) there may be the following small improvements:

  • The Ref{T}() trick (which I’ve never seen before!) doesn’t always work (see trace below). So there may need to be a more generic fallback if you want to write a routine that won’t fail.

  • Even when it does work, it appears to me that it might allocate, i.e., it seems that Ref{Int}()[] must put something on the heap, So, assuming this is the case, you might want to change your Base.iterate so that it does not call _dummy_iterate_result() but instead uses the last known good iteration state as the placeholder, and you confine the invocation of _dummy_iterate_result to _start. Not sure about this.

julia> x = Ref{Vector{Int}}()[]
ERROR: UndefRefError: access to undefined reference
Stacktrace:
 [1] getproperty at .\sysimg.jl:18 [inlined]
 [2] getindex(::Base.RefValue{Array{Int64,1}}) at .\refvalue.jl:32
 [3] top-level scope at none:0

Yeah, I think undef(::Type{T}) would have to be a language feature/compiler hint (along the lines of @simd or @inbounds, as opposed to being a function that returns an actually accessible value) to be most useful/general.

I agree that there’s no reason to call this in any case outside of the initial Nothing/non-Nothing ambiguity at the first iterate call.