Cyclic pair iterator?

Hi,

I am not very familiar with iterators, so maybe someone can give me some hints.

I would like to build a consecutive pairs iterator that would take a collection C and return an iterator over the consecutive pairs. The last pair ending with the first element:

a=[3,4,2,1] # could be an iterator too
cp=cpair(a) #to be defined
for p in cp
   println(p)
end

would return

(3,4)
(4,2)
(2,1)
(1,3)

with the best possible performances (less allocations).

Thank you for your help

You can just look at the docs and define methods required for iteration, see Interfaces · The Julia Language

Here is one implementation:

struct PairIterator{T}
    it::Vector{T}
end

# Required method
function Base.iterate(p::PairIterator, i::Int=1)
    l = length(p.it)
    if i > l
        return nothing
    else
        j = mod(i, l)+1
        return (p.it[i], p.it[j]), i+1
    end
end

# Important optional methods
Base.eltype(::Type{PairIterator{T}}) where {T} = Tuple{T,T}
Base.length(p::PairIterator) = length(p.it)

Example:

julia> for i in PairIterator([3,4,2,1])
           println(i)
       end
(3, 4)
(4, 2)
(2, 1)
(1, 3)
1 Like

Thank you very much !

Your implementation does not handle iterator input.
I add @ffevotte ConsecutivePairs proposal in the following MWE.
The problem is that the sumdist function is 3-4 time slower using pair iterators than the loop based sumdist_noiter.

On my machine I obtain:

("gflops 0", 3.2064275973906313)  # sumdist_noiter
("gflops 1", 0.8656447249774571)  # sum dist with PairIterator
("gflops 2", 0.9534782090500973)  # sum_dist with ConsecutivePairs

I wonder if it is possible to improve the iterator based implementation (which is prettier) and recover the performance of loop based implementation…

Here is the MWE:

using BenchmarkTools
#########################
struct ConsecutivePairs{T}
    subiter::T
end

@inline function Base.iterate(cp::ConsecutivePairs)
    i = iterate(cp.subiter)
    i === nothing && return nothing
    first, substate = i

    iterate(cp, (substate, first))
end

@inline function Base.iterate(cp::ConsecutivePairs, state)
    (substate, first) = state

    i = iterate(cp.subiter, substate)
    i === nothing && return nothing
    second, substate = i

    return ((first, second), (substate, second))
end
cyclic_pairs(it) = ConsecutivePairs(Iterators.flatten((it, first(it))))
################################
struct PairIterator{T}
    it::Vector{T}
end

# Required method
function Base.iterate(p::PairIterator, i::Int=1)
    l = length(p.it)
    if i > l
        return nothing
    else
        j = mod(i, l)+1
        return (p.it[i], p.it[j]), i+1
    end
end

# Important optional methods
Base.eltype(::Type{PairIterator{T}}) where {T} = Tuple{T,T}
Base.length(p::PairIterator) = length(p.it)
##############################
function sumdist(cp)
    s=zero(eltype(first(first(cp))))
    for p in cp
        d=last(p)-first(p)
        s+=d^2
    end
    s
end
function sumdist_noiter(a)
    s=zero(eltype(a))
    v1=first(a)
    @inbounds for i in 2:length(a)
        d=a[i]-v1
        s+=d^2
        v1=a[i]
    end
    d=v1-first(a)
    s+=d^2
    s
end


function test()

    a=rand(1000)
    @show sumdist(PairIterator(a))==sumdist(cyclic_pairs(a))==sumdist_noiter(a)
    t0=@benchmark sumdist_noiter($a)
    t1=@benchmark sumdist(PairIterator($a))
    t2=@benchmark sumdist(cyclic_pairs($a))
    println(t1)
    @show t2
    @show "gflops 0",3*length(a)/(minimum(t0).time)
    @show "gflops 1",3*length(a)/(minimum(t1).time)
    @show "gflops 2",3*length(a)/(minimum(t2).time)

    @show t3=@btime sumdist_noiter(1:1000) #funny compiler trick
    # @show @btime sumdist(PairIterator(1:1000)) # does not work with iterator
    @show t4=@btime sumdist(cyclic_pairs(1:1000))

end

test()
1 Like

In the version based on cyclic_pairs, it looks like the performance is lost when composing the ConsecutivePairs iterator with Iterators.flatten in order to loop back to the first element.

Implementing the whole logic (iterating on pairs + looping back to the first element) in one iterator gets the for loop performance back:

struct CyclicPairs{T}
    subiter::T
