# How to perform an argmax / argmin on a subset of a vector?

Hi! Is there a way to perform an `argmax` on a subset of a vector determined by a bitarray. This bitarray tells whether the index should be considered or not.

Example: if my vector is `x = [0, 2, 1, 0]` and the bitarray is `b = [true, false, true, false]`, then the result should be `3` as only indexes `1` and `3` should be considered (as indicated by `b`).
I cannot do `argmax(x[b])` as “filtering” with `b` modifies the indexing and thus return `2` instead of `3`.

Thanks!

Probably someone will have a better idea, but one quick and dirty way for argmax could be:

``````julia> x = [0, 2, 1, 0]
4-element Vector{Int64}:
0
2
1
0

julia> b = [true, false, true, false]
4-element Vector{Bool}:
1
0
1
0

julia> argmax(x.*b)
3
``````

All the elements multiplied by false will become zero and therefore will never be a max, provided you know your max is greater than zero.

You could easily adapt this method for argmin or a little bit more general by combining with Inf (see Gustaphe’s response below), but it will allocate an intermediate vector.

The other natural way would be with a loop. I think this is the best choice.

1 Like
``````argmax(x .- inf .* .!(b))
``````

Doesn’t rely on positive `x`.

1 Like

In Julia 1.7 you can use (untested, on phone in bed):

``````_, ind = findmax((x,b) -> b ? x : -inf, zip(x, b))
``````

Making a loop is probably more self documenting, but how cool is that?

in julia 1.6, this seems to work too:

``````julia> argmax(view(x,b))
2
``````

This doesn’t do what is asked (the answer should be 3, not 2).

2 Likes

oh, sorry, you are right

``````julia> findall(b)[argmax(@view x[b])]
3
``````
3 Likes

this (now) also returns 3

``````julia> c = argmax(view(x,b))
2 #index position of the nth true value in b
julia> findnext(identity,b,c)
3
``````
1 Like

Dumber alternative:

``````eachindex(x)[b][argmax(x[b])]
``````
1 Like

One more:

``````argmax(sortperm(x) .* b)
``````
1 Like

I ran a benchmark. Code below.

Note logarithmic `y` scale. Lower is better, and the lowest point is the most important. The numbers vary a bit with different runs, but the order is constant as far as I’ve seen. These are the functions I could make pass the test. My conclusion: `findall` or `findnext` wins (unless we count writing a loop to do it, which is a bit sad)

Code
``````#!/bin/julia

using BenchmarkTools
using BenchmarkPlots
using StatsPlots
using Test
using Random: bitrand

BenchmarkTools.DEFAULT_PARAMETERS.seconds = 20
N = 10_000

preservingargmax = Dict{Symbol,Function}()

preservingargmax[:Inf] = (x, b) -> argmax(x .- Inf .* .!(b))
preservingargmax[:findall] = (x, b) -> findall(b)[argmax(x[b])]
preservingargmax[:findall_views] = (x, b) -> first(@view findall(b)[argmax(@view x[b])])
preservingargmax[:findnext] = (x, b) -> begin
c = argmax(x[b])
i0 = findfirst(identity, b)
for _ in 1:c-1
i0 = findnext(identity, b, i0 + 1)
end
return i0
end
preservingargmax[:findnext_views] = (x, b) -> begin
c = argmax(view(x, b))
i0 = findfirst(identity, b)
for _ in 1:c-1
i0 = findnext(identity, b, i0 + 1)
end
return i0
end
preservingargmax[:eachindex] = (x, b) -> eachindex(x)[b][argmax(x[b])]
preservingargmax[:eachindex_views] = (x, b) -> first(view(view(eachindex(x),b),argmax(view(x,b))))
preservingargmax[:sortperm] = (x, b) -> argmax(sortperm(x) .* b)
preservingargmax[:argmax] = (x, b) -> argmax(i -> b[i] ? x[i] : -Inf, eachindex(x))
preservingargmax[:enumerate] = (x, b) -> first(argmax(t -> t[2][1] ? t[2][2] : -Inf, enumerate(zip(b, x))))
preservingargmax[:zip] = (x, b) -> first(argmax(t -> t[2] ? t[3] : -Inf, zip(eachindex(x), b, x)))
preservingargmax[:loop] = (x, b) -> begin
maxx = -Inf
maxi = 0
for (xi, bi, i) in zip(x, b, eachindex(x))
(bi && xi > maxx) || continue
maxx = xi
maxi = i
end
return maxi
end
preservingargmax[:branchless] = (x, b) -> begin
maxx = -Inf
maxi = 0
for (xi, bi, i) in zip(x, b, eachindex(x))
maxi = (bi && xi > maxx)*i + !(bi && xi > maxx)*maxi
maxx = (bi && xi > maxx)*xi + !(bi && xi > maxx)*maxx
end
return maxi
end
preservingargmax[:to_index] = (x, b) -> (i=Base.to_index(b); Vector{Int}(i)[argmax(x[i])])
preservingargmax[:typemin] = (x, b) -> (y = copy(x); y[.!b] .= typemin(eltype(x)); argmax(y))
preservingargmax[:parentindices] = (x, b) -> (v = view(x, b); first(parentindices(v))[argmax(v)])

