In the version based on cyclic_pairs
, it looks like the performance is lost when composing the ConsecutivePairs
iterator with Iterators.flatten
in order to loop back to the first element.
Implementing the whole logic (iterating on pairs + looping back to the first element) in one iterator gets the for
loop performance back:
struct CyclicPairs{T}
subiter::T
end
@inline function Base.iterate(cp::CyclicPairs)
i = iterate(cp.subiter)
i === nothing && return nothing
first, substate = i
iterate(cp, (substate, first, first, #=finished=#false))
end
@inline function Base.iterate(cp::CyclicPairs, state)
(substate, latest, first, finished) = state
i = iterate(cp.subiter, substate)
if i === nothing
if finished
return nothing
else
return ((latest, first), (substate, latest, first, #=finished=#true))
end
end
current, substate = i
return ((latest, current), (substate, current, first, #=finished=#false))
end
@inline Base.length(cp::CyclicPairs) = length(cp.subiter)
Here are the benchmarked times on my machine (with a very slight degradation in performance, probably due to the allocation of the iterator itself):
julia> @btime sumdist_noiter($a)
782.621 ns (0 allocations: 0 bytes)
166.4520614732102
julia> @btime sumdist(CyclicPairs($a))
797.495 ns (1 allocation: 16 bytes)
166.4520614732102
Or, re-using your own benchmarks:
("gflops 0", (3 * length(a)) / (minimum(t0)).time) = ("gflops 0", 3.8318927566066048)
("gflops 1", (3 * length(a)) / (minimum(t1)).time) = ("gflops 1", 0.5404107121412274)
("gflops 2", (3 * length(a)) / (minimum(t2)).time) = ("gflops 2", 0.7005137100540396)
("gflops 3", (3 * length(a)) / (minimum(t3)).time) = ("gflops 3", 3.7575645707806506)
Complete code
using BenchmarkTools
#########################
struct ConsecutivePairs{T}
subiter::T
end
@inline function Base.iterate(cp::ConsecutivePairs)
i = iterate(cp.subiter)
i === nothing && return nothing
first, substate = i
iterate(cp, (substate, first))
end
@inline function Base.iterate(cp::ConsecutivePairs, state)
(substate, first) = state
i = iterate(cp.subiter, substate)
i === nothing && return nothing
second, substate = i
return ((first, second), (substate, second))
end
@inline cyclic_pairs(it) = ConsecutivePairs(Iterators.flatten((it, first(it))))
#######################
struct CyclicPairs{T}
subiter::T
end
@inline function Base.iterate(cp::CyclicPairs)
i = iterate(cp.subiter)
i === nothing && return nothing
first, substate = i
iterate(cp, (substate, first, first, #=finished=#false))
end
@inline function Base.iterate(cp::CyclicPairs, state)
(substate, latest, first, finished) = state
i = iterate(cp.subiter, substate)
if i === nothing
if finished
return nothing
else
return ((latest, first), (substate, latest, first, #=finished=#true))
end
end
current, substate = i
return ((latest, current), (substate, current, first, #=finished=#false))
end
@inline Base.length(cp::CyclicPairs) = length(cp.subiter)
################################
struct PairIterator{T}
it::Vector{T}
end
# Required method
function Base.iterate(p::PairIterator, i::Int=1)
l = length(p.it)
if i > l
return nothing
else
j = mod(i, l)+1
return (p.it[i], p.it[j]), i+1
end
end
# Important optional methods
Base.eltype(::Type{PairIterator{T}}) where {T} = Tuple{T,T}
Base.length(p::PairIterator) = length(p.it)
##############################
function sumdist(cp)
s=zero(eltype(first(first(cp))))
for p in cp
d=last(p)-first(p)
s+=d^2
end
s
end
function sumdist_noiter(a)
s=zero(eltype(a))
v1=first(a)
@inbounds for i in 2:length(a)
d=a[i]-v1
s+=d^2
v1=a[i]
end
d=v1-first(a)
s+=d^2
s
end
function test()
a=rand(100_000_000);
@show sumdist(PairIterator(a))==sumdist(cyclic_pairs(a))==sumdist_noiter(a)==sumdist(CyclicPairs(a))
t0=@benchmark sumdist_noiter($a)
t1=@benchmark sumdist(PairIterator($a))
t2=@benchmark sumdist(cyclic_pairs($a))
t3=@benchmark sumdist(CyclicPairs($a))
println(t1)
@show t2
@show "gflops 0",3*length(a)/(minimum(t0).time)
@show "gflops 1",3*length(a)/(minimum(t1).time)
@show "gflops 2",3*length(a)/(minimum(t2).time)
@show "gflops 3",3*length(a)/(minimum(t3).time)
@show t4=@btime sumdist_noiter(1:1000) #funny compiler trick
# @show @btime sumdist(PairIterator(1:1000)) # does not work with iterator
@show t5=@btime sumdist(cyclic_pairs(1:1000))
end
test()