Times of some attempts, note that my x
here has some repeated elements:
julia> Random.seed!(42); x = rand(randn(1_000), 1_000, 1_000); # some repeats
julia> sum(x .== maximum(x, dims=1), dims=1)
1×1000 Array{Int64,2}:
1 1 2 2 1 2 1 4 2 1 1 1 … 2 1 2 1 3 3 2 4 2 1 3 3
julia> @btime vreduce(max, $x, dims=1);
193.752 μs (1 allocation: 7.94 KiB)
julia> @btime findmax_mask($x);
5.286 ms (4 allocations: 7.65 MiB)
julia> @btime remove_extra_1s(($x .== maximum($x, dims=1))); # from vkv
3.063 ms (8 allocations: 134.48 KiB)
julia> @btime my_mask_creator($x); # from Henrique_Becker
1.263 ms (2 allocations: 976.70 KiB)
julia> @btime findmax_mask2($x); # from Seif_Shebl --> 1.667 ms no repeats
2.308 ms (3 allocations: 122.25 KiB)
julia> @btime findmax_mask3($x); # --> 1.149 ms no repeats
998.595 μs (4 allocations: 130.19 KiB)
julia> @btime findmax_mask3_threads($x); # multi-threaded
488.894 μs (95 allocations: 136.55 KiB)
# Attempts at a vectorised version:
julia> @btime findmax_mask4($x); # --> 1.329 ms no repeats
1.328 ms (3 allocations: 984.64 KiB)
julia> @btime findmax_mask6($x);
797.529 μs (4 allocations: 130.19 KiB)
# Alternative x, some times change, marked -->
julia> Random.seed!(0); x = rand(1_000, 1_000); # no repeats
julia> sum(x .== maximum(x, dims=1), dims=1)
1×1000 Array{Int64,2}:
1 1 1 1 1 1 1 1 1 1 1 1 … 1 1 1 1 1 1 1 1 1 1 1 1
julia> VERSION
v"1.5.0"
All of these are quite far from vreduce
, which made me think it ought to be possible to do better. findmax_mask6
is an attempt to do this, with LoopVectorization, but it’s not actually faster. (Edit – some improvements.) But perhaps @Elrod knows a better way?
using LoopVectorization, Tullio, Random, BenchmarkTools
function findmax_mask3(x::Matrix)
y = vreduce(max, x, dims=1)
mask = falses(size(x)) # BitArray{2}
# mask = fill(false, size(x)) # Array{Bool,2} much the same speed
for c in axes(x,2)
@inbounds for r in axes(x,1)
flag = x[r,c] == y[c]
mask[r,c] = flag
flag && break
end
end
return mask
end
function findmax_mask3_threads(x::Matrix)
@tullio (max) y[c] := x[r,c];
mask = falses(size(x))
Threads.@threads for c in axes(x,2)
@inbounds for r in axes(x,1)
flag = x[r,c] == y[c]
mask[r,c] = flag
flag && break # this prevents @simd
end
end
return mask
end
function findmax_mask4(x::Matrix) # no branches
y = vreduce(max, x, dims=1)
# mask = falses(size(x)) # BitArray{2}
mask = fill(false, size(x)) # Array{Bool,2} is faster
@inbounds for c in axes(x,2)
seen = false
@simd for r in axes(x,1)
flag = !seen & (x[r,c] == y[c])
mask[r,c] = flag
seen |= flag
end
end
return mask
end
x0 = rand(1:5, 10,30) # small, many repeats
x0 = rand(rand(100), 100, 1000); # some repeats
findmax_mask3(x0) == findmax_mask3_threads(x0) == findmax_mask2(x0) # ok
findmax_mask4(x0) == findmax_mask2(x0) # ok
using LoopVectorization: vifelse
using LoopVectorization.VectorizationBase: SVec, vload, VectorizationBase, Mask
Mask(0x03) # Mask{8,Bool}<1, 1, 0, 0, 0, 0, 0, 0>
SVec{4,Int}(1,2,3,4) # SVec{4,Int64}<1, 2, 3, 4>
@inline onlyone(cond::Bool, seen::Int) = cond && iszero(seen)
@inline function onlyone(cond::Mask{W}, seen::Union{Int,SVec}) where {W}
allzero(seen) || return zero(cond)
return Mask(hibit(cond.u))
end
@inline function hibit(n::UInt16) # https://stackoverflow.com/questions/53161/find-the-highest-order-bit-in-c
n |= (n >> 1);
n |= (n >> 2);
n |= (n >> 4);
n |= (n >> 8);
return n - (n >> 1)
end
@inline function hibit(n::UInt8)
n |= (n >> 1);
n |= (n >> 2);
n |= (n >> 4);
return n - (n >> 1)
end
@inline allzero(seen::Int) = iszero(seen)
# @inline allzero(seen::SVec{N,Int}) where {N} = all(ntuple(i -> iszero(seen.data[i].value), N))
@inline allzero(seen::SVec{N,Int}) where {N} = iszero((!iszero(seen)).u)
@inline anyone(cond::Bool) = cond
@inline anyone(cond::Mask) = cond != zero(cond)
onlyone(Mask(0x03), 0) # Mask{8,Bool}<0, 1, 0, 0, 0, 0, 0, 0>
anyone(Mask(0x03))
function findmax_mask6(x::Matrix)
y = vreduce(max, x, dims=1)
mask = falses(size(x)) # BitArray{2} # faster
# mask = fill(false, size(x)) # Array{Bool,2}
@avx for c in axes(x,2)
seen = 0
for r in axes(x,1)
flag = onlyone(x[r,c] == y[c], seen)
mask[r,c] = flag
seen += anyone(flag)
end
end
return mask
end
findmax_mask6(x0) == findmax_mask2(x0) # false, because it chooses the last not first