I’m trying to build an efficient iterator for a bunch of nested loops, so that the inner loops depend on the outer loops. I’m having problems achieving comparable performance with the iterator approach as with the loop approach.
This is a minimal example of the kind of loop I’m interested in
function test_loop(n)
k = 0
for i in 1:n, j in i^2:i^3
k += i + j
end
return k
end
I would like to write it as
function test_iter(n)
k = 0
for (i, j) in MyIter(1:n)
k += i + j
end
return k
end
This is what I’ve tried
struct MyIter
irange::UnitRange{Int}
end
struct MyState
i::Int
ist::Int
jrange::UnitRange{Int}
j::Int
jst::Int
end
const endstate = MyState(0, 0, 0:0, 0, 0) # Sentinel for type stability of state
MyState(it::MyIter) = MyState(it, iterate(it.irange))
MyState(it::MyIter, ::Nothing) = endstate
MyState(it::MyIter, (i, ist)) = (jrange = i^2:i^3; MyState(it, i, ist, jrange, iterate(jrange)))
MyState(it::MyIter, i, ist, jrange, ::Nothing) = MyState(it, iterate(it.irange, ist))
MyState(it::MyIter, i, ist, jrange, (j, jst)) = MyState(i, ist, jrange, j, jst)
nextstate(it::MyIter, s::MyState) = MyState(it, s.i, s.ist, s.jrange, iterate(s.jrange, s.jst))
function Base.iterate(it::MyIter, s = MyState(it))
s == endstate && return nothing
newstate = nextstate(it, s)
return (s.i, s.j), newstate
end
While this code is clear enough to me, and it seems perfectly type-stable to the best of my understanding, I get very poor performance with the iterator approach
julia> using BenchmarkTools
julia> @btime test_loop(10)
17.948 ns (0 allocations: 0 bytes)
1000604
julia> @btime test_iter(10)
3.111 μs (0 allocations: 0 bytes)
1000604
What am I doing wrong? What is a good, performant approach for this?