Avoid creation of temporaries in non-trivial iterator

Consider the following Julia “compound” iterator: it merges two iterators, a and b , each of which are assumed to be sorted according to order, to a single ordered sequence:

struct MergeSorted{T,A,B,O}
    a::A
    b::B
    order::O
    MergeSorted(a::A, b::B, order::O=Base.Order.Forward) where {A,B,O} =
        new{promote_type(eltype(A),eltype(B)),A,B,O}(a, b, order)
end

Base.eltype(::Type{MergeSorted{T,A,B,O}}) where {T,A,B,O} = T

function Base.iterate(self::MergeSorted{T}, 
                      state=(iterate(self.a), iterate(self.b))) where T
    a_result, b_result = state
    if b_result === nothing
        a_result === nothing && return nothing
        a_curr, a_state = a_result
        return T(a_curr), (iterate(self.a, a_state), b_result)
    end
    b_curr, b_state = b_result
    if a_result !== nothing
        a_curr, a_state = a_result
        Base.Order.lt(self.order, a_curr, b_curr) &&
            return T(a_curr), (iterate(self.a, a_state), b_result)
    end
    return T(b_curr), (a_result, iterate(self.b, b_state))
end

x = MergeSorted([1,4,5,9,32,44], [0,7,9,24,134])
sum(x); print(@allocated sum(x))

This code works, but creates a temporary in each iteration step. Running iterate(::MergeSorted) through @code_warntype reveals that the return type is Union{Nothing, Tuple{Int, Any}}, which means that Julia gives up on specializing on the return type.

I first posted this question on stackoverflow, where user Bogumił Kamiński suggested a rather ugly workaround, which introduces a special state type to “force” julia to specialize. We were wondering whether there is a better possiblity, e.g., by forcing Julia to specialize on more complicated types?

Or maybe there is a smart way to rewrite this such that julia can figure this out on its own? I got encouraged by the standard library’s Base.Iterators package: it has a flatten() routine which does pretty much the same thing as MergeSorted() except that it does not intersperse the elements, yet it does not create temporaries. Does anyone with more iterator-foo transfer this to this case?

This seems like a case for multiple dispatch—implement four methods for the four different possibilities (A + B, A + nothing, nothing + B, nothing + nothing) and then let the type system figure out the rest

@josuagrw : Unfortunately, no. Here is the code for multiple dispatch, which does not help avoiding temporaries:

struct MergeSorted{T,A,B,O,S}
    a::A
    b::B
    order::O
    init::S

    function MergeSorted(a::A, b::B, order::O=Base.Order.Forward) where {A,B,O}
        T = promote_type(eltype(A), eltype(B))
        init = _make_state(iterate(a), iterate(b))
        new{T,A,B,O,typeof(init)}(a,b,order,init)
    end
end

Base.eltype(::Type{MergeSorted{T,A,B,O,S}}) where {T,A,B,O,S} = T

Base.iterate(self::MergeSorted) = iterate(self, self.init)

Base.iterate(self::MergeSorted{T}, ::Val{:done}) where T = nothing

function Base.iterate(self::MergeSorted{T}, state::Tuple{Val{:a_only},A}) where {T,A}
    _, (a_curr, a_state) = state
    return T(a_curr), _make_state(nothing, iterate(self.a, a_state))
end
    
function Base.iterate(self::MergeSorted{T}, state::Tuple{Val{:b_only},B}) where {T,B}
    _, (b_curr, b_state) = state
    return T(b_curr), _make_state(nothing, iterate(self.b, b_state))
end

function Base.iterate(self::MergeSorted{T}, state::Tuple{Val{:both},A,B}) where {T,A,B}
    _, (a_curr, a_state), (b_curr, b_state) = state
    if Base.Order.lt(self.order, a_curr, b_curr)
        return T(a_curr), _make_state(iterate(self.a, a_state), (b_curr, b_state))
    else
        return T(b_curr), _make_state((a_curr, a_state), iterate(self.b, b_state))
    end        
end

_make_state(a_result::Nothing, b_result::Nothing) = Val(:done)
_make_state(a_result::Nothing, b_result) = Val(:b_only), b_result
_make_state(a_result, b_result::Nothing) = Val(:a_only), a_result
_make_state(a_result, b_result) = Val(:both), a_result, b_result

What an interesting little problem!

Here’s a multiple dispatch + look-ahead approach that is inferable and free of allocations in v1.6+

