How to perform an argmax / argmin on a subset of a vector?

Hi! Is there a way to perform an argmax on a subset of a vector determined by a bitarray. This bitarray tells whether the index should be considered or not.

Example: if my vector is x = [0, 2, 1, 0] and the bitarray is b = [true, false, true, false], then the result should be 3 as only indexes 1 and 3 should be considered (as indicated by b).
I cannot do argmax(x[b]) as “filtering” with b modifies the indexing and thus return 2 instead of 3.

Thanks!

Probably someone will have a better idea, but one quick and dirty way for argmax could be:

julia> x = [0, 2, 1, 0]
4-element Vector{Int64}:
 0
 2
 1
 0

julia> b = [true, false, true, false]
4-element Vector{Bool}:
 1
 0
 1
 0

julia> argmax(x.*b)
3

All the elements multiplied by false will become zero and therefore will never be a max, provided you know your max is greater than zero.

You could easily adapt this method for argmin or a little bit more general by combining with Inf (see Gustaphe’s response below), but it will allocate an intermediate vector.

The other natural way would be with a loop. I think this is the best choice.

1 Like
argmax(x .- inf .* .!(b))

Doesn’t rely on positive x.

1 Like

In Julia 1.7 you can use (untested, on phone in bed):

_, ind = findmax((x,b) -> b ? x : -inf, zip(x, b))

Making a loop is probably more self documenting, but how cool is that?

in julia 1.6, this seems to work too:

julia> argmax(view(x,b))
2

This doesn’t do what is asked (the answer should be 3, not 2).

2 Likes

oh, sorry, you are right

How about

julia> findall(b)[argmax(@view x[b])]
3
3 Likes

this (now) also returns 3

julia> c = argmax(view(x,b))      
2 #index position of the nth true value in b
julia> findnext(identity,b,c)    
3
1 Like

Dumber alternative:

eachindex(x)[b][argmax(x[b])]
1 Like

One more:

argmax(sortperm(x) .* b)
1 Like

I ran a benchmark. Code below.

Note logarithmic y scale. Lower is better, and the lowest point is the most important. The numbers vary a bit with different runs, but the order is constant as far as I’ve seen. These are the functions I could make pass the test. My conclusion: findall or findnext wins (unless we count writing a loop to do it, which is a bit sad)

benchmark

Code
#!/bin/julia

using BenchmarkTools
using BenchmarkPlots
using StatsPlots
using Test
using Random: bitrand

BenchmarkTools.DEFAULT_PARAMETERS.seconds = 20
N = 10_000

preservingargmax = Dict{Symbol,Function}()

preservingargmax[:Inf] = (x, b) -> argmax(x .- Inf .* .!(b))
preservingargmax[:findall] = (x, b) -> findall(b)[argmax(x[b])]
preservingargmax[:findall_views] = (x, b) -> first(@view findall(b)[argmax(@view x[b])])
preservingargmax[:findnext] = (x, b) -> begin
    c = argmax(x[b])
    i0 = findfirst(identity, b)
    for _ in 1:c-1
        i0 = findnext(identity, b, i0 + 1)
    end
    return i0
end
preservingargmax[:findnext_views] = (x, b) -> begin
    c = argmax(view(x, b))
    i0 = findfirst(identity, b)
    for _ in 1:c-1
        i0 = findnext(identity, b, i0 + 1)
    end
    return i0
end
preservingargmax[:eachindex] = (x, b) -> eachindex(x)[b][argmax(x[b])]
preservingargmax[:eachindex_views] = (x, b) -> first(view(view(eachindex(x),b),argmax(view(x,b))))
preservingargmax[:sortperm] = (x, b) -> argmax(sortperm(x) .* b)
preservingargmax[:argmax] = (x, b) -> argmax(i -> b[i] ? x[i] : -Inf, eachindex(x))
preservingargmax[:enumerate] = (x, b) -> first(argmax(t -> t[2][1] ? t[2][2] : -Inf, enumerate(zip(b, x))))
preservingargmax[:zip] = (x, b) -> first(argmax(t -> t[2] ? t[3] : -Inf, zip(eachindex(x), b, x)))
preservingargmax[:loop] = (x, b) -> begin
    maxx = -Inf
    maxi = 0
    for (xi, bi, i) in zip(x, b, eachindex(x))
        (bi && xi > maxx) || continue
        maxx = xi
        maxi = i
    end
    return maxi
