What is Julia's maxk(Matlab) that returns the indice of top K largest values?

There may be something more efficient, but sortperm should do what you want:

julia> a = shuffle(11:15)
5-element Array{Int64,1}:
 15
 12
 14
 13
 11

julia> b = sortperm(a)
5-element Array{Int64,1}:
 5
 2
 4
 3
 1

julia> collect(zip(b, a[b]))
5-element Array{Tuple{Int64,Int64},1}:
 (5, 11)
 (2, 12)
 (4, 13)
 (3, 14)
 (1, 15)
5 Likes

As above, you can also use partialsortperm(a, 1:3) to only get the first 3 indices, and so on. It should run a little faster. Here’s a link to the documentation, and an implementation of maxk. Note I use rev=true to get the largest elements.

function maxk(a, k)
    b = partialsortperm(a, 1:k, rev=true)
    return collect(zip(b, a[b]))
end
8 Likes

Thanks for the code. However, I found that maxk in Matlab is way faster(4.7 times in Matlab 2018a) than the maxk in Julia.

QQ=randn(1000,1);
tic;
for i=1:10000
[~,ind]=maxk(QQ,10);
end
toc;

and

tic();
for i=1:10000
ind1=maxk(Q,10);
end
toc();

Is it possible for the maxk in Julia to be as fast as that in Matlab?

Have a look at the performance tips. Here is how I would time:

julia> using BenchmarkTools

julia> a = randn(1000);

julia> @btime maxk($a, 10);
  11.064 μs (8 allocations: 8.47 KiB)

I see a >30x improvement when I use partialsortperm! (an in-place version of partialsortperm) and I pass in a pre-initialized vector ix=collect(1:10).

julia> function maxk!(ix, a, k; initialized=false)
         partialsortperm!(ix, a, 1:k, rev=true, initialized=initialized)
         return collect(zip(ix, a[ix]))
       end
maxk! (generic function with 1 method)

julia> ix = collect(1:10);

julia> @btime maxk!($ix, $a, 10, true);
  345.111 ns (7 allocations: 544 bytes)

Edit: a more general maxk! where ix may not necessarily be initialized.
Editx2: See @Elrod below

I don’t think maxk! is correct.

julia> function maxk!(ix, a, k; initialized=false)
                partialsortperm!(ix, a, 1:k, rev=true, initialized=initialized)
                return collect(zip(ix, a[ix]))
              end
maxk! (generic function with 1 method)

julia> q = randn(1000);

julia> idx = collect(1:10);

julia> maxk!(idx, q, 10, initialized = true)
10-element Array{Tuple{Int64,Float64},1}:
 (4, 0.20337986959767532) 
 (6, -0.3077570887074856) 
 (1, -0.6107099712370375) 
 (5, -0.6993273979405977) 
 (10, -0.7794269061061618)
 (2, -0.8531978408910773) 
 (7, -0.9704489858497608) 
 (8, -1.1237782702811618) 
 (9, -1.2057372641871114) 
 (3, -2.813777795383793)  

julia> maximum(q)
4.032311634752638

julia> idx'
1×10 LinearAlgebra.Adjoint{Int64,Array{Int64,1}}:
 4  6  1  5  10  2  7  8  9  3

julia> partialsortperm(q, 1:10, rev=true)'
1×10 LinearAlgebra.Adjoint{Int64,SubArray{Int64,1,Array{Int64,1},Tuple{UnitRange{Int64}},true}}:
 626  425  940  40  594  811  88  133  278  128

julia> q[ans]
1×10 Array{Float64,2}:
 4.03231  3.60551  3.07742  2.75559  2.69597  2.45856  2.25731  2.2014  2.19445  2.18963

EDIT:
You need:

function maxk!(ix, a, k; initialized=false)
                partialsortperm!(ix, a, 1:k, rev=true, initialized=initialized)
                @views collect(zip(ix[1:k], a[ix[1:k]]))
              end

julia> idx = collect(1:length(q));

julia> maxk!(idx, q, 10, initialized = true)
1000-element Array{Tuple{Int64,Float64},1}:
 (626, 4.032311634752638)  
 (425, 3.6055071635748828) 
 (940, 3.07741625532675)   
 (40, 2.755589275585265)   
 (594, 2.695973149935006)  
 (811, 2.4585577302387667) 
 (88, 2.2573122866282973)  
 (133, 2.2014028244112733) 
 (278, 2.1944512712785187) 
 (128, 2.1896301246037284) 

