Hey, I have a very simple problem: given a vector v::AbstractVector{<:Integer}
I’d like to find the last index idx
such that v[idx] ≠ idx
.
Here’s a demo/test function and some benchmarks.
I tried several things but can’t really make it much faster. If I can nerd-snipe someone in this forum to show me better patterns for optimizing this, or just a better execution of the ideas below I’d be very grateful!
Summary:
myfindlast0
: a simple application offindlast
+ resolution ofnothing
myfindlast1
: a manual, short-circuiting loopmyfindlast2
: use chunks (of 64) +@nany
myfindlast3
: as above but deal with the tail firstmyfindlast4
: try to force vector instructions by|=
over the whole chunk + loop for finding the indexmyfindlast5
: as above, but try to useInt
asset_int
to find the index
function myfindlast0(v::AbstractVector{<:Integer})
return something(findlast(i -> v[i] ≠ i, eachindex(v)), firstindex(v))
end
function test_findlast(fl)
v0 = collect(1:2^10)
v1 = copy(v0)
k = length(v1) - 64 + 18
v1[k] = k + 1
v2 = copy(v0)
v2[2] = 3
w0 = collect(1:2^10+53)
w1 = copy(w0)
l = length(w1) - 128 + 18
w1[l] = l + 1
w2 = copy(w0)
w2[2] = 3
let fl = fl
@test fl(v0) == myfindlast0(v0)
@test fl(v1) == myfindlast0(v1)
@test fl(v2) == myfindlast0(v2)
@test fl(w0) == myfindlast0(w0)
@test fl(w1) == myfindlast0(w1)
@test fl(w2) == myfindlast0(w2)
end
let fl = fl
@info fl
@btime $fl($v0)
@btime $fl($v1)
@btime $fl($v2)
@btime $fl($w0)
@btime $fl($w1)
@btime $fl($w2)
end
end
test_findlast(myfindlast0)
[ Info: myfindlast0
253.200 ns (0 allocations: 0 bytes)
14.206 ns (0 allocations: 0 bytes)
257.921 ns (0 allocations: 0 bytes)
267.984 ns (0 allocations: 0 bytes)
33.197 ns (0 allocations: 0 bytes)
266.925 ns (0 allocations: 0 bytes)
Maybe this is fast, or maybe not. Let’s try to write manually a short-circuiting loop (no vectorisation, but low latency if idx
is at the end of v
:
function myfindlast1(v::AbstractVector{<:Integer})
@inbounds for idx in lastindex(v):-1:firstindex(v)
v[idx] ≠ idx && return idx
end
return firstindex(v)
end
test_findlast(myfindlast1)
[ Info: myfindlast1
260.920 ns (0 allocations: 0 bytes)
17.566 ns (0 allocations: 0 bytes)
259.644 ns (0 allocations: 0 bytes)
271.759 ns (0 allocations: 0 bytes)
32.701 ns (0 allocations: 0 bytes)
270.607 ns (0 allocations: 0 bytes)
Well that didn’t change much, did it? ok, so let’s try something from the standard textbook; Let’s divide the array into chunks of constant size (a power of 2), check if any of elements does what we want (chunks are now statically sized!) and then deal with the tail. My first try uses @nany
:
@inline function __unsafe_findlast(v, lastidx, firstidx)
@inbounds for idx in lastidx:-1:firstidx
v[idx] ≠ idx && return idx
end
return lastidx + 1 # so that we're not returning nothing
end
function myfindlast2(v::AbstractVector{<:Integer})
step = 64
i = length(v)
@inbounds while i > step
found = Base.Cartesian.@nany 64 d -> @inbounds v[i+1-d] ≠ i + 1 - d
if found
ans = __unsafe_findlast(v, i, i + 1 - step)
@assert ans ≤ i
return ans
end
i -= step
end
if i > firstindex(v)
ans = __unsafe_findlast(v, i, firstindex(v))
return ifelse(ans ≤ i, ans, firstindex(v))
else
return firstindex(v)
end
end
test_findlast(myfindlast2)
[ Info: myfindlast2
178.678 ns (0 allocations: 0 bytes)
25.736 ns (0 allocations: 0 bytes)
178.797 ns (0 allocations: 0 bytes)
185.015 ns (0 allocations: 0 bytes)
30.736 ns (0 allocations: 0 bytes)
185.514 ns (0 allocations: 0 bytes)
Ok, now we paid with latency for overall speed up of almost 50%; I could buy it
The next thing I thought of is that even if we’re traversing memory 64 elements at time, these might not be aligned in memory! (remember we’re traversing v
in reverse). Let’s try to deal with the tail first.
function myfindlast3(v::AbstractVector{<:Integer})
step = 64
# dealing with the tail first
i = (length(v) >> 6) << 6
# @assert rem(i, step) == 0
if i + 1 < lastindex(v)
ans = __unsafe_findlast(v, lastindex(v), i + 1)
if ans ≤ lastindex(v)
return ans
end
end
@inbounds while i ≥ step
found = Base.Cartesian.@nany 64 d -> @inbounds v[i+1-d] ≠ i + 1 - d
if found
ans = __unsafe_findlast(v, i, i + 1 - step)
@assert ans ≤ i
return ans
end
i -= step
end
return firstindex(v)
end
test_findlast(myfindlast3)
[ Info: myfindlast3
170.201 ns (0 allocations: 0 bytes)
25.902 ns (0 allocations: 0 bytes)
189.619 ns (0 allocations: 0 bytes)
186.107 ns (0 allocations: 0 bytes)
49.392 ns (0 allocations: 0 bytes)
201.911 ns (0 allocations: 0 bytes)
Hmm, in the example we actually do need to deal with tail we actually got a slowdown ?!
By looking at @code_native
it seems that this one didn’t use vector instructions though.
Let’s try to use them! Here we’re traversing the block (of size 64
) which does contain the requested index twice.
@inline function __unsafe_findlast2(v, lastidx, firstidx)
found = false
@inbounds for idx in firstidx:lastidx
found |= v[idx] ≠ idx
end
if found
return __unsafe_findlast(v, lastidx, firstidx)
else
return lastidx + 1
end
end
function myfindlast4(v::AbstractVector{<:Integer})
step = 64
i = length(v)
@inbounds while i > step
ans = __unsafe_findlast2(v, i, i + 1 - step)
ans ≤ i && return ans
i -= step
end
if i > firstindex(v)
ans = __unsafe_findlast2(v, i, firstindex(v))
return ifelse(ans ≤ i, ans, firstindex(v))
else
return firstindex(v)
end
end
test_findlast(myfindlast4)
[ Info: myfindlast4
151.057 ns (0 allocations: 0 bytes)
26.392 ns (0 allocations: 0 bytes)
168.459 ns (0 allocations: 0 bytes)
157.583 ns (0 allocations: 0 bytes)
36.751 ns (0 allocations: 0 bytes)
179.509 ns (0 allocations: 0 bytes)
Ok, we got back to speed, but not much else.
I made another try which uses Int
as small setint, which does use vector instructions as well, but is overall slower (and more complicated).
@inline function __unsafe_findlast3(v, lastidx, firstidx)
step = lastidx - firstidx + 1
set_int = zero(UInt)
n_unused_bits = sizeof(set_int) * 8 - step
@assert n_unused_bits ≥ 0
found = false
@inbounds for idx in firstidx:lastidx
k = v[idx] ≠ idx
found |= k
a = idx - firstidx
set_int |= (k << (k * a))
end
return ifelse(
found,
lastidx - trailing_zeros(bitreverse(set_int)) + n_unused_bits,
lastidx + 1,
)
end
function myfindlast5(v::AbstractVector{<:Integer})
step = 64
i = length(v)
@inbounds while i ≥ step
ans = __unsafe_findlast3(v, i, i + 1 - step)
ans ≤ i && return ans
i -= step
end
if i > firstindex(v)
ans = __unsafe_findlast3(v, i, firstindex(v))
return ifelse(ans ≤ i, ans, firstindex(v))
else
return firstindex(v)
end
end
test_findlast(myfindlast5)
[ Info: myfindlast5
323.478 ns (0 allocations: 0 bytes)
18.634 ns (0 allocations: 0 bytes)
335.813 ns (0 allocations: 0 bytes)
352.505 ns (0 allocations: 0 bytes)
42.258 ns (0 allocations: 0 bytes)
353.798 ns (0 allocations: 0 bytes)