10x faster sortperm()

Motivated by LilithHafner s work to add radix sorting to Julia 1.9 and the resulting performance improvements of sort(), I did some tests to see if sortperm() could be sped up as well.

To test this, I implemented different ways to sort a vector of UInt64s and compared their performance:

sort
standard sort() as a performance reference

sortperm

  • standard sortperm()
  • If I am not mistaken, this uses the Quickersort similar to Quicksort so sort the indices with a comparison function that looks up the values for each comparison.

packed sortperm

  • Pack indices and values into Uint128 vector.
  • Sort using sort!() (uses radix sort).
  • Unpack indices.

packed sortperm 2

  • Pack indices and values into Uint128 vector.
  • Sort using custom radix sort, that skips the lowest 64 bits, as the indices are already sorted.
  • Unpack indices.

radix sortperm by reference

  • Sort indices using custom radix sort, that looks up the value for each index during each radix pass.

All radix sorts are minimal modifications of the LSD radix sort in Julia 1.9.
The Benchmarks are run on Win 10 with an AMD 3900x with DDR4-3600 (Julia version 1.9.0-beta2 7daffeecb8).

Observations

  • Both sorting methods which sort by reference see a significant drop in performance for vectors of more than about 10^6 elements, which roughly corresponds to the size of the L3 cache (directly accessible from one core).
  • Radix sort by reference is the slowest option for big inputs, but the fastest for small ones.
  • Packed sortperm using radix sort can achieve up to 10x speedup for large datasets, with the optimized version skipping half the bits about twice as fast as the basic version.

I don’t know if it would be worth it to include something like this in the standard library, especially because this packed version uses twice the memory, but I wanted to share my results nonetheless.

Here is my (ugly) code, modified to run up to vectors of size 10^7 (10^9 needs 64GB of RAM):

using Random
import Plots
mypack(a,b) = UInt128(a)<<64 + UInt128(b)

function my_radix_sort!(v::AbstractVector{U}, lo::Integer, hi::Integer, bits::Unsigned, 
                     t::AbstractVector{U}, offset::Integer,
                     shift,chunk_size) where U <: Unsigned
    # bits is unsigned for performance reasons.
    counts = Vector{Int}(undef, 1 << chunk_size + 1) # TODO use scratch for this
    while true
        @noinline my_radix_sort_pass!(t, lo, hi, offset, counts, v, shift, chunk_size)
		#return(v,t)
        # the latest data resides in t
        shift += chunk_size
        shift < bits || return false
        @noinline my_radix_sort_pass!(v, lo+offset, hi+offset, -offset, counts, t, shift, chunk_size)
        # the latest data resides in v
        shift += chunk_size
        shift < bits || return true
    end
end
function my_radix_sort_pass!(t, lo, hi, offset, counts, v, shift, chunk_size)
    mask = UInt(1) << chunk_size - 1  # mask is defined in pass so that the compiler
    @inbounds begin                   #  ↳ knows it's shape
        # counts[2:mask+2] will store the number of elements that fall into each bucket.
        # if chunk_size = 8, counts[2] is bucket 0x00 and counts[257] is bucket 0xff.
        counts .= 0
        for k in lo:hi
            x = v[k]				  # lookup the element
			#show(x)
			#println(typeof(x))
            i = (x >> shift)&mask + 2 # compute its bucket's index for this pass
            #println(i)
			counts[i] += 1            # increment that bucket's count
        end

        counts[1] = lo                # set target index for the first bucket
        cumsum!(counts, counts)       # set target indices for subsequent buckets
        #println(counts)
		# counts[1:mask+1] now stores indices where the first member of each bucket
        # belongs, not the number of elements in each bucket. We will put the first element
        # of bucket 0x00 in t[counts[1]], the next element of bucket 0x00 in t[counts[1]+1],
        # and the last element of bucket 0x00 in t[counts[2]-1].

        for k in lo:hi
            x = v[k]                  # lookup the element
            i = (x >> shift)&mask + 1 # compute its bucket's index for this pass
            j = counts[i]             # lookup the target index
            t[j + offset] = x         # put the element where it belongs
            counts[i] = j + 1         # increment the target index for the next
        end                           #  ↳ element in this bucket
    end
