I’m continuing my efforts to implement SIMD-friendly sorting in Julia, and I am trying to implement the basic sorting network steps. This is looking like a possible good use-case for metaprogramming (again), but it seems to be still beyond my abilities. How do I do this?
I basically want a function that will take a tuple in, compare specific pairs of values and output another tuple. We run multiple such steps in sequence, and hopefully, the whole thing gets compiled into a big chunk of highly optimized code.
One example of a sorting network for 4 values:
function sort_4(a)
b = (min(a[1],a[2]), max(a[1],a[2]), min(a[3],a[4]), max(a[3],a[4]))
c = (min(b[1],b[3]), max(b[1],b[3]), min(b[2],b[4]), max(b[2],b[4]))
(c[1], min(c[2],c[3]), max(c[2],c[3]), c[4])
end
In practice, each tuple can actually contain e.g. 16x16-bit values, and that’s when things begin to get exciting. Larger networks can be created also for 8 or 16 inputs. I want to write a macro or generated function (or actual function?) to help me implement all these different possible networks.
This generic compare_step
function would allow us to write the previous function like this:
function sort_4(a)
b = compare_step(a, (1,2), (3, 4))
c = compare_step(b, (1,3), (2, 4))
compare_step(d, (2,3))
end
The following function achieves it using an array
function compare_step_arr(input::NTuple{N, T}, tt...) where {N, T}
output = [input[n] for n in 1:N]
for t in tt
output[t[1]] = min(input[t[1]], input[t[2]])
output[t[2]] = max(input[t[1]], input[t[2]])
end
tuple(output...)
end
julia> vals = (9,8,7,6,5,4,3,2)
(9, 8, 7, 6, 5, 4, 3, 2)
julia> compare_step_arr(vals, (1,3), (5,7))
(7, 8, 9, 6, 3, 4, 5, 2)
But this is ruined by the fact there is a mutable array. The code might be adapted to output expressions, though.
function compare_step_exp(input::NTuple{N, T}, tt...) where {N, T}
output = [:(input[n]) for n in 1:N]
for t in tt
output[t[1]] = :(min(input[t[1]], input[t[2]]))
output[t[2]] = :(max(input[t[1]], input[t[2]]))
end
Expr(:call, :tuple, output...)
end
julia> vals = (9,8,7,6,5,4,3,2)
(9, 8, 7, 6, 5, 4, 3, 2)
julia> compare_step_exp(vals, (1,3), (5,7))
:(tuple(min(input[1], input[3]), input[2], max(input[1], input[3]), input[4], min(input[5], input[7]), input[6], max(input[5], input[7]), input[8]))
But this is as far as I could go. How do I now create a macro or generated function out of this? Or what could be a better approach instead?