labels = [
:Inf,
:findall,
:findall_views,
:findnext,
:findnext_views,
:eachindex,
:eachindex_views,
:sortperm,
:argmax,
:enumerate,
:zip,
:loop,
:branchless,
:to_index,
:typemin,
:parentindices,
]

bg = BenchmarkGroup()

x_1 = [0, 2, 1, -3, 0]
b_1 = [true, false, true, true, false]
x_2 = collect(1:100)
b_2 = falses(100); b_2[[1, 3, 10, 100]] .= true
@testset "correctness, \$k" for (k, f) in preservingargmax
@test f(x_1, b_1) == 3
@test f(x_2, b_2) == 100
bg[k] = @benchmarkable \$f(x, b) setup = (x = randn(\$N); b = bitrand(\$N))
end

res = run(bg)

plot(
res,
labels,
;
yscale=:log10,
xrotation=30,
)

savefig("benchmark.png")
``````
3 Likes

@gustaphe, thanks for sharing this brilliant benchmarking code.
For large arrays, @longemen3000’s `findnext` solution seems to be the fastest.

1 Like

Another one with Julia 1.7:

``````argmax(i -> b[i] ? x[i] : -Inf, eachindex(x))
``````
1 Like

Added to the benchmark. I tried to make something similar to that before, it’s really annoying there is no `indmax` function.

I’m also trying

``````first(argmax(t -> t[2][1] ? t[2][2] : -Inf, enumerate(zip(b, x))))
``````

for which I had high hopes. Sadly it’s the second worst suggestion yet

The current leader is simply writing a loop. I wish it weren’t so, but sometimes you have to do things yourself.

1 Like

I feel like there is a package somewhere that does views that have offset indexes to match the original positions.
But it isn’t in OffsetArrays nor PaddedViews.jl

One more lamp to add to Gustaphe’s fantastic lamp store:

``````function argmax_typemin(x,b)
y = copy(x)
y[.!b] .= typemin(eltype(x))
argmax(y)
end
``````

For N=10K, the store’s window display is (updated based on posts further down):

If one runs again the benchmark for N=10K, as the input is random, some lamps change shape and relative position.

1 Like

One day I want to (have the time to) implement a group wide `setup` for BenchmarkTools so you can have a common rand seed for every test. Sorely lacking functionality.

1 Like

I might be missing something, but the above code doesn’t seem to work in general:

``````julia> x = 1:100
1:100 # Maximum value is 100 at index 100

julia> b = falses(100); b[[1, 3, 10, 100]] .= true; # Compare only these indexes

julia> c = argmax(view(x, b))
4 # Index of maximal value in x[b]

julia> findnext(identity, b, c)
10 # Should be 100
``````
3 Likes