I am trying to implement an iterator that acts on heterogenous types whose elements may or may not need to iterated over themselves. Hence, this is sort of a nested iterator. However, I keep running into unexpected allocations when this is implemented. Contextually, I am trying to implement this for a expression graph which has mixed node types for efficiency.
Here I implement a MWE with a more contrived case using integer arrays:
struct Itr
a::Vector{Union{Int, Vector{Int}}}
end
mutable struct _ItrData
itr::Vector{Union{Int, Vector{Int}}}
state::Int
use_internal::Bool
internal_itr::Vector{Int}
internal_state::Int
end
function _process_itr(raw::Int, state)
state.use_internal = false
return raw
end
function _process_itr(raw, state)
out = iterate(raw)
if out === nothing
new_out = iterate(state.itr, state.state)
new_out === nothing && return
state.state = new_out[2]
return _process_itr(new_out[1], state)
else
state.internal_itr = raw
state.use_internal = true
state.internal_state = out[2]
return out[1]
end
end
function Base.iterate(itr::Itr)
out = iterate(itr.a)
out === nothing && return
state = _ItrData(itr.a, out[2], false, Int[], 0)
raw = _process_itr(out[1], state)
return raw === nothing ? raw : (raw, state)
end
function Base.iterate(itr::Itr, state)
if state.use_internal
int_out = iterate(state.internal_itr, state.internal_state)
if int_out !== nothing
state.internal_state = int_out[2]
return int_out[1], state
end
end
out = iterate(state.itr, state.state)
out === nothing && return
state.state = out[2]
raw = _process_itr(out[1], state)
return raw === nothing ? raw : (raw, state)
end
Base.IteratorSize(::Itr) = Base.SizeUnknown()
Base.eltype(::Itr) = Int
This works but where I would only expect it to only allocate 1 time to create an instance of _ItrData
, it actually allocates a lot more (in proportion to the number of iterations):
function test(itr)
for i in itr
i
end
return
end
# Make example iteratorable objects
itr1 = Itr([1, [2, 3], 4, [5]])
itr2 = Itr([1, 2, 3, 4, 5])
# Test them
collect(itr1) # returns [1, 2, 3, 4, 5] as expected
collect(itr2) # returns [1, 2, 3, 4, 5] as expected
# Test allocations
@time test(itr1) # 0.000002 seconds (7 allocations: 288 bytes)
@time test(itr2) # 0.000002 seconds (7 allocations: 288 bytes)
I cannot figure our why the extra allocation is occurring and I would greatly appreciate help to figure out why it is happening and how I can fix it.