end


function packedsortperm(v)
tups = Vector{UInt128}(undef, length(v))
for idx = 1:length(v)
	tups[idx] = mypack(v[idx],idx)
end
sort!(tups)
Int.(tups .& 0xffffffff)
end

function packedsortperm2(v)
tups = Vector{UInt128}(undef, length(v))
for idx = 1:length(v)
	tups[idx] = mypack(v[idx],idx)
end
t = Vector{UInt128}(undef, length(v))
flag = my_radix_sort!(tups, 1, length(v), UInt(128), t, UInt(0), 64, UInt8(10))
if flag
	return Int.(tups .& 0xffffffff)
else
	return Int.(t .& 0xffffffff)
end
end


function my_radix_sort_pass_by_reference!(t, lo, hi, offset, counts, v, shift, chunk_size,key)
    mask = UInt(1) << chunk_size - 1  # mask is defined in pass so that the compiler
    @inbounds begin                   #  ↳ knows it's shape
        # counts[2:mask+2] will store the number of elements that fall into each bucket.
        # if chunk_size = 8, counts[2] is bucket 0x00 and counts[257] is bucket 0xff.
        counts .= 0
        for k in lo:hi
            x = v[k]				  # lookup the element
            i = (key[x] >> shift)&mask + 2 # compute its bucket's index for this pass
			counts[i] += 1            # increment that bucket's count
        end

        counts[1] = lo                # set target index for the first bucket
        cumsum!(counts, counts)       # set target indices for subsequent buckets
		# counts[1:mask+1] now stores indices where the first member of each bucket
        # belongs, not the number of elements in each bucket. We will put the first element
        # of bucket 0x00 in t[counts[1]], the next element of bucket 0x00 in t[counts[1]+1],
        # and the last element of bucket 0x00 in t[counts[2]-1].

        for k in lo:hi
            x = v[k]                  # lookup the element
            i = (key[x] >> shift)&mask + 1 # compute its bucket's index for this pass
            j = counts[i]             # lookup the target index
            t[j + offset] = x         # put the element where it belongs
            counts[i] = j + 1         # increment the target index for the next
        end                           #  ↳ element in this bucket
    end
end

function my_radix_sort_by_reference!(v::AbstractVector{U}, lo::Integer, hi::Integer, bits::Unsigned, 
                     t::AbstractVector{U}, offset::Integer,
                     shift,chunk_size,key) where U <: Unsigned
    # bits is unsigned for performance reasons.
    counts = Vector{Int}(undef, 1 << chunk_size + 1) # TODO use scratch for this
    while true
        @noinline my_radix_sort_pass_by_reference!(t, lo, hi, offset, counts, v, shift, chunk_size,key)
        # the latest data resides in t
        shift += chunk_size
        shift < bits || return false
        @noinline my_radix_sort_pass_by_reference!(v, lo+offset, hi+offset, -offset, counts, t, shift, chunk_size,key)
        # the latest data resides in v
        shift += chunk_size
        shift < bits || return true
    end
end


function radixsortperm(v)
idxs = UInt.(collect(1:length(v)))
t = Vector{UInt64}(undef, length(v))
flag = my_radix_sort_by_reference!(idxs, 1, length(v), UInt(64), t, UInt(0), 0, UInt8(10),v)
if flag
	return Int.(idxs)
else
	return Int.(t)
end
end


