10x faster sortperm()

Motivated by LilithHafner s work to add radix sorting to Julia 1.9 and the resulting performance improvements of sort(), I did some tests to see if sortperm() could be sped up as well.

To test this, I implemented different ways to sort a vector of UInt64s and compared their performance:

sort
standard sort() as a performance reference

sortperm

  • standard sortperm()
  • If I am not mistaken, this uses the Quickersort similar to Quicksort so sort the indices with a comparison function that looks up the values for each comparison.

packed sortperm

  • Pack indices and values into Uint128 vector.
  • Sort using sort!() (uses radix sort).
  • Unpack indices.

packed sortperm 2

  • Pack indices and values into Uint128 vector.
  • Sort using custom radix sort, that skips the lowest 64 bits, as the indices are already sorted.
  • Unpack indices.

radix sortperm by reference

  • Sort indices using custom radix sort, that looks up the value for each index during each radix pass.

All radix sorts are minimal modifications of the LSD radix sort in Julia 1.9.
The Benchmarks are run on Win 10 with an AMD 3900x with DDR4-3600 (Julia version 1.9.0-beta2 7daffeecb8).

Observations

  • Both sorting methods which sort by reference see a significant drop in performance for vectors of more than about 10^6 elements, which roughly corresponds to the size of the L3 cache (directly accessible from one core).
  • Radix sort by reference is the slowest option for big inputs, but the fastest for small ones.
  • Packed sortperm using radix sort can achieve up to 10x speedup for large datasets, with the optimized version skipping half the bits about twice as fast as the basic version.

I don’t know if it would be worth it to include something like this in the standard library, especially because this packed version uses twice the memory, but I wanted to share my results nonetheless.

Here is my (ugly) code, modified to run up to vectors of size 10^7 (10^9 needs 64GB of RAM):

using Random
import Plots
mypack(a,b) = UInt128(a)<<64 + UInt128(b)

function my_radix_sort!(v::AbstractVector{U}, lo::Integer, hi::Integer, bits::Unsigned, 
                     t::AbstractVector{U}, offset::Integer,
                     shift,chunk_size) where U <: Unsigned
    # bits is unsigned for performance reasons.
    counts = Vector{Int}(undef, 1 << chunk_size + 1) # TODO use scratch for this
    while true
        @noinline my_radix_sort_pass!(t, lo, hi, offset, counts, v, shift, chunk_size)
		#return(v,t)
        # the latest data resides in t
        shift += chunk_size
        shift < bits || return false
        @noinline my_radix_sort_pass!(v, lo+offset, hi+offset, -offset, counts, t, shift, chunk_size)
        # the latest data resides in v
        shift += chunk_size
        shift < bits || return true
    end
end
function my_radix_sort_pass!(t, lo, hi, offset, counts, v, shift, chunk_size)
    mask = UInt(1) << chunk_size - 1  # mask is defined in pass so that the compiler
    @inbounds begin                   #  ↳ knows it's shape
        # counts[2:mask+2] will store the number of elements that fall into each bucket.
        # if chunk_size = 8, counts[2] is bucket 0x00 and counts[257] is bucket 0xff.
        counts .= 0
        for k in lo:hi
            x = v[k]				  # lookup the element
			#show(x)
			#println(typeof(x))
            i = (x >> shift)&mask + 2 # compute its bucket's index for this pass
            #println(i)
			counts[i] += 1            # increment that bucket's count
        end

        counts[1] = lo                # set target index for the first bucket
        cumsum!(counts, counts)       # set target indices for subsequent buckets
        #println(counts)
		# counts[1:mask+1] now stores indices where the first member of each bucket
        # belongs, not the number of elements in each bucket. We will put the first element
        # of bucket 0x00 in t[counts[1]], the next element of bucket 0x00 in t[counts[1]+1],
        # and the last element of bucket 0x00 in t[counts[2]-1].

        for k in lo:hi
            x = v[k]                  # lookup the element
            i = (x >> shift)&mask + 1 # compute its bucket's index for this pass
            j = counts[i]             # lookup the target index
            t[j + offset] = x         # put the element where it belongs
            counts[i] = j + 1         # increment the target index for the next
        end                           #  ↳ element in this bucket
    end
end


function packedsortperm(v)
tups = Vector{UInt128}(undef, length(v))
for idx = 1:length(v)
	tups[idx] = mypack(v[idx],idx)
end
sort!(tups)
Int.(tups .& 0xffffffff)
end

function packedsortperm2(v)
tups = Vector{UInt128}(undef, length(v))
for idx = 1:length(v)
	tups[idx] = mypack(v[idx],idx)
end
t = Vector{UInt128}(undef, length(v))
flag = my_radix_sort!(tups, 1, length(v), UInt(128), t, UInt(0), 64, UInt8(10))
if flag
	return Int.(tups .& 0xffffffff)
