Cyclic pair iterator?

#1

Hi,

I am not very familiar with iterators, so maybe someone can give me some hints.

I would like to build a consecutive pairs iterator that would take a collection C and return an iterator over the consecutive pairs. The last pair ending with the first element:

a=[3,4,2,1] # could be an iterator too
cp=cpair(a) #to be defined
for p in cp
   println(p)
end

would return

(3,4)
(4,2)
(2,1)
(1,3)

with the best possible performances (less allocations).

Thank you for your help

#2

You can just look at the docs and define methods required for iteration, see https://docs.julialang.org/en/v1/manual/interfaces/#man-interface-iteration-1

Here is one implementation:

struct PairIterator{T}
    it::Vector{T}
end

# Required method
function Base.iterate(p::PairIterator, i::Int=1)
    l = length(p.it)
    if i > l
        return nothing
    else
        j = mod(i, l)+1
        return (p.it[i], p.it[j]), i+1
    end
end

# Important optional methods
Base.eltype(::Type{PairIterator{T}}) where {T} = Tuple{T,T}
Base.length(p::PairIterator) = length(p.it)

Example:

julia> for i in PairIterator([3,4,2,1])
           println(i)
       end
(3, 4)
(4, 2)
(2, 1)
(1, 3)
1 Like
#3

Thank you very much !

Your implementation does not handle iterator input.
I add @ffevotte ConsecutivePairs proposal in the following MWE.
The problem is that the sumdist function is 3-4 time slower using pair iterators than the loop based sumdist_noiter.

On my machine I obtain:

("gflops 0", 3.2064275973906313)  # sumdist_noiter
("gflops 1", 0.8656447249774571)  # sum dist with PairIterator
("gflops 2", 0.9534782090500973)  # sum_dist with ConsecutivePairs

I wonder if it is possible to improve the iterator based implementation (which is prettier) and recover the performance of loop based implementation…

Here is the MWE:

using BenchmarkTools
#########################
struct ConsecutivePairs{T}
    subiter::T
end

@inline function Base.iterate(cp::ConsecutivePairs)
    i = iterate(cp.subiter)
    i === nothing && return nothing
    first, substate = i

    iterate(cp, (substate, first))
end

@inline function Base.iterate(cp::ConsecutivePairs, state)
    (substate, first) = state

    i = iterate(cp.subiter, substate)
    i === nothing && return nothing
    second, substate = i

    return ((first, second), (substate, second))
end
cyclic_pairs(it) = ConsecutivePairs(Iterators.flatten((it, first(it))))
################################
struct PairIterator{T}
    it::Vector{T}
end

# Required method
function Base.iterate(p::PairIterator, i::Int=1)
    l = length(p.it)
    if i > l
        return nothing
    else
        j = mod(i, l)+1
        return (p.it[i], p.it[j]), i+1
    end
end

# Important optional methods
Base.eltype(::Type{PairIterator{T}}) where {T} = Tuple{T,T}
Base.length(p::PairIterator) = length(p.it)
##############################
function sumdist(cp)
    s=zero(eltype(first(first(cp))))
    for p in cp
        d=last(p)-first(p)
        s+=d^2
    end
    s
end
function sumdist_noiter(a)
    s=zero(eltype(a))
    v1=first(a)
    @inbounds for i in 2:length(a)
        d=a[i]-v1
        s+=d^2
        v1=a[i]
    end
    d=v1-first(a)
    s+=d^2
    s
end


function test()

    a=rand(1000)
    @show sumdist(PairIterator(a))==sumdist(cyclic_pairs(a))==sumdist_noiter(a)
    t0=@benchmark sumdist_noiter($a)
    t1=@benchmark sumdist(PairIterator($a))
    t2=@benchmark sumdist(cyclic_pairs($a))
    println(t1)
    @show t2
    @show "gflops 0",3*length(a)/(minimum(t0).time)
    @show "gflops 1",3*length(a)/(minimum(t1).time)
    @show "gflops 2",3*length(a)/(minimum(t2).time)

    @show t3=@btime sumdist_noiter(1:1000) #funny compiler trick
    # @show @btime sumdist(PairIterator(1:1000)) # does not work with iterator
    @show t4=@btime sumdist(cyclic_pairs(1:1000))

end

test()
1 Like
#4

In the version based on cyclic_pairs, it looks like the performance is lost when composing the ConsecutivePairs iterator with Iterators.flatten in order to loop back to the first element.

Implementing the whole logic (iterating on pairs + looping back to the first element) in one iterator gets the for loop performance back:

struct CyclicPairs{T}
    subiter::T