EDIT:
Because this was marked the solution, I want to point out that it isn’t notably faster than the other answers on the computer I tried this on.

julia> @benchmark maxk!($idx, $q, 10)# setup=(copyto!($idx, 1:length($q)))
BenchmarkTools.Trial: 
  memory estimate:  544 bytes
  allocs estimate:  10
  --------------
  minimum time:     9.098 μs (0.00% GC)
  median time:      9.257 μs (0.00% GC)
  mean time:        9.671 μs (0.00% GC)
  maximum time:     26.930 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1

julia> b = @benchmarkable maxk!($idx, $q, 10, initialized=true) setup=(copyto!($idx, 1:length($q)))
Benchmark(evals=1, seconds=5.0, samples=10000)

julia> run(b)
BenchmarkTools.Trial: 
  memory estimate:  544 bytes
  allocs estimate:  10
  --------------
  minimum time:     8.966 μs (0.00% GC)
  median time:      9.127 μs (0.00% GC)
  mean time:        9.440 μs (0.00% GC)
  maximum time:     21.891 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1

julia> function maxk(a, k)
           b = partialsortperm(a, 1:k, rev=true)
           return collect(zip(b, a[b]))
       end
maxk (generic function with 1 method)

julia> @benchmark maxk($q, 10)
BenchmarkTools.Trial: 
  memory estimate:  8.47 KiB
  allocs estimate:  8
  --------------
  minimum time:     9.749 μs (0.00% GC)
  median time:      11.382 μs (0.00% GC)
  mean time:        16.790 μs (31.13% GC)
  maximum time:     42.964 ms (99.93% GC)
  --------------
  samples:          10000
  evals/sample:     1
5 Likes

You should have used

collect(1:1000)

as a fair comparison

A comparison between maxk (the not-in-place version) in Julia 1.0 and maxk in Matlab R2018a shows that Julia is several times faster two times slower for this particular test on my laptop.

Julia 1.0.0:

using BenchmarkTools

function maxk(a, k)
    b = partialsortperm(a, 1:k, rev=true)
    return collect(zip(b, a[b]))
end

a = randn(1000)
k = 10

@benchmark maxk($a,$k)
BenchmarkTools.Trial:
  memory estimate:  8.47 KiB
  allocs estimate:  8
  --------------
  minimum time:     9.079 μs (0.00% GC)
  median time:      9.475 μs (0.00% GC)
  mean time:        16.670 μs (38.25% GC)
  maximum time:     50.874 ms (99.91% GC)
  --------------
  samples:          10000
  evals/sample:     1

Matlab R2018a:

a = randn(1000,1);
k = 10;
f = @() maxk(a,k);

tmedian = timeit(f)
tmedian =
   4.3039e-06

Edit: Fixed wrong comparison (it was too late at night…).

Looks like the Matlab version is >2x faster to me.

Yes, looking at previous discussions it seems like Matlab’s sort functions are multi-threaded and heavily optimized, so they will be tough to beat. I did a naive threaded implementation here, which starts to have some gains over single-threaded for large a:

julia> using Base.Threads

julia> nthreads()
2

julia> function maxk_threaded(a, k)
           ix = Vector{Int}(undef, k*nthreads())
           block_size = ceil(Int, length(a)/nthreads())
           @threads for thread_id in 1:nthreads()
               ix_start = (thread_id-1)*block_size
               ix_end   = min(length(a), thread_id*block_size)
               ix[((thread_id-1)*k+1):(thread_id*k)] = ix_start .+ partialsortperm(@view(a[(1+ix_start):ix_end]), 1:k, rev=true)
           end
           partialsortperm!(ix, a, 1:k, rev=true, initialized=true)
           @views collect(zip(ix[1:k], a[ix[1:k]]))
       end
maxk_threaded (generic function with 1 method)

julia> a = randn(10000);

julia> @btime maxk($a, 10)
  70.201 μs (9 allocations: 78.73 KiB)
10-element Array{Tuple{Int64,Float64},1}:
 (3840, 3.7524106800393)
 (1359, 3.667162944745407)
 (4738, 3.46128657912246)
 (8532, 3.349067643815953)
 (8314, 3.3363898988561234)
 (3542, 3.3297030468239965)
 (1159, 3.2795246783768923)
 (9436, 3.259918244413647)
 (9418, 3.254388944717796)
 (2198, 3.155524296051535)

julia> @btime maxk_threaded($a, 10)
  53.894 μs (23 allocations: 1.58 KiB)
