I am trying to implement a custom iterator for multidimensional arrays, which among other things, allows one to specify the order in which the axes are iterated. Sadly, I am having trouble getting even close to the performance of CartesianIndices
even though almost all the code is identical. In particular, I found that the performance depends greatly on whether the index calculation is done within the iterator or in the caller.
Below is a MWE. Iter1
is basically a stripped-down version of CartesianIndices
and achieves good performance. It returns a tuple which is converted to an appropriate linear index in the caller. Iter2
is nearly identical to Iter1
; the only difference is that the conversion to a linear index is done within the iterate
function instead of in the caller. Surprisingly to me, this results in nearly 8x slowdown. (This is on Julia 1.5.)
struct Iter1{N}
strides::Dims{N}
stop::Dims{N}
function Iter1(sz::Dims{N}, p::Dims{N}) where {N}
strides = cumprod(ntuple(i -> i>1 ? sz[i-1] : 1, Val(N)))
new{N}(ntuple(i->strides[p[i]], Val(N)), sz)
end
end
# (Exactly the same as Iter1)
struct Iter2{N}
strides::Dims{N}
stop::Dims{N}
function Iter2(sz::Dims{N}, p::Dims{N}) where {N}
strides = cumprod(ntuple(i -> i>1 ? sz[i-1] : 1, Val(N)))
new{N}(ntuple(i->strides[p[i]], Val(N)), sz)
end
end
# Iterating Iter1 returns a Tuple
@inline function Base.iterate(it::Iter1{N}) where {N}
if any(map(s -> s<1, it.stop))
return nothing
end
I = ntuple(i->1, Val(N))
return I, I
end
@inline function Base.iterate(it::Iter1{N}, state::Dims{N}) where {N}
valid, I = __inc(state, it.stop)
valid || return nothing
return I, I
end
# Iterating Iter2 returns a linear index
@inline function Base.iterate(it::Iter2{N}) where {N}
if any(map(s -> s<1, it.stop))
return nothing
end
I = ntuple(i->1, Val(N))
return 1, I
end
@inline function Base.iterate(it::Iter2, state::Dims{N}) where {N}
valid, I = __inc(state, it.stop)
valid || return nothing
return calc_index(I, it.strides), I
end
# helper functions
@inline __inc(::Tuple{}, ::Tuple{}) = false, ()
@inline function __inc(state::Dims{N}, stop::Dims{N}) where {N}
if state[1] < stop[1]
return true, (state[1]+1, Base.tail(state)...)
end
valid, I = __inc(Base.tail(state), Base.tail(stop))
return valid, (1, I...)
end
@inline function calc_index(I::Dims{N}, strides::Dims{N}) where {N}
1 + sum(ntuple(i -> (I[i]-1)*strides[i], Val(N)))
end
# Benchmarks
function test1(A::AbstractArray{Float64,N}, p::Dims{N}) where {N}
it = Iter1(size(A), p)
for I in it
iA = calc_index(I, it.strides)
A[iA] += 1.0
end
end
function test2(A::AbstractArray{Float64,N}, p::Dims{N}) where {N}
it = Iter2(size(A), p)
for iA in it
A[iA] += 1.0
end
end
function test2(A::AbstractArray{Float64,N}) where {N}
it = Iter2(size(A), ntuple(i->1+N-i, Val(N)))
for iA in it
A[iA] += 1.0
end
end
using BenchmarkTools
A = rand(4,4,4,4,4)
@btime test1($A, (5,4,3,2,1)) # 710 ns
@btime test2($A, (5,4,3,2,1)) # 5.2 μs
@btime test2($A) # 1.6 μs
Note that the only difference between test1
and test2
is where calc_index
is called. Evidently the compiler is utilizing information about it.strides
in test1
that it is not using in iterate(::Iter2)
. Hard-coding the permutation p
(the second version of test2
) helps somewhat, but is still 2x slower than test1
.
I’d greatly appreciate any insight anyone has to offer. I’ve been pulling my hair out for two days trying to figure out why my iterator is performing so poorly.