Metaprogramming for sorting networks

I’m continuing my efforts to implement SIMD-friendly sorting in Julia, and I am trying to implement the basic sorting network steps. This is looking like a possible good use-case for metaprogramming (again), but it seems to be still beyond my abilities. How do I do this?

I basically want a function that will take a tuple in, compare specific pairs of values and output another tuple. We run multiple such steps in sequence, and hopefully, the whole thing gets compiled into a big chunk of highly optimized code.

One example of a sorting network for 4 values:

function sort_4(a)
    b = (min(a[1],a[2]), max(a[1],a[2]), min(a[3],a[4]), max(a[3],a[4]))
    c = (min(b[1],b[3]), max(b[1],b[3]), min(b[2],b[4]), max(b[2],b[4]))
    (c[1], min(c[2],c[3]), max(c[2],c[3]), c[4])
end

In practice, each tuple can actually contain e.g. 16x16-bit values, and that’s when things begin to get exciting. Larger networks can be created also for 8 or 16 inputs. I want to write a macro or generated function (or actual function?) to help me implement all these different possible networks.

This generic compare_step function would allow us to write the previous function like this:

function sort_4(a)
    b = compare_step(a, (1,2), (3, 4))
    c = compare_step(b, (1,3), (2, 4))
    compare_step(d, (2,3))
end

The following function achieves it using an array

function compare_step_arr(input::NTuple{N, T}, tt...) where {N, T}
    output = [input[n] for n in 1:N]
    for t in tt
        output[t[1]] = min(input[t[1]], input[t[2]])
        output[t[2]] = max(input[t[1]], input[t[2]])
    end
    tuple(output...)
end
julia> vals = (9,8,7,6,5,4,3,2)
(9, 8, 7, 6, 5, 4, 3, 2)

julia> compare_step_arr(vals, (1,3), (5,7))
(7, 8, 9, 6, 3, 4, 5, 2)

But this is ruined by the fact there is a mutable array. The code might be adapted to output expressions, though.

function compare_step_exp(input::NTuple{N, T}, tt...) where {N, T}
    output = [:(input[n]) for n in 1:N]
    for t in tt
        output[t[1]] = :(min(input[t[1]], input[t[2]]))
        output[t[2]] = :(max(input[t[1]], input[t[2]]))
    end
    Expr(:call, :tuple, output...)
end
julia> vals = (9,8,7,6,5,4,3,2)
(9, 8, 7, 6, 5, 4, 3, 2)

julia> compare_step_exp(vals, (1,3), (5,7))
:(tuple(min(input[1], input[3]), input[2], max(input[1], input[3]), input[4], min(input[5], input[7]), input[6], max(input[5], input[7]), input[8]))

But this is as far as I could go. How do I now create a macro or generated function out of this? Or what could be a better approach instead?

1 Like

Not a direct answer, but I’ve have a note in an ideas file that says “use sorting networks as a base case for sorting?” It would be interesting to see if arbitrary sized sorts can be sped up that way.

3 Likes

Cool to hear that you’re interested! Indeed, the idea is that small arrays get handled by these specialized and parallel sorting networks, and above that is merge sort, but also using a SIMD-based k-way merging. I have some initial tests here: SIMD based merge sort in Julia · GitHub

Notice these publications are from the time of 128-bit registers. I am trying to adapt it all to handle 256 and 512 bits too!

Are you aware of https://github.com/JeffreySarnoff/SortingNetworks.jl by @JeffreySarnoff ?

1 Like

Not yet, thanks!

I’m interested in learning this kind of programming as much as just solving the problem… That package looks great, and will definitely come handy, but unfortunately it seems they just implemented each network in the “hard-coded” way.

It took me a while to understand that eval evaluates in the global scope! Now that I finally understood this I managed to write something that seems to work. Any comments are appreciated.

function nested_calls(name, n)
    if n == 0
        :input
    else
        Expr(:call, Symbol(name,n) , nested_calls(name, n-1))
    end
end

nets = (
    (4, (((1,2), (3,4)), ((1,3), (2,4)), ((2,3),))),
)

for nn in 1:1
    inlen, net_params = nets[nn]
    nsteps = length(net_params)

    for st in 1:nsteps
        aa = [:(@inbounds input[$n]) for n in 1:inlen]
        for t in net_params[st]
            aa[t[1]] = :(@inbounds min(input[$(t[1])],input[$(t[2])]))
            aa[t[2]] = :(@inbounds max(input[$(t[1])],input[$(t[2])]))
        end
        eval(Expr(:(=),
                  Expr(:call, Symbol("sort_", inlen, "_step_", st), :input),
                  Expr(:call, :tuple, aa...)))
        eval(Expr(:(=),
                  Expr(:call, Symbol("sort_", inlen), :input),
                  nested_calls("sort_$(inlen)_step_", nsteps)))
    end
end
julia> sort_4(rand(4))
(0.02459309151726008, 0.10103906580180921, 0.17137334025578843, 0.48658695848852007)
1 Like

I assume you’ve seen SortingAlgorithms.jl and SortingLab.jl

2 Likes