end
preservingargmax[:branchless] = (x, b) -> begin
    maxx = -Inf
    maxi = 0
    for (xi, bi, i) in zip(x, b, eachindex(x))
        maxi = (bi && xi > maxx)*i + !(bi && xi > maxx)*maxi
        maxx = (bi && xi > maxx)*xi + !(bi && xi > maxx)*maxx
    end
    return maxi
end
preservingargmax[:to_index] = (x, b) -> (i=Base.to_index(b); Vector{Int}(i)[argmax(x[i])])
preservingargmax[:typemin] = (x, b) -> (y = copy(x); y[.!b] .= typemin(eltype(x)); argmax(y))
preservingargmax[:parentindices] = (x, b) -> (v = view(x, b); first(parentindices(v))[argmax(v)])

labels = [
        :Inf,
        :findall,
        :findall_views,
        :findnext,
        :findnext_views,
        :eachindex,
        :eachindex_views,
        :sortperm,
        :argmax,
        :enumerate,
        :zip,
        :loop,
        :branchless,
        :to_index,
        :typemin,
        :parentindices,
    ]

bg = BenchmarkGroup()

x_1 = [0, 2, 1, -3, 0]
b_1 = [true, false, true, true, false]
x_2 = collect(1:100)
b_2 = falses(100); b_2[[1, 3, 10, 100]] .= true
@testset "correctness, $k" for (k, f) in preservingargmax
    @test f(x_1, b_1) == 3
    @test f(x_2, b_2) == 100
    bg[k] = @benchmarkable $f(x, b) setup = (x = randn($N); b = bitrand($N))
end

res = run(bg)


plot(
    res,
    labels,
    ;
    yscale=:log10,
    xrotation=30,
)

savefig("benchmark.png")
3 Likes

@gustaphe, thanks for sharing this brilliant benchmarking code.
For large arrays, @longemen3000’s findnext solution seems to be the fastest.

1 Like

Another one with Julia 1.7:

argmax(i -> b[i] ? x[i] : -Inf, eachindex(x))
1 Like

Added to the benchmark. I tried to make something similar to that before, it’s really annoying there is no indmax function.

I’m also trying

first(argmax(t -> t[2][1] ? t[2][2] : -Inf, enumerate(zip(b, x))))

for which I had high hopes. Sadly it’s the second worst suggestion yet :frowning:

The current leader is simply writing a loop. I wish it weren’t so, but sometimes you have to do things yourself.

1 Like

I feel like there is a package somewhere that does views that have offset indexes to match the original positions.
But it isn’t in OffsetArrays nor PaddedViews.jl

This thread is addictive…

One more lamp to add to Gustaphe’s fantastic lamp store:

function argmax_typemin(x,b)
    y = copy(x)
    y[.!b] .= typemin(eltype(x))
    argmax(y)
end

For N=10K, the store’s window display is (updated based on posts further down):

benchmark_conditional_argmax_10K_r1

If one runs again the benchmark for N=10K, as the input is random, some lamps change shape and relative position.

benchmark_conditional_argmax_10K_r2

1 Like

One day I want to (have the time to) implement a group wide setup for BenchmarkTools so you can have a common rand seed for every test. Sorely lacking functionality.

1 Like

I might be missing something, but the above code doesn’t seem to work in general:

julia> x = 1:100
1:100 # Maximum value is 100 at index 100

julia> b = falses(100); b[[1, 3, 10, 100]] .= true; # Compare only these indexes

julia> c = argmax(view(x, b))
4 # Index of maximal value in x[b]

julia> findnext(identity, b, c)
10 # Should be 100
3 Likes