10-element Array{Tuple{Int64,Float64},1}:
 (3840, 3.7524106800393)
 (1359, 3.667162944745407)
 (4738, 3.46128657912246)
 (8532, 3.349067643815953)
 (8314, 3.3363898988561234)
 (3542, 3.3297030468239965)
 (1159, 3.2795246783768923)
 (9436, 3.259918244413647)
 (9418, 3.254388944717796)
 (2198, 3.155524296051535)

Got a reasonably fast solution here using the SortingLab.jl, see

I am not sure how fast it is compared to Matlab. Perhaps @complexfilter can do a test and share the results.

I found a ~10x faster alternative to partialsortperm using StaticArrays.jl and an insertion sort. This is still the top answer on Google so thought I’d add to this in case people are looking for a super fast maxk.

using StaticArrays: MVector

const MAXK = 10_000  # Just to preallocate Val(i); increase as needed
const PREALLOC_VALS = ntuple(Val, MAXK)

bottomk(x, k) = _bottomk_dispatch(x, PREALLOC_VALS[k])

function _bottomk_dispatch(x::AbstractVector{T}, ::Val{k}) where {T,k}
    @assert k >= 2
    indmin = MVector{k}(ntuple(_ -> 0, k))
    minval = MVector{k}(ntuple(_ -> typemax(T), k))
    _bottomk!(x, minval, indmin)
    return [minval...], [indmin...]
end
function _bottomk!(x, minval, indmin)
    @inbounds @fastmath for i in eachindex(x)
        new_min = x[i] < minval[end]
        if new_min
            minval[end] = x[i]
            indmin[end] = i
            for ki in length(minval):-1:2
                need_swap = minval[ki] < minval[ki - 1]
                if need_swap
                    minval[ki], minval[ki - 1] = minval[ki - 1], minval[ki]
                    indmin[ki], indmin[ki - 1] = indmin[ki - 1], indmin[ki]
                end
            end
        end
    end
    return nothing
end

we can see that it is much faster than the partialsortperm strategy:

julia> @btime bottomk(x, 5) setup=(x=randn(1000));
  779.506 ns (5 allocations: 320 bytes)

julia> @btime partialsortperm(x, 1:5) setup=(x=randn(1000));
  8.122 μs (4 allocations: 12.05 KiB)

even if I use preallocation:

julia> idx = collect(eachindex(x));

julia> @btime partialsortperm!(idx, x, 1:5) setup=(x=randn(1000));
  7.365 μs (4 allocations: 4.16 KiB)

and gives the same outputs:

julia> bottomk(x, 5)[2]
5-element Vector{Int64}:
 185
 603
  21
 386
 105

julia> partialsortperm(x, 1:5)
5-element view(::Vector{Int64}, 1:5) with eltype Int64:
 185
 603
  21
 386
 105

This is an extension of @Elrod’s solution in Why is minimum so much faster than argmin? - #11 by Elrod. I couldn’t seem to get @turbo working for topk here (but please share if you find one!)

1 Like

For large k, a max-heap stored as a binary tree on a StaticArray, should allow it to come down to O( n \log(k)), instead of O(n k) of the code above.

here a similar discussion

julia> using DataStructures

julia> function MaxN7(cr,N)
           maxn = heapify!(cr[1:N])
           maxn1=maxn[1]
              @inbounds for i in N+1:length(cr)
               e=cr[i]
               if maxn1 < e
                   heappop!(maxn)
                   heappush!(maxn,e)
                   maxn1=maxn[1]
                   end
               end
           sort!(maxn,rev=true)
       end
MaxN7 (generic function with 1 method)

# on my laptop

julia> @btime MaxN7(x, 5) setup=(x=randn(1000));
  722.951 ns (1 allocation: 96 bytes)

julia> @btime bottomk(x, 5) setup=(x=randn(1000));
  752.137 ns (5 allocations: 320 bytes)

3 Likes

Thanks @rocco_sprmnt21. Note that bottomk returns the indices as well. Can you modify yours to do so?

Partly as a note to myself and partly as a note to the reader, I was curious about this, so here are some more detailed benchmarks comparing @rocco_sprmnt21 and the code I pasted above. However note that @rocco_sprmnt21’s does not return indices so performance is not completely comparable (but probably scales the same).

(@rocco_sprmnt21 when you get a chance to send an updated copy with indices, I can re-do the comparison)

using BenchmarkTools

suite = BenchmarkGroup()

