[ANN] OptimalSortingNetworks: Sort small collections efficiently and with good type inference

OptimalSortingNetworks (OSN for short) is being registered. Use it to sort small tuples or vectors:

julia> using OptimalSortingNetworks

julia> sorted([10, 5, 7, 9])
4-element Vector{Int64}:
  5
  7
  9
 10

julia> sorted((10, 5, 7, 9))
(5, 7, 9, 10)

julia> sorted(reverse ∘ minmax, (10, 5, 7, 9))
(10, 9, 7, 5)

julia> const new_mm = OptimalSortingNetworks.new_minmax
new_minmax (generic function with 2 methods)

julia> sorted(new_mm(less = >), (10, 5, 7, 9))
(10, 9, 7, 5)

julia> tup = ((key=3, val="a"), (key=1, val="b"), (key=2,val="c"))
((key = 3, val = "a"), (key = 1, val = "b"), (key = 2, val = "c"))

julia> sorted(new_mm(by = (n -> n.key)), tup)
((key = 1, val = "b"), (key = 2, val = "c"), (key = 3, val = "a"))

julia> sorted(new_mm(by = (n -> n.key), less = >), tup)
((key = 3, val = "a"), (key = 2, val = "c"), (key = 1, val = "b"))

sorted does not qualify as a stable sort, i.e., the order among identical elements will not necessarily be preserved.

I intend to create another package for tuple sorting in general, which will probably use merge sort with a dependency on OSN for the base cases of the recursion. EDIT ANN

EDIT: vectors are now supported, in addition to tuples

17 Likes

For folks like me who did not know what a Sorting Network is: Sorting network - Wikipedia

5 Likes

Could it also be used as an algorithm for Base sort? cc @Lilith

Yeah, it could be useful as a performance optimization. Most of the effort for a PR would probably come down to careful benchmarking to find good thresholds for switching to sorting networks.

1 Like

This seems to be significantly faster than Base’s tuple sorting in both runtime and compile time.

julia> for i in 1:20
           println(i)
           x = Tuple(rand(i))
           @time sort(x)
           @time sorted(x)
           @btime sort($x)
           @btime sorted($x)
       end
1
  0.001746 seconds (2.28 k allocations: 108.430 KiB, 98.43% compilation time)
  0.001823 seconds (3.61 k allocations: 169.359 KiB, 99.16% compilation time)
  1.375 ns (0 allocations: 0 bytes)
  1.375 ns (0 allocations: 0 bytes)
2
  0.003862 seconds (14.96 k allocations: 756.289 KiB, 99.35% compilation time)
  0.002913 seconds (5.99 k allocations: 292.672 KiB, 99.49% compilation time)
  1.666 ns (0 allocations: 0 bytes)
  1.375 ns (0 allocations: 0 bytes)
3
  0.005494 seconds (20.86 k allocations: 1.080 MiB, 99.44% compilation time)
  0.002853 seconds (11.49 k allocations: 585.125 KiB, 99.53% compilation time)
  3.083 ns (0 allocations: 0 bytes)
  1.666 ns (0 allocations: 0 bytes)
4
  0.010172 seconds (42.30 k allocations: 2.316 MiB, 99.74% compilation time)
  0.003560 seconds (14.91 k allocations: 772.594 KiB, 99.47% compilation time)
  4.208 ns (0 allocations: 0 bytes)
  2.166 ns (0 allocations: 0 bytes)
5
  0.013185 seconds (38.95 k allocations: 2.096 MiB, 99.78% compilation time)
  0.003609 seconds (21.28 k allocations: 1.079 MiB, 99.51% compilation time)
  7.083 ns (0 allocations: 0 bytes)
  3.916 ns (0 allocations: 0 bytes)
6
  0.026137 seconds (65.27 k allocations: 3.628 MiB, 99.88% compilation time)
  0.004008 seconds (26.30 k allocations: 1.329 MiB, 99.60% compilation time)
  8.625 ns (0 allocations: 0 bytes)
  5.250 ns (0 allocations: 0 bytes)
7
  0.025794 seconds (51.51 k allocations: 2.736 MiB, 99.88% compilation time)
  0.005348 seconds (32.75 k allocations: 1.734 MiB, 99.71% compilation time)
  9.926 ns (0 allocations: 0 bytes)
  6.958 ns (0 allocations: 0 bytes)
8
  0.038604 seconds (67.09 k allocations: 3.516 MiB, 99.93% compilation time)
  0.005249 seconds (37.66 k allocations: 1.988 MiB, 99.66% compilation time)
  16.241 ns (0 allocations: 0 bytes)
  8.334 ns (0 allocations: 0 bytes)
9
  0.022513 seconds (55.22 k allocations: 3.057 MiB, 99.88% compilation time)
  0.006660 seconds (47.14 k allocations: 2.506 MiB, 99.82% compilation time)
  19.664 ns (0 allocations: 0 bytes)
  11.094 ns (0 allocations: 0 bytes)
10
  0.329988 seconds (1.23 M allocations: 67.375 MiB, 99.99% compilation time)
  0.009091 seconds (56.44 k allocations: 2.971 MiB, 99.72% compilation time)
  50.489 ns (2 allocations: 144 bytes)
  13.598 ns (0 allocations: 0 bytes)