struct MergeSorted{T,A,B,O}
    a::A
    b::B
    order::O
    MergeSorted(a::A, b::B, order::O=Base.Order.Forward) where {A,B,O} =
        new{promote_type(eltype(A),eltype(B)),A,B,O}(a, b, order)
end

Base.eltype(::Type{MergeSorted{T,A,B,O}}) where {T,A,B,O} = T

Base.iterate(self::MergeSorted, (sa, sb) = ((), ())) =
    _iterate(self, sa, sb, iterate(self.a, sa...), iterate(self.b, sb...))

_iterate(self::MergeSorted{T}, sa, sb, ::Nothing, ::Nothing) where {T} = nothing
_iterate(self::MergeSorted{T}, sa, sb, (ra, sa´), ::Nothing) where {T} = T(ra), (sa´, sb)
_iterate(self::MergeSorted{T}, sa, sb, ::Nothing, (rb, sb´)) where {T} = T(rb), (sa, sb´)
_iterate(self::MergeSorted{T}, sa, sb, (ra, sa´), (rb, sb´)) where {T} =
    Base.Order.lt(self.order, ra, rb) ? (T(ra), (sa´, sb)) : (T(rb), (sa, sb´))

x = MergeSorted([1,4,5,9,32,44], [0,7,9,24,134])
sum(x); print(@allocated sum(x))

EDIT: it may still need some small tweak to avoid allocations also in the case of empty collections

EDIT2: Forget the above, one just needs to properly cast to eltype T. Fixed.

1 Like

Interesting solution, but does it require run-time dispatch to advance the iterator?

No, it just avoids type-unstable states (sa, sb). In the x example here it is always Tuple{Int,Int}.

@lekand Thanks for neat and tidy solution!

Is there any way to understand why the added lookup is required here to avoid temporaries?

The pre-1.0 iteration protocol of start/next/done was type-stable, so your problem would have had a straightforward type-stable non-allocating solution in the old protocol. The newer iterate protocol is predicated on the assumption that the compiler is able to efficiently implement certain type-unstable operations, in particular, containers of objects of type Union{T,Nothing}. So solving your problem with efficient code in the iterate protocol requires reliance on aggressive compiler optimization, which is apparently not available yet. At least, I don’t see a solution.

Tangentially, there is an interesting blog post about the state of the iterator protocol Iterate on it – Mike Innes

What do you mean? The above solution is infered fine using union splitting. You just need to remain within the bounds of what unions can be split by the optimizer

In your proposed solution, on each iteration, one of the four versions of _iterate is called. The compiler does not know in advance which version is invoked since the choice is data dependent. Therefore, the running code must choose at run-time on each iteration which one to call, correct? This leads to some loss of performance.

Ah, but the key is that the return type of all _iterate functions is the same, regardless of which one is run. Well, to be precise, it is Union{Nothing, Tuple{Int,Int}}, so there are just two possible return types, but that is precisely what union splitting is designed to optimize (the two different type branches are precompiled, runtime dispatch is avoided, and you have essentially no runtime penalty).

The difference with the OP’s approach is that there, iterate returned a much more complex Union that shows up in red in @code_warntype, instead of yellow, signaling that you’re paying the price of runtime dispatch in that case. And the underlying reason is that the state of iterate is itself a Union.

So what I learned from this nice exercise is that to keep the compiler happy when using the post-1.0 iteration protocol is to construct your iterations so that iterate returns a Union{Nothing,T} for some concrete state type T, nothing much more complex than that, i.e. make your iteration states type-stable.

Is there any way to understand why the added lookup is required here to avoid temporaries?

Hi @mwallerb, maybe see the post above. The idea is "make iterate always return the same concrete type T unless it has finished its job, in which case return nothing". So the proposed solution just returns the a tuple with the states for the iterate of a and b. When either of them is finished (no state is returned), instead of returning a tuple with a nothing, you return the last state. That way your state remains always Tuple{Int,Int}, instead of Tuple{Union{Nothing,Int}, Union{Nothing,Int}}, which is too complex for the optimized to handle at this point.

I see, this is interesting! So the compiled code has all four versions of _iterate in-lined at the call site? Is there any way for the user to confirm this, short of poring over the machine code?

I like that @code_warntype does this kind of interpretation for you with its color coding (yellow = union splitting, red = runtime dispatch), but perhaps there are other lower-level ways to see it directly, I’d also be interested to know!

1 Like