# Create the benchmark over n and k:
for n in [10, 100, 1000, 10_000], k in [2, 5, 10, 20, 100]
   k > n && continue
   foreach(["bottomk_fast", "maxn7"]) do s
       !haskey(suite, s) && (suite[s] = BenchmarkGroup())
       !haskey(suite[s], n) && (suite[s][n] = BenchmarkGroup())
   end
   suite["bottomk_fast"][n][k] = @benchmarkable(bottomk(x, $k), setup=(x=randn($n)))
   suite["maxn7"][n][k] = @benchmarkable(MaxN7(x, $k), setup=(x=randn($n)))
end

# Run the benchmark:
results = run(suite; verbose=true)

# Print the benchmark:
for n in [10, 100, 1000, 10_000], k in [2, 5, 10, 20, 100]
   k > n && continue
   bottomk_fast_results = median(results["bottomk_fast"][n][k].times)
   maxn7_results = median(results["maxn7"][n][k].times)
   @printf("n = %-8d k = %-8d bottomk_fast = %-8d maxn7 = %-8d ratio = %.3f", n, k, bottomk_fast_results, maxn7_results, bottomk_fast_results/maxn7_results)
   println()
end

this gives:

n = 10       k = 2        bottomk_fast = 1583     maxn7 = 2695     ratio = 0.587
n = 10       k = 5        bottomk_fast = 5912     maxn7 = 3036     ratio = 1.947
n = 10       k = 10       bottomk_fast = 15615    maxn7 = 2930     ratio = 5.328
n = 100      k = 2        bottomk_fast = 4186     maxn7 = 3670     ratio = 1.141
n = 100      k = 5        bottomk_fast = 13670    maxn7 = 5043     ratio = 2.711
n = 100      k = 10       bottomk_fast = 41110    maxn7 = 7203     ratio = 5.707
n = 100      k = 20       bottomk_fast = 142150   maxn7 = 21048    ratio = 6.754
n = 100      k = 100      bottomk_fast = 1504647  maxn7 = 101132   ratio = 14.878
n = 1000     k = 2        bottomk_fast = 23511    maxn7 = 8855     ratio = 2.655
n = 1000     k = 5        bottomk_fast = 39211    maxn7 = 11026    ratio = 3.556
n = 1000     k = 10       bottomk_fast = 83094    maxn7 = 14947    ratio = 5.559
n = 1000     k = 20       bottomk_fast = 260530   maxn7 = 32883    ratio = 7.923
n = 1000     k = 100      bottomk_fast = 3920184  maxn7 = 128532   ratio = 30.500
n = 10000    k = 2        bottomk_fast = 215446   maxn7 = 56601    ratio = 3.806
n = 10000    k = 5        bottomk_fast = 238905   maxn7 = 58012    ratio = 4.118
n = 10000    k = 10       bottomk_fast = 298069   maxn7 = 63744    ratio = 4.676
n = 10000    k = 20       bottomk_fast = 542982   maxn7 = 86845    ratio = 6.252
n = 10000    k = 100      bottomk_fast = 6368336  maxn7 = 222330   ratio = 28.644

So as @pitsianis suggested as well the heap helps quite a bit, especially for larger k.

Personally I only work with k~5 so I am curious where exactly the crossing point is as it would affect what implementation I go with.

2 Likes

like this?

function Nlargest(v,N)
    maxn = heapify!(tuple.(v[1:N],1:N))
    maxn1=maxn[1]
    for i in N+1:length(v)
        e=(v[i],i)    
        if maxn1[1] < e[1]
            heappop!(maxn)
            heappush!(maxn,e)
            maxn1=maxn[1]
            end
        end
    sort!(maxn, by=first,rev=
    true)
  end

Perfect, thanks!

Here is the updated benchmark code:

using DataStructures
using StaticArrays: MVector
using BenchmarkTools
using Printf

function heap_topk(x, k)
    maxn = heapify!(tuple.(x[1:k], 1:k))
    maxn1 = maxn[1]
    for i in k+1:length(x)
        e = (x[i], i)
        if maxn1[1] < e[1]
            heappop!(maxn)
            heappush!(maxn, e)
            maxn1 = maxn[1]
        end
    end
    raw_out = sort!(maxn, by=first, rev=true)
    xout, iout = Array{eltype(x)}(undef, k), Array{Int}(undef, k)
    for i in 1:k
        xout[i], iout[i] = raw_out[i]
    end
    xout, iout
end

const MAXK = 10_000;  # Just to preallocate Val(i); increase as needed
const PREALLOC_VALS = ntuple(Val, MAXK);