end
@inline function Base.iterate(cp::CyclicPairs)
    i = iterate(cp.subiter)
    i === nothing && return nothing
    first, substate = i

    iterate(cp, (substate, first, first, #=finished=#false))
end
@inline function Base.iterate(cp::CyclicPairs, state)
    (substate, latest, first, finished) = state

    i = iterate(cp.subiter, substate)
    if i === nothing
        if finished
            return nothing
        else
            return ((latest, first), (substate, latest, first, #=finished=#true))
        end
    end
    current, substate = i

    return ((latest, current), (substate, current, first, #=finished=#false))
end
@inline Base.length(cp::CyclicPairs) = length(cp.subiter)

Here are the benchmarked times on my machine (with a very slight degradation in performance, probably due to the allocation of the iterator itself):

julia> @btime sumdist_noiter($a)
  782.621 ns (0 allocations: 0 bytes)
166.4520614732102

julia> @btime sumdist(CyclicPairs($a))
  797.495 ns (1 allocation: 16 bytes)
166.4520614732102

Or, re-using your own benchmarks:

("gflops 0", (3 * length(a)) / (minimum(t0)).time) = ("gflops 0", 3.8318927566066048)
("gflops 1", (3 * length(a)) / (minimum(t1)).time) = ("gflops 1", 0.5404107121412274)
("gflops 2", (3 * length(a)) / (minimum(t2)).time) = ("gflops 2", 0.7005137100540396)
("gflops 3", (3 * length(a)) / (minimum(t3)).time) = ("gflops 3", 3.7575645707806506)


Complete code
using BenchmarkTools


#########################


struct ConsecutivePairs{T}
    subiter::T
end

@inline function Base.iterate(cp::ConsecutivePairs)
    i = iterate(cp.subiter)
    i === nothing && return nothing
    first, substate = i

    iterate(cp, (substate, first))
end

@inline function Base.iterate(cp::ConsecutivePairs, state)
    (substate, first) = state

    i = iterate(cp.subiter, substate)
    i === nothing && return nothing
    second, substate = i

    return ((first, second), (substate, second))
end
@inline cyclic_pairs(it) = ConsecutivePairs(Iterators.flatten((it, first(it))))


#######################


struct CyclicPairs{T}
    subiter::T
end

@inline function Base.iterate(cp::CyclicPairs)
    i = iterate(cp.subiter)
    i === nothing && return nothing
    first, substate = i

    iterate(cp, (substate, first, first, #=finished=#false))
end

@inline function Base.iterate(cp::CyclicPairs, state)
    (substate, latest, first, finished) = state

    i = iterate(cp.subiter, substate)
    if i === nothing
        if finished
            return nothing
        else
            return ((latest, first), (substate, latest, first, #=finished=#true))
        end
    end
    current, substate = i

    return ((latest, current), (substate, current, first, #=finished=#false))
end

@inline Base.length(cp::CyclicPairs) = length(cp.subiter)


################################


struct PairIterator{T}
    it::Vector{T}
end

# Required method
function Base.iterate(p::PairIterator, i::Int=1)
    l = length(p.it)
    if i > l
        return nothing
    else
        j = mod(i, l)+1
        return (p.it[i], p.it[j]), i+1
    end
end

# Important optional methods
Base.eltype(::Type{PairIterator{T}}) where {T} = Tuple{T,T}
Base.length(p::PairIterator) = length(p.it)


##############################


function sumdist(cp)
    s=zero(eltype(first(first(cp))))
    for p in cp
        d=last(p)-first(p)
        s+=d^2
    end
    s
end
function sumdist_noiter(a)
    s=zero(eltype(a))
    v1=first(a)
    @inbounds for i in 2:length(a)
        d=a[i]-v1
        s+=d^2
        v1=a[i]
    end
    d=v1-first(a)
    s+=d^2
    s
end


function test()
    a=rand(100_000_000);
    @show sumdist(PairIterator(a))==sumdist(cyclic_pairs(a))==sumdist_noiter(a)==sumdist(CyclicPairs(a))
    t0=@benchmark sumdist_noiter($a)
    t1=@benchmark sumdist(PairIterator($a))
    t2=@benchmark sumdist(cyclic_pairs($a))
    t3=@benchmark sumdist(CyclicPairs($a))
    println(t1)
    @show t2
    @show "gflops 0",3*length(a)/(minimum(t0).time)
    @show "gflops 1",3*length(a)/(minimum(t1).time)
    @show "gflops 2",3*length(a)/(minimum(t2).time)
    @show "gflops 3",3*length(a)/(minimum(t3).time)

    @show t4=@btime sumdist_noiter(1:1000) #funny compiler trick
    # @show @btime sumdist(PairIterator(1:1000)) # does not work with iterator
    @show t5=@btime sumdist(cyclic_pairs(1:1000))
end

test()
2 Likes
#5

Great !!!
Thank you very much.