Iβm having a tough time writing performant iterators that wrap multiple other iterators; as an illustrative example this is how I would currently write a simple merge sort iterator:
struct MergeSortIterator{I1,I2}
a::I1
b::I2
end
Base.length(itr::MergeSortIterator) = length(itr.a) + length(itr.b)
Base.eltype(::Type{MergeSortIterator{I1,I2}}) where {I1,I2} = promote_type(eltype(I1), eltype(I2))
function Base.iterate(itr::MergeSortIterator, state=(iterate(itr.a), iterate(itr.b)))
a_iter_result, b_iter_result = state
a_iter_result === nothing && b_iter_result === nothing && return nothing
a_iter_result === nothing && return (b_iter_result[1], (nothing, iterate(itr.b, b_iter_result[2])))
b_iter_result === nothing && return (a_iter_result[1], (iterate(itr.a, a_iter_result[2]), nothing))
item_a, state_a = a_iter_result
item_b, state_b = b_iter_result
item_a < item_b ? (item_a, (iterate(itr.a, state_a), b_iter_result)) : (item_b, (a_iter_result, iterate(itr.b, state_b)))
end
This works, e.g.,
v1 = sort!(rand(10))
v2 = sort!(rand(5))
collect(MergeSortIterator(v1, v2))
# [0.0135208, 0.125779, 0.221346, 0.386235, 0.530192, 0.614787, 0.629101, 0.702266, 0.708002, 0.85657, 0.86487, 0.886334, 0.895294, 0.912969, 0.947071]
but itβs pretty slow because the iterator state type can be a combinatorial explosion of component iterator iterate
return types (Nothing
or Tuple{Float64,Int64}}
in this case):
julia> @code_warntype iterate(MergeSortIterator(v1, v2))
Body::Union{Nothing, Tuple{Float64,Any}}
2 1 β %1 = (Base.getfield)(itr, :a)::Array{Float64,1} ββ» getproperty
β %2 = (Base.arraylen)(%1)::Int64 βββ»β· iterate
β %3 = (Base.sle_int)(0, %2)::Bool ββββ»β· <
β %4 = (Base.bitcast)(UInt64, %2)::UInt64 βββββ» unsigned
β %5 = (Base.ult_int)(0x0000000000000000, %4)::Bool βββββ» <
β %6 = (Base.and_int)(%3, %5)::Bool βββββ» &
βββ goto #3 if not %6 βββ
2 β %8 = (Base.arrayref)(false, %1, 1)::Float64 ββββ» getindex
β %9 = (Core.tuple)(%8, 2)::Tuple{Float64,Int64} βββ
βββ goto #4 βββ
3 β %11 = Base.nothing::Nothing βββ
βββ goto #4 βββ
4 β %13 = Ο (#2 => %9, #3 => %11)::Union{Nothing, Tuple{Float64,Int64}} ββ
βββ goto #5 ββ
5 β %15 = (Base.getfield)(itr, :b)::Array{Float64,1} ββ» getproperty
β %16 = (Base.arraylen)(%15)::Int64 βββ»β· iterate
β %17 = (Base.sle_int)(0, %16)::Bool ββββ»β· <
β %18 = (Base.bitcast)(UInt64, %16)::UInt64 βββββ» unsigned
β %19 = (Base.ult_int)(0x0000000000000000, %18)::Bool βββββ» <
β %20 = (Base.and_int)(%17, %19)::Bool βββββ» &
βββ goto #7 if not %20 βββ
6 β %22 = (Base.arrayref)(false, %15, 1)::Float64 ββββ» getindex
β %23 = (Core.tuple)(%22, 2)::Tuple{Float64,Int64} βββ
βββ goto #8 βββ
7 β %25 = Base.nothing::Nothing βββ
βββ goto #8 βββ
8 β %27 = Ο (#6 => %23, #7 => %25)::Union{Nothing, Tuple{Float64,Int64}} ββ
βββ goto #9 ββ
9 β %29 = (Core.tuple)(%13, %27)::Tuple{Union{Nothing, Tuple{Float64,Int64}},Union{Nothing, Tuple{Float64,Int64}}}
β %30 = (#self#)(itr, %29)::Union{Nothing, Tuple{Float64,Any}} β
βββ return %30 β
Does anyone know of a better way to write iterators that need to maintain the states of sub-iterators?
Iβve taken a look at the implementation of Base.zip
(in particular Zip2
). It doesnβt suffer from this type inference problem because it truncates on the first nothing
it sees from any of the wrapped iterators (and thus the return type of iterate
is just a union of Nothing
and Tuple{[Tuple of value types], [Tuple of state types]}
:
julia> @code_warntype iterate(zip(v1, v2))
Body::Union{Nothing, Tuple{Tuple{Float64,Float64},Tuple{Int64,Int64}}}
320 1 ββ %1 = (Base.getfield)(z, :a)::Array{Float64,1} ββ» getproperty
β %2 = (Base.getfield)(z, :b)::Array{Float64,1} ββ
ββββ goto #8 if not true ββ» zip_iterate
...
In order to achieve this sort of semi-type-stability, I can keep around an extra sentinel as in this pretty hacky implementation (this works, and is performant), but Iβm hoping thereβs a better way (in particular, _dummy_iterate_result
feels problematic):
_dummy_iterate_result(itr::Vector{T}) where {T} = (zero(T), 0)
function _start(itr::MergeSortIterator)
a_iter_result = iterate(itr.a)
b_iter_result = iterate(itr.b)
if (a_done = a_iter_result === nothing)
a_iter_result = _dummy_iterate_result(itr.a)
end
if (b_done = b_iter_result === nothing)
b_iter_result = _dummy_iterate_result(itr.b)
end
(a_done, b_done, a_iter_result, b_iter_result)
end
function Base.iterate(itr::MergeSortIterator, state=_start(itr))
a_done, b_done, a_iter_result, b_iter_result = state
a_done && b_done && return nothing
if b_done || (!a_done && a_iter_result[1] < b_iter_result[1])
item_a, state_a = a_iter_result
a_iter_result = iterate(itr.a, state_a)
a_done = a_iter_result === nothing
return (item_a, (a_done, b_done, a_done ? _dummy_iterate_result(itr.a) : a_iter_result, b_iter_result))
else
item_b, state_b = b_iter_result
b_iter_result = iterate(itr.b, state_b)
b_done = b_iter_result === nothing
return (item_b, (a_done, b_done, a_iter_result, b_done ? _dummy_iterate_result(itr.b) : b_iter_result))
end
end