static_topk(x, k) = _topk_static_dispatch(x, PREALLOC_VALS[k])

function _topk_static_dispatch(x::AbstractVector{T}, ::Val{k}) where {T,k}
    @assert k >= 2
    indmin = MVector{k}(ntuple(_ -> 0, k))
    maxval = MVector{k}(ntuple(_ -> typemin(T), k))
    _topk_static!(x, maxval, indmin)
    return [maxval...], [indmin...]
end
function _topk_static!(x, maxval, indmin)
    @inbounds @fastmath for i in eachindex(x)
        new_max = x[i] > maxval[end]
        if new_max
            maxval[end] = x[i]
            indmin[end] = i
            for ki in length(maxval):-1:2
                need_swap = maxval[ki] > maxval[ki-1]
                if need_swap
                    maxval[ki], maxval[ki-1] = maxval[ki-1], maxval[ki]
                    indmin[ki], indmin[ki-1] = indmin[ki-1], indmin[ki]
                end
            end
        end
    end
    return nothing
end


using BenchmarkTools

suite = BenchmarkGroup()

# Create the benchmark over n and k:
for n in [10, 100, 1000, 10_000], k in [2, 5, 10, 20, 100]
    k > n && continue
    foreach(["static_topk", "heap_topk"]) do s
        !haskey(suite, s) && (suite[s] = BenchmarkGroup())
        !haskey(suite[s], n) && (suite[s][n] = BenchmarkGroup())
    end
    suite["static_topk"][n][k] = @benchmarkable(static_topk(x, $k), setup = (x = randn($n)))
    suite["heap_topk"][n][k] = @benchmarkable(heap_topk(x, $k), setup = (x = randn($n)))
end

# Run the benchmark:
tune!(suite)
results = run(suite; verbose=true)

# Print the benchmark:
for n in [10, 100, 1000, 10_000], k in [2, 5, 10, 20, 100]
    k > n && continue
    static_topk_results = median(results["static_topk"][n][k].times)
    heap_topk_results = median(results["heap_topk"][n][k].times)
    @printf("n = %-8d k = %-8d static_topk = %-8d heap_topk = %-8d ratio = %.3f", n, k, static_topk_results, heap_topk_results, static_topk_results / heap_topk_results)
    println()
end

and the results (the earlier results I ran with -O2; these ones are -O3)

n = 10       k = 2        static_topk = 110      heap_topk = 204      ratio = 0.538
n = 10       k = 5        static_topk = 203      heap_topk = 278      ratio = 0.731
n = 10       k = 10       static_topk = 408      heap_topk = 277      ratio = 1.472
n = 100      k = 2        static_topk = 187      heap_topk = 369      ratio = 0.507
n = 100      k = 5        static_topk = 352      heap_topk = 638      ratio = 0.551
n = 100      k = 10       static_topk = 919      heap_topk = 1006     ratio = 0.914
n = 100      k = 20       static_topk = 7453     heap_topk = 2011     ratio = 3.707
n = 100      k = 100      static_topk = 51932    heap_topk = 3302     ratio = 15.727
n = 1000     k = 2        static_topk = 939      heap_topk = 1268     ratio = 0.740
n = 1000     k = 5        static_topk = 986      heap_topk = 1913     ratio = 0.515
n = 1000     k = 10       static_topk = 1954     heap_topk = 2819     ratio = 0.693
n = 1000     k = 20       static_topk = 10465    heap_topk = 4933     ratio = 2.122
n = 1000     k = 100      static_topk = 108776   heap_topk = 24936    ratio = 4.362
n = 10000    k = 2        static_topk = 8084     heap_topk = 8498     ratio = 0.951
n = 10000    k = 5        static_topk = 6051     heap_topk = 9631     ratio = 0.628
n = 10000    k = 10       static_topk = 7339     heap_topk = 12001    ratio = 0.612
n = 10000    k = 20       static_topk = 17308    heap_topk = 16600    ratio = 1.043
n = 10000    k = 100      static_topk = 168512   heap_topk = 50380    ratio = 3.345

so it looks like the performance benefit of the heap starts to kick in around k~15, across many different n

some examples for larger x and k

julia> size(x)
(1000000,)

julia> @btime static_topk(x,2000)
  73.665 ms (14029 allocations: 530.78 KiB)