else
	return Int.(t .& 0xffffffff)
end
end


function my_radix_sort_pass_by_reference!(t, lo, hi, offset, counts, v, shift, chunk_size,key)
    mask = UInt(1) << chunk_size - 1  # mask is defined in pass so that the compiler
    @inbounds begin                   #  ↳ knows it's shape
        # counts[2:mask+2] will store the number of elements that fall into each bucket.
        # if chunk_size = 8, counts[2] is bucket 0x00 and counts[257] is bucket 0xff.
        counts .= 0
        for k in lo:hi
            x = v[k]				  # lookup the element
            i = (key[x] >> shift)&mask + 2 # compute its bucket's index for this pass
			counts[i] += 1            # increment that bucket's count
        end

        counts[1] = lo                # set target index for the first bucket
        cumsum!(counts, counts)       # set target indices for subsequent buckets
		# counts[1:mask+1] now stores indices where the first member of each bucket
        # belongs, not the number of elements in each bucket. We will put the first element
        # of bucket 0x00 in t[counts[1]], the next element of bucket 0x00 in t[counts[1]+1],
        # and the last element of bucket 0x00 in t[counts[2]-1].

        for k in lo:hi
            x = v[k]                  # lookup the element
            i = (key[x] >> shift)&mask + 1 # compute its bucket's index for this pass
            j = counts[i]             # lookup the target index
            t[j + offset] = x         # put the element where it belongs
            counts[i] = j + 1         # increment the target index for the next
        end                           #  ↳ element in this bucket
    end
end

function my_radix_sort_by_reference!(v::AbstractVector{U}, lo::Integer, hi::Integer, bits::Unsigned, 
                     t::AbstractVector{U}, offset::Integer,
                     shift,chunk_size,key) where U <: Unsigned
    # bits is unsigned for performance reasons.
    counts = Vector{Int}(undef, 1 << chunk_size + 1) # TODO use scratch for this
    while true
        @noinline my_radix_sort_pass_by_reference!(t, lo, hi, offset, counts, v, shift, chunk_size,key)
        # the latest data resides in t
        shift += chunk_size
        shift < bits || return false
        @noinline my_radix_sort_pass_by_reference!(v, lo+offset, hi+offset, -offset, counts, t, shift, chunk_size,key)
        # the latest data resides in v
        shift += chunk_size
        shift < bits || return true
    end
end


function radixsortperm(v)
idxs = UInt.(collect(1:length(v)))
t = Vector{UInt64}(undef, length(v))
flag = my_radix_sort_by_reference!(idxs, 1, length(v), UInt(64), t, UInt(0), 0, UInt8(10),v)
if flag
	return Int.(idxs)
else
	return Int.(t)
end
end


function sortpermbench(sizes)
	nSamples = length(sizes)
	t = zeros(nSamples,5)
	for (idx,n) in enumerate(sizes)
		print("$(idx)/$(nSamples)")
		v = rand(UInt64,n)
		GC.gc()
		
		v1 = copy(v)
		t[idx,1] = @elapsed v2 = sort(v1)
		print(".")
		
		v1 = copy(v)
		t[idx,2] = @elapsed p2 = sortperm(v1)
		print(".")
		
		v1 = copy(v)
		t[idx,3] = @elapsed p3 = packedsortperm(v1)
		print(".")
		
		v1 = copy(v)
		t[idx,4] = @elapsed p4 = packedsortperm2(v1)
		print(".")
		
		v1 = copy(v)
		t[idx,5] = @elapsed p5 = radixsortperm(v1)
		print(".")
		
		@assert p2 == p3
		@assert p2 == p4
		@assert p2 == p5
		println(".")
		
	end
	t
end


biasExp = 6
sizes = convert.(Int,round.(10 .^ (range(2^(1/biasExp),9^(1/biasExp),250)).^biasExp)) |> shuffle

t = sortpermbench(sizes)

Plots.scatter(sizes,sizes./t,xaxis=:log,yaxis=:log,xlabel="Input Size / Elements",ylabel= "Elements / Second",label=["sort" "sortperm" "packed sortperm" "packed sortperm 2" "radix sortperm by reference"],legend=:bottomleft,xticks=10.0 .^ (2:7),minorticks=10)
Plots.ylims!(1e6, 1e8)
Plots.savefig("packedsortpermbench3.png")