end
@inline function Base.iterate(cp::CyclicPairs)
    i = iterate(cp.subiter)
    i === nothing && return nothing
    first, substate = i

    iterate(cp, (substate, first, first, #=finished=#false))
end
@inline function Base.iterate(cp::CyclicPairs, state)
    (substate, latest, first, finished) = state

    i = iterate(cp.subiter, substate)
    if i === nothing
        if finished
            return nothing
        else
            return ((latest, first), (substate, latest, first, #=finished=#true))
        end
    end
    current, substate = i

    return ((latest, current), (substate, current, first, #=finished=#false))
end
@inline Base.length(cp::CyclicPairs) = length(cp.subiter)

Here are the benchmarked times on my machine (with a very slight degradation in performance, probably due to the allocation of the iterator itself):

julia> @btime sumdist_noiter($a)
  782.621 ns (0 allocations: 0 bytes)
166.4520614732102

julia> @btime sumdist(CyclicPairs($a))
  797.495 ns (1 allocation: 16 bytes)
166.4520614732102

Or, re-using your own benchmarks:

("gflops 0", (3 * length(a)) / (minimum(t0)).time) = ("gflops 0", 3.8318927566066048)
("gflops 1", (3 * length(a)) / (minimum(t1)).time) = ("gflops 1", 0.5404107121412274)
("gflops 2", (3 * length(a)) / (minimum(t2)).time) = ("gflops 2", 0.7005137100540396)
("gflops 3", (3 * length(a)) / (minimum(t3)).time) = ("gflops 3", 3.7575645707806506)


Complete code
using BenchmarkTools


#########################


struct ConsecutivePairs{T}
    subiter::T
end

@inline function Base.iterate(cp::ConsecutivePairs)
    i = iterate(cp.subiter)
    i === nothing && return nothing
    first, substate = i

    iterate(cp, (substate, first))
end

@inline function Base.iterate(cp::ConsecutivePairs, state)
    (substate, first) = state

    i = iterate(cp.subiter, substate)
    i === nothing && return nothing
    second, substate = i

    return ((first, second), (substate, second))
end
@inline cyclic_pairs(it) = ConsecutivePairs(Iterators.flatten((it, first(it))))


#######################


struct CyclicPairs{T}
    subiter::T
end

@inline function Base.iterate(cp::CyclicPairs)
    i = iterate(cp.subiter)
    i === nothing && return nothing
    first, substate = i

    iterate(cp, (substate, first, first, #=finished=#false))
end

@inline function Base.iterate(cp::CyclicPairs, state)
    (substate, latest, first, finished) = state

    i = iterate(cp.subiter, substate)
    if i === nothing
        if finished
            return nothing
        else
            return ((latest, first), (substate, latest, first, #=finished=#true))
        end
    end
    current, substate = i

    return ((latest, current), (substate, current, first, #=finished=#false))
end

@inline Base.length(cp::CyclicPairs) = length(cp.subiter)


################################


struct PairIterator{T}
    it::Vector{T}
end

# Required method
function Base.iterate(p::PairIterator, i::Int=1)
    l = length(p.it)
    if i > l
        return nothing
    else
        j = mod(i, l)+1
        return (p.it[i], p.it[j]), i+1
    end
end

# Important optional methods
Base.eltype(::Type{PairIterator{T}}) where {T} = Tuple{T,T}
Base.length(p::PairIterator) = length(p.it)


##############################


function sumdist(cp)
    s=zero(eltype(first(first(cp))))
    for p in cp
        d=last(p)-first(p)
        s+=d^2
    end
    s
end
function sumdist_noiter(a)
    s=zero(eltype(a))
    v1=first(a)
    @inbounds for i in 2:length(a)
        d=a[i]-v1
        s+=d^2
        v1=a[i]
    end
    d=v1-first(a)
    s+=d^2
    s
end


function test()
    a=rand(100_000_000);
    @show sumdist(PairIterator(a))==sumdist(cyclic_pairs(a))==sumdist_noiter(a)==sumdist(CyclicPairs(a))
    t0=@benchmark sumdist_noiter($a)
    t1=@benchmark sumdist(PairIterator($a))
    t2=@benchmark sumdist(cyclic_pairs($a))
    t3=@benchmark sumdist(CyclicPairs($a))
    println(t1)
    @show t2
    @show "gflops 0",3*length(a)/(minimum(t0).time)
    @show "gflops 1",3*length(a)/(minimum(t1).time)
    @show "gflops 2",3*length(a)/(minimum(t2).time)
    @show "gflops 3",3*length(a)/(minimum(t3).time)

    @show t4=@btime sumdist_noiter(1:1000) #funny compiler trick
    # @show @btime sumdist(PairIterator(1:1000)) # does not work with iterator
    @show t5=@btime sumdist(cyclic_pairs(1:1000))
end

test()
3 Likes

Great !!!
Thank you very much.