([0.999997967646439, 0.9999961871680448, 0.999995897044009, 0.9999958302586448, 0.999994393124644, 0.9999938471198041, 0.9999924818468605, 0.9999919285747503, 0.999991600223384, 0.9999912694737714  …  0.9980135647476381, 0.9980133491712229, 0.9980128032389254, 0.9980124826798483, 0.9980115374968526, 0.9980102464380634, 0.9980096635365333, 0.9980090790430461, 0.9980083786018685, 0.9980068940624146], [263033, 497354, 486880, 422564, 542791, 822125, 217926, 123925, 620801, 771262  …  620445, 899858, 448375, 527998, 13622, 446776, 809931, 520942, 6102, 873068])

julia> @btime heap_topk(x,2000)
  1.757 ms (8 allocations: 94.33 KiB)
([0.999997967646439, 0.9999961871680448, 0.999995897044009, 0.9999958302586448, 0.999994393124644, 0.9999938471198041, 0.9999924818468605, 0.9999919285747503, 0.999991600223384, 0.9999912694737714  …  0.9980135647476381, 0.9980133491712229, 0.9980128032389254, 0.9980124826798483, 0.9980115374968526, 0.9980102464380634, 0.9980096635365333, 0.9980090790430461, 0.9980083786018685, 0.9980068940624146], [263033, 497354, 486880, 422564, 542791, 822125, 217926, 123925, 620801, 771262  …  620445, 899858, 448375, 527998, 13622, 446776, 809931, 520942, 6102, 873068])

julia> @btime heap_topk(x,200)
  546.200 μs (7 allocations: 10.23 KiB)
([0.999997967646439, 0.9999961871680448, 0.999995897044009, 0.9999958302586448, 0.999994393124644, 0.9999938471198041, 0.9999924818468605, 0.9999919285747503, 0.999991600223384, 0.9999912694737714  …  0.999824711031017, 0.999824306912119, 0.9998242015343735, 0.9998237261377597, 0.9998232767425146, 0.999822831466934, 0.9998225999196836, 0.9998210131024187, 0.9998202355337459, 0.9998173306675848], [263033, 497354, 486880, 422564, 542791, 822125, 217926, 123925, 620801, 771262  …  778414, 251065, 684906, 894726, 502819, 785139, 860830, 846404, 777284, 280505])

julia> @btime static_topk(x,200)
  1.441 ms (1417 allocations: 50.22 KiB)
([0.999997967646439, 0.9999961871680448, 0.999995897044009, 0.9999958302586448, 0.999994393124644, 0.9999938471198041, 0.9999924818468605, 0.9999919285747503, 0.999991600223384, 0.9999912694737714  …  0.999824711031017, 0.999824306912119, 0.9998242015343735, 0.9998237261377597, 0.9998232767425146, 0.999822831466934, 0.9998225999196836, 0.9998210131024187, 0.9998202355337459, 0.9998173306675848], [263033, 497354, 486880, 422564, 542791, 822125, 217926, 123925, 620801, 771262  …  778414, 251065, 684906, 894726, 502819, 785139, 860830, 846404, 777284, 280505])

PS
another difference

julia> static_topk(rand(100,100),5)
ERROR: MethodError: no method matching _topk_static_dispatch(::Matrix{Float64}, ::Val{5})
Closest candidates are:
  _topk_static_dispatch(::AbstractVector{T}, ::Val{k}) where {T, k} at c:\Users\sprmn\.julia\environments\v1.8.3\array62.jl:29
Stacktrace:
 [1] static_topk(x::Matrix{Float64}, k::Int64)
   @ Main c:\Users\sprmn\.julia\environments\v1.8.3\array62.jl:27        
 [2] top-level scope
   @ REPL[15]:1


julia> heap_topk(rand(100,100),5)
([0.9999980748309518, 0.9999748146819019, 0.9999213162425449, 0.9999191599605562, 0.999707415570861], [1259, 8141, 6205, 8028, 1774])
1 Like

This bit of code looks very unnatural. Isn’t this a job for @generated functions? (I really should try to make a version myself, instead of asking…)

Normally you would just do

static_topk(x, k) = _topk_static_dispatch(x, Val(k))

rather than PREALLOC_VALS[k].

But in another project I found you can get a tiny performance boost if you create all the Val(i) beforehand and just index it. I didn’t test whether it affects this one though; it’s just a habit (perhaps a bad one).

The ::Val{k} is just to make the function get recompiled each time k changes, so that the static arrays are type stable. (@generated only re-generates the functions when types change, so you couldn’t get around using the Val{k} here).

(I remember reading that in the most recent Julia versions, the compiler tends to be good enough so you don’t need @generated in most scenarios, so you can just let type specialization handle everything.)