Plots.hline([1], color=:black,lw=2,label="sortperm")
Plots.hline!([1], color=:black,lw=1,label=nothing)
Plots.scatter!(sizes,t[:,2]./t[:,3:5],xaxis=:log,xlabel="Input Size / Elements",ylabel= "Speedup vs sortperm",label=["packed sortperm" "packed sortperm 2" "radix sortperm by reference"],legend=:topleft,xticks=10.0 .^ (2:7),minorticks=10,alpha = 1.0)
Plots.ylims!(0, 12)
Plots.xlims!(1e2, 1e7)
Plots.savefig("packedsortpermbench3_speedup.png")
17 Likes

Nice work.

A simple (Conceptually, not in code) optimization, for arrays, would be adapting the number of bits used for the extension according to the array size and data type.

1 Like

I tried to find some more realistic ways to speed up sortperm (ie. without using enormous amounts of memory). Using 32 bit array indices is an obvious approach, but that would still use too much memory and does not scale to really big arrays, where conserving memory is arguably most important. Instead I tried the following approaches:

sortperm alg =MergeSort

  • Mergesort does use less comparisons than Quicksort, so using it when sorting by reference does lead to a small performance gain for very large instances.
  • The speedup is very small and does not warrant indroducing a special case for this.
  • Maybe doing the merge interleaved from the left and the right side (called Parity Merge here) can lead to further speedups, by allowing for more independent array accesses to be in flight at the same time.

Packed QuickSort

  • Using an in-place sorting algorithm with packed values and indices uses exactly the same amount of memory as regular sortperm while running. When returning the indices, the packed array and the vector of indices to be returned is briefly present in memory at the same time. Maybe it is possible to avoid this?

  • QuickSort (not ScratchQuickSort!) does work in-place. Not beeing a stable sort is no problem, as ties are broken using the indices, which reside in the lower bits of the 128-bit values that are compared.

  • The speedup increases with the input size and reaches 4x when sorting 10^9 (~7.45 GiB) UInts.

  • Open questions:

    • Can the brief increased memory usage when returning the indices be avoided?
    • If yes, can the technique be extended to yield a function p = sort!perm(v) that sorts v in place and additionally returns the permutation?
    • How much does the overhead of the order-preserving transformation for Float64->UInt matter?
    • Is it beneficial to use a MSD radix sort here?

Results using the same Hardware and Julia version as before:

packedsortpermbench5

packedsortpermbench5_speedup

4 Likes

@LSchwerdt This looks promising!

I wonder if this would be convertible to sortperm!?

For anyone still interested:

I am working on a package that implements a very similar appoach to speedup sortperm:

By using StructArrays it is not only faster, but also more general, i.e. it works with arrays of all types.
I am very happy with the results so far. But before listing the package publicly, some more polishing is required.

I wonder if this would be convertible to sortperm!?

You will be happy so know that SimultaneousSortperm implements not only ssortperm!(ix,v), but also ssortperm!(v) and ssortperm!!(ix,v), which sort the input vector too, and thereby manage to use O(1) extra space instead of O(n).

Here are some benchmarks on my AMD 3900x with DDR4-3600 CL 16 memory. Note that the advantage of this approach is even greater when using (regular) memory with higher latency.

Benchmarks

Int64

Int64_almost_presorted

Float64

Float64_by_abs2

Int128

Int64_missing5

Categorical

shortstrings30

8 Likes

Could this be a PR to base? cc @Lilith

That is my goal. But there is quite a bit of code, so getting it to the required quality for base will take some time.
By releasing this as a package and opening a PR for the underlying pattern-defeating-quicksort in SortingAlgorithms.jl first, the individual parts are easier to review and improve. It will be usable earlier, and I avoid creating a basically unreviewable 1000-line PR in base.

Sadly, implementing only a simplified limited version is not an option. Using another underlying sorting algorithm would yield worse performance, and removing all the optimizations for special inputs would lead to significant performance regressions for these special cases.

3 Likes

Hi @LSchwerdt, do you think that the partial-sort versions would be a lot of work? I have an application that uses partial sorting, so I would be interested in trying to write something for this. Alternatively (perhaps a naive question), do you think that sort! could be modified easily/efficiently to keep track of the indices? It seems that the main reason sortperm! is slower than sort! is because of the element access cost for the custom ordering)? An alternative could be to sort a tuple containing the value and original index, e.g., for initial data x = randn(n); ix = zeros(Int64, n), suppose we have y = [(x,i) for (i,x) in enumerate(x)], then we could sort!(x) and set @inbounds for (i,yy) in enumerate(y); ix[i] = yy.i; end, though it seems that accessing yy.i is slow, which is I guess what SimultaneousSortperm tries to avoid?

Partition could potentially benefit as well.

Amazing, hopefully when more mature it gets into base :slight_smile:

As a follow up, I have a small PR that should implement a partial sort !! version (modify vector and index vector) that I’ve done limited testing on. Not much but a start on the partial versions and seems to work for my purpose at the moment. It basically does quick select and then calls the sortperm!! function

2 Likes