11
  0.014698 seconds (79.00 k allocations: 4.314 MiB, 99.69% compilation time)
  0.009332 seconds (62.86 k allocations: 3.298 MiB, 99.85% compilation time)
  60.568 ns (2 allocations: 144 bytes)
  15.364 ns (0 allocations: 0 bytes)
12
  0.018018 seconds (83.03 k allocations: 4.528 MiB, 99.71% compilation time)
ERROR: MethodError: no method matching sorted(::OptimalSortingNetworks.var"#2#3"{typeof(identity), typeof(<)}, ::NTuple{12, Float64}, ::Depth)

Closest candidates are:
  sorted(::Any, ::NTuple{10, Any}, ::Union{Depth, Size})
   @ OptimalSortingNetworks ~/.julia/packages/OptimalSortingNetworks/MzN3U/src/OptimalSortingNetworks.jl:111
  sorted(::Any, ::Union{Tuple{}, Tuple{Any}, Tuple{Any, Any}, Tuple{Any, Any, Any}, NTuple{4, Any}, NTuple{5, Any}, NTuple{6, Any}, NTuple{7, Any}, NTuple{8, Any}, NTuple{9, Any}, NTuple{11, Any}}, ::Union{Depth, Size})
   @ OptimalSortingNetworks ~/.julia/packages/OptimalSortingNetworks/MzN3U/src/OptimalSortingNetworks.jl:106
  sorted(::Any, ::NTuple{10, Any})
   @ OptimalSortingNetworks ~/.julia/packages/OptimalSortingNetworks/MzN3U/src/OptimalSortingNetworks.jl:111
  ...

Stacktrace:
 [1] sorted(t::NTuple{12, Float64}, o::Depth)
   @ OptimalSortingNetworks ~/.julia/packages/OptimalSortingNetworks/MzN3U/src/OptimalSortingNetworks.jl:99
 [2] macro expansion
   @ ./timing.jl:282 [inlined]
 [3] top-level scope
   @ ./REPL[7]:5

PRs to base welcome :slight_smile:

7 Likes

Does this have any difficult-to-remedy weaknesses compared to SortingNetworks.jl or is it a strict win?

It should be a strict win. There’s nothing deeply wrong with the SortingNetworks package (except for the fact that minmax is hardcoded instead of a parameter), but creating a new package seemed easier than trying to address the couple of perceived issues with SortingNetworks, given that the algorithms are trivial anyway.

One thing to note is that my style was to code each sorting network in a sort of single-assignment form, I figured that each variable being assigned to only once could help the compiler’s type inference when the tuple is heterogeneous. SortingNetwork instead uses the minimal number of local variables. I’m not totally sure of the implications of one vs the other implementation approach.

FTR, it crossed my mind that it would make sense to support static vectors in addition to tuples, because an SArray is basically just a Tuple in a wrapper with additional type annotations. However then I realized that StaticArrays already supports some sorting functionality. Given that circular dependencies are nonsensical (two packages can’t depend on eachother, I think), I’ll instead try to make StaticArrays depend on OptimalSortingNetworks, to improve the SA sort performance (glancing at their code, it seems like it could be improved in both small and large cases).

1 Like

For homogeneous tuples of subtypes of Real (which is the only kind supported by SortingNetworks) it probably makes no difference at all once Julia has transformed it to SSA form. At least @code_typed sorted((1, 2, 3)) and @code_typed swapsort((1, 2, 3)) look structurally identical.

I’m not sure how you arrived at the circular dependency, but this looks like an ideal case for a package extension.

You could also consider a convenience function for sorting of short (standard) vectors. SortingNetworks sort of (pun intended) supports that except that it somewhat weirdly returns a tuple.

1 Like

Thanks, I supposed it works like that but wasn’t sure.

StaticArrays already implements one sorting algorithm (<:Base.Sort.Algorithm). I think it would make sense to add another one (based on merge sort and OSN for the base cases) to StaticArrays, but then I couldn’t depend on SA from OSN,

I want to do this, but need to be careful to prevent run time dispatch.

I would try something like this, with or without metaprogramming:

sorted(x::AbstractVector) = sorted!(copy(x))

function sorted!(x::AbstractVector)
    if length(x) == 1
        x .= sorted((x[1],))
    elseif length(x) == 2
        x .= sorted((x[1], x[2]))
    elseif length(x) == 3
        x .= sorted((x[1], x[2], x[3]))
    else
        error("Vector too long.")
    end
    return x
end

Edit: Feel free to optimize out the one element case. :slight_smile:

1 Like

Now the package supports sorting vectors. The interface currently allocates and copies a new vector, I should probably try to also add a mutable sort, something like your sorted!. Just have to think about how to do it in the nicest way possible, not sure whether to just add a new function or something else.

1 Like

Yeah, it loses some of the charm of having a super fast Vector sort algorithm if you can’t do it in place.

Added sorted!, but it allocates even though it works in-place. I think this is due to a bug in Julia:

https://github.com/JuliaLang/julia/issues/51955

EDIT: sorted! doesn’t allocate any more now (when applied to homogeneous vectors), thanks to workaround/fix provided by @Lilith

1 Like

Not ideal for a default sort because it’s unstable, but it could be another algorithm for sort(;alg=SortingNetwork()).