I am trying to implement a custom iterator for multidimensional arrays, which among other things, allows one to specify the order in which the axes are iterated. Sadly, I am having trouble getting even close to the performance of CartesianIndices even though almost all the code is identical. In particular, I found that the performance depends greatly on whether the index calculation is done within the iterator or in the caller.
Below is a MWE. Iter1 is basically a stripped-down version of CartesianIndices and achieves good performance. It returns a tuple which is converted to an appropriate linear index in the caller. Iter2 is nearly identical to Iter1; the only difference is that the conversion to a linear index is done within the iterate function instead of in the caller. Surprisingly to me, this results in nearly 8x slowdown. (This is on Julia 1.5.)
struct Iter1{N}
strides::Dims{N}
stop::Dims{N}
function Iter1(sz::Dims{N}, p::Dims{N}) where {N}
strides = cumprod(ntuple(i -> i>1 ? sz[i-1] : 1, Val(N)))
new{N}(ntuple(i->strides[p[i]], Val(N)), sz)
end
end
# (Exactly the same as Iter1)
struct Iter2{N}
strides::Dims{N}
stop::Dims{N}
function Iter2(sz::Dims{N}, p::Dims{N}) where {N}
strides = cumprod(ntuple(i -> i>1 ? sz[i-1] : 1, Val(N)))
new{N}(ntuple(i->strides[p[i]], Val(N)), sz)
end
end
# Iterating Iter1 returns a Tuple
@inline function Base.iterate(it::Iter1{N}) where {N}
if any(map(s -> s<1, it.stop))
return nothing
end
I = ntuple(i->1, Val(N))
return I, I
end
@inline function Base.iterate(it::Iter1{N}, state::Dims{N}) where {N}
valid, I = __inc(state, it.stop)
valid || return nothing
return I, I
end
# Iterating Iter2 returns a linear index
@inline function Base.iterate(it::Iter2{N}) where {N}
if any(map(s -> s<1, it.stop))
return nothing
end
I = ntuple(i->1, Val(N))
return 1, I
end
@inline function Base.iterate(it::Iter2, state::Dims{N}) where {N}
valid, I = __inc(state, it.stop)
valid || return nothing
return calc_index(I, it.strides), I
end
# helper functions
@inline __inc(::Tuple{}, ::Tuple{}) = false, ()
@inline function __inc(state::Dims{N}, stop::Dims{N}) where {N}
if state[1] < stop[1]
return true, (state[1]+1, Base.tail(state)...)
end
valid, I = __inc(Base.tail(state), Base.tail(stop))
return valid, (1, I...)
end
@inline function calc_index(I::Dims{N}, strides::Dims{N}) where {N}
1 + sum(ntuple(i -> (I[i]-1)*strides[i], Val(N)))
end
# Benchmarks
function test1(A::AbstractArray{Float64,N}, p::Dims{N}) where {N}
it = Iter1(size(A), p)
for I in it
iA = calc_index(I, it.strides)
A[iA] += 1.0
end
end
function test2(A::AbstractArray{Float64,N}, p::Dims{N}) where {N}
it = Iter2(size(A), p)
for iA in it
A[iA] += 1.0
end
end
function test2(A::AbstractArray{Float64,N}) where {N}
it = Iter2(size(A), ntuple(i->1+N-i, Val(N)))
for iA in it
A[iA] += 1.0
end
end
using BenchmarkTools
A = rand(4,4,4,4,4)
@btime test1($A, (5,4,3,2,1)) # 710 ns
@btime test2($A, (5,4,3,2,1)) # 5.2 μs
@btime test2($A) # 1.6 μs
Note that the only difference between test1 and test2 is where calc_index is called. Evidently the compiler is utilizing information about it.strides in test1 that it is not using in iterate(::Iter2). Hard-coding the permutation p (the second version of test2) helps somewhat, but is still 2x slower than test1.
I’d greatly appreciate any insight anyone has to offer. I’ve been pulling my hair out for two days trying to figure out why my iterator is performing so poorly.