function sortpermbench(sizes)
	nSamples = length(sizes)
	t = zeros(nSamples,5)
	for (idx,n) in enumerate(sizes)
		print("$(idx)/$(nSamples)")
		v = rand(UInt64,n)
		GC.gc()
		
		v1 = copy(v)
		t[idx,1] = @elapsed v2 = sort(v1)
		print(".")
		
		v1 = copy(v)
		t[idx,2] = @elapsed p2 = sortperm(v1)
		print(".")
		
		v1 = copy(v)
		t[idx,3] = @elapsed p3 = packedsortperm(v1)
		print(".")
		
		v1 = copy(v)
		t[idx,4] = @elapsed p4 = packedsortperm2(v1)
		print(".")
		
		v1 = copy(v)
		t[idx,5] = @elapsed p5 = radixsortperm(v1)
		print(".")
		
		@assert p2 == p3
		@assert p2 == p4
		@assert p2 == p5
		println(".")
		
	end
	t
end


biasExp = 6
sizes = convert.(Int,round.(10 .^ (range(2^(1/biasExp),9^(1/biasExp),250)).^biasExp)) |> shuffle

t = sortpermbench(sizes)

Plots.scatter(sizes,sizes./t,xaxis=:log,yaxis=:log,xlabel="Input Size / Elements",ylabel= "Elements / Second",label=["sort" "sortperm" "packed sortperm" "packed sortperm 2" "radix sortperm by reference"],legend=:bottomleft,xticks=10.0 .^ (2:7),minorticks=10)
Plots.ylims!(1e6, 1e8)
Plots.savefig("packedsortpermbench3.png")

Plots.hline([1], color=:black,lw=2,label="sortperm")
Plots.hline!([1], color=:black,lw=1,label=nothing)
Plots.scatter!(sizes,t[:,2]./t[:,3:5],xaxis=:log,xlabel="Input Size / Elements",ylabel= "Speedup vs sortperm",label=["packed sortperm" "packed sortperm 2" "radix sortperm by reference"],legend=:topleft,xticks=10.0 .^ (2:7),minorticks=10,alpha = 1.0)
Plots.ylims!(0, 12)
Plots.xlims!(1e2, 1e7)
Plots.savefig("packedsortpermbench3_speedup.png")
10 Likes

Nice work.

A simple (Conceptually, not in code) optimization, for arrays, would be adapting the number of bits used for the extension according to the array size and data type.

1 Like

I tried to find some more realistic ways to speed up sortperm (ie. without using enormous amounts of memory). Using 32 bit array indices is an obvious approach, but that would still use too much memory and does not scale to really big arrays, where conserving memory is arguably most important. Instead I tried the following approaches:

sortperm alg =MergeSort

  • Mergesort does use less comparisons than Quicksort, so using it when sorting by reference does lead to a small performance gain for very large instances.
  • The speedup is very small and does not warrant indroducing a special case for this.
  • Maybe doing the merge interleaved from the left and the right side (called Parity Merge here) can lead to further speedups, by allowing for more independent array accesses to be in flight at the same time.

Packed QuickSort

  • Using an in-place sorting algorithm with packed values and indices uses exactly the same amount of memory as regular sortperm while running. When returning the indices, the packed array and the vector of indices to be returned is briefly present in memory at the same time. Maybe it is possible to avoid this?

  • QuickSort (not ScratchQuickSort!) does work in-place. Not beeing a stable sort is no problem, as ties are broken using the indices, which reside in the lower bits of the 128-bit values that are compared.

  • The speedup increases with the input size and reaches 4x when sorting 10^9 (~7.45 GiB) UInts.

  • Open questions:

    • Can the brief increased memory usage when returning the indices be avoided?
    • If yes, can the technique be extended to yield a function p = sort!perm(v) that sorts v in place and additionally returns the permutation?
    • How much does the overhead of the order-preserving transformation for Float64->UInt matter?
    • Is it beneficial to use a MSD radix sort here?

Results using the same Hardware and Julia version as before:

packedsortpermbench5

packedsortpermbench5_speedup

1 Like