Unexpected Allocation with Iterator Extension

I am trying to implement an iterator that acts on heterogenous types whose elements may or may not need to iterated over themselves. Hence, this is sort of a nested iterator. However, I keep running into unexpected allocations when this is implemented. Contextually, I am trying to implement this for a expression graph which has mixed node types for efficiency.

Here I implement a MWE with a more contrived case using integer arrays:

struct Itr
    a::Vector{Union{Int, Vector{Int}}}
end

mutable struct _ItrData
    itr::Vector{Union{Int, Vector{Int}}}
    state::Int
    use_internal::Bool
    internal_itr::Vector{Int}
    internal_state::Int 
end

function _process_itr(raw::Int, state)
    state.use_internal = false
    return raw
end

function _process_itr(raw, state)
    out = iterate(raw)
    if out === nothing
        new_out = iterate(state.itr, state.state)
        new_out === nothing && return
        state.state = new_out[2]
        return _process_itr(new_out[1], state)
    else
        state.internal_itr = raw
        state.use_internal = true
        state.internal_state = out[2]
        return out[1]
    end
end

function Base.iterate(itr::Itr)
    out = iterate(itr.a)
    out === nothing && return 
    state = _ItrData(itr.a, out[2], false, Int[], 0)
    raw = _process_itr(out[1], state)
    return raw === nothing ? raw : (raw, state)
end

function Base.iterate(itr::Itr, state)
    if state.use_internal
        int_out = iterate(state.internal_itr, state.internal_state)
        if int_out !== nothing
            state.internal_state = int_out[2]
            return int_out[1], state
        end
    end
    out = iterate(state.itr, state.state)
    out === nothing && return
    state.state = out[2]
    raw = _process_itr(out[1], state)
    return raw === nothing ? raw : (raw, state)
end

Base.IteratorSize(::Itr) = Base.SizeUnknown()
Base.eltype(::Itr) = Int

This works but where I would only expect it to only allocate 1 time to create an instance of _ItrData, it actually allocates a lot more (in proportion to the number of iterations):

function test(itr)
    for i in itr
        i
    end
    return
end

# Make example iteratorable objects
itr1 = Itr([1, [2, 3], 4, [5]])
itr2 = Itr([1, 2, 3, 4, 5])

# Test them 
collect(itr1) # returns [1, 2, 3, 4, 5] as expected
collect(itr2) # returns [1, 2, 3, 4, 5] as expected

# Test allocations
@time test(itr1) # 0.000002 seconds (7 allocations: 288 bytes)
@time test(itr2) # 0.000002 seconds (7 allocations: 288 bytes)

I cannot figure our why the extra allocation is occurring and I would greatly appreciate help to figure out why it is happening and how I can fix it.

I used @code_warntype to enhance the type stability of the functions, and I simplified things a little more. Unfortunately, now it allocates even moreā€¦

struct Itr
    a::Vector{Union{Int, Vector{Int}}}
end

mutable struct _ItrData
    state::Int
    use_internal::Bool
    internal_itr::Vector{Int}
    internal_state::Int 
end

function _process_itr_arg(itr_out::Tuple{Int, Int}, state)
    state.use_internal = false
    return itr_out[1]
end

function _process_itr_arg(itr_out::Tuple{Vector{Int}, Int}, state)
    out = iterate(itr_out[1])
    state.internal_itr = itr_out[1]
    state.use_internal = true
    state.internal_state = out[2]
    return out[1]
end

function _process_raw(raw::Nothing, itr, state)
    state.use_internal = false
    return iterate(itr, state)
end

function _process_raw(raw::Int, itr, state)
    return raw, state
end

function Base.iterate(itr::Itr)
    out = iterate(itr.a)
    out === nothing && return 
    state = _ItrData(out[2], false, Int[], 0)
    raw = _process_itr_arg(out, state)
    return _process_raw(raw, itr, state)
end

function Base.iterate(itr::Itr, state)
    if state.use_internal
        out = iterate(state.internal_itr, state.internal_state)
        if out === nothing
            state.use_internal = false
            return iterate(itr, state)
        else
            state.internal_state = out[2]
            return out[1], state
        end
    else
        out = iterate(itr.a, state.state)
        out === nothing && return 
        state.state = out[2]
        raw = _process_itr_arg(out, state)
        return _process_raw(raw, itr, state)
    end
end

Base.IteratorSize(::Itr) = Base.SizeUnknown()
Base.eltype(::Itr) = Int
function test(itr)
    for i in itr
        i
    end
    return
end

# Make example iteratorable objects
itr1 = Itr([1, [2, 3], 4, [5]])
itr2 = Itr([1, 2, 3, 4, 5])

# Test them 
collect(itr1) # returns [1, 2, 3, 4, 5] as expected
collect(itr2) # returns [1, 2, 3, 4, 5] as expected

# Test allocations
@time test(itr1) # 0.000006 seconds (11 allocations: 416 bytes)
@time test(itr2) # 0.000006 seconds (12 allocations: 448 bytes)

I am not an expert on this, but since no one else has answered yet, I will take a shot at explaining what I think is the problem. Each call to iterate(itr::Itr) causes an allocation because it creates a new _ItrData object on the heap. It is heap-allocated because it is mutable and contains a mutable member. Meanwhile, the routine iterate(itr::Itr, state) on the second branch invokes _process_itr_arg which in turn invokes the one-argument version of iterate, which has an allocation. So this may be why you are seeing so many allocations. In most codes that I have written, the iteration state argument is immutable. Have you tried to implement your code with an immutable iteration state?

Note: an immutable object with a mutable member used to be heap-allocated up to Julia 1.4, but starting with 1.5 a change was made to allow such items to be stack-allocated. Perhaps someone more knowledgeable than me can explain what was changed.