Pushing the limits of recursive inlining with tuples

I am trying to understand the limits of what the compiler can do when functions are defined with a lot of compile-time information, like tuples of fixed length, or also static arrays.

In this example, we have a selection algorithm that allows us to compute the median from an unsorted input. When the input is small, say 5 or 7 values, it would be great if inlining might output code where we only have a lot of comparisons until we reach the answer in the end, with no intermediate function calls.

This is not what happens, though. Can I modify this program somehow to obtain this resul? Or maybe this is a case I should use a generated function?

partition(p, aa) = partition(p, aa, (), ())

function partition(p, a, l, r)
    if length(a) == 0
        (l, r)
    elseif a[1] <= p
        partition(p, a[2:end], (a[1], l...), r)
    else
        partition(p, a[2:end], l, (a[1], r...))
    end
end

function select(k, a)
    println(k, a)
    if length(a) == 1
        a[1]
    else
        p = a[1]
        l, r = partition(p, a)
        println(l, r)
        if length(l) > k
            select(k, l)
        elseif length(l) < k
            select(k - length(l), r)
        else # length(l)==k
            p
        end
    end
end

@code_native select(2, (7,8,4))

You have a lot of type instabilities. From the accidental and avoidable:

julia> foo(a) = a[2:end]
foo (generic function with 1 method)

julia> @code_warntype foo(a)
Body::Tuple{Vararg{Int64,N} where N}
1 ─ %1  = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Int64,1}, svec(Any, Int64), :(:ccall), 2, Array{Int64,1}, 2, 2))::Array{Int64,1}
└──       goto #7 if not true
2 β”„ %3  = Ο† (#1 => 1, #6 => %13)::Int64
β”‚   %4  = Ο† (#1 => 1, #6 => %14)::Int64
β”‚   %5  = (Base.add_int)(1, %3)::Int64
β”‚   %6  = (Base.getfield)(a, %5, true)::Int64
β”‚         (Base.arrayset)(false, %1, %6, %3)
β”‚   %8  = (%4 === 2)::Bool
└──       goto #4 if not %8
3 ─       goto #5
4 ─ %11 = (Base.add_int)(%4, 1)::Int64
└──       goto #5
5 β”„ %13 = Ο† (#4 => %11)::Int64
β”‚   %14 = Ο† (#4 => %11)::Int64
β”‚   %15 = Ο† (#3 => true, #4 => false)::Bool
β”‚   %16 = (Base.not_int)(%15)::Bool
└──       goto #7 if not %16
6 ─       goto #2
7 β”„ %19 = (Core._apply)(Core.tuple, %1)::Tuple{Vararg{Int64,N} where N}
└──       goto #8
8 ─       return %19

to the fundamental, where the return types are based on the value of the inputs:

julia> partition(3, (2,7,8))
((2,), (8, 7))

julia> partition(3, (2,1,8))
((1, 2), (8,))

You can resolve the accidental via defining functions like:

sub2through(a::NTuple{N}) where {N} = ntuple(i -> a[i+1], Val(N-1))

For the fundamental, you would have to change the algorithm in such a way that the types of inputs always determine the types of output.

3 Likes

I keep forgetting about the ntuple function, thanks! It seems to be the key to many of the things I want to do.

Would fixed sized tuples with separate variables storing the length do it, perhaps? My impression is we need to have a variable encoding some kind of state, and the tuple lengths might be this state in this case. What I am actually interested in implementing is a merge sort algorithm, for instance, and I’m trying to get a better understanding of how to go about implementing any algorithms like this.

Maybe the best description of what I am looking for is how to avoid writing code that has lots of variables like a1 = ...; a2 = ... :slight_smile:

Maybe a good general approach is to encode your whole state into a struct?

I’m trying to come up with something, but I suppose as long as I use tuples of varying length anywhere, it’s not going to be type-stable?

This code shows how state can be represented as a type-stable object, with the tuples only existing as transient objects. But it seems it’s not enough.

I believe what I am looking for is a language or technique that allows me to implement a stack on top of the registers. Is this perhaps a highly sophisticated but not entirely practical or relevant compiler optimization technique?

(Why am I doing this? I am trying to implement this article, Proceedings of the VLDB Endowment, and they have a specialized merging code for small lists that works only in the registers. I am trying to obtain something like that, but writing as much high-level Julia as I can)

function tail(a::NTuple{N,T})::NTuple{N-1,T} where {N,T}
    ntuple(i -> a[i+1], Val(N-1))
end

function pack_tuples(tt...)
    ((length(tt[1]),
      length(tt[2])),
     (tt[1]..., tt[2]..., tt[3]...))
end

# pack_tuples((1,2,3), (4,6), (7,8,9))

function unpack_tuples(p, v)
    (ntuple(i -> v[i], p[1]),
     ntuple(i -> v[i + p[1]], p[2]),
     ntuple(i -> v[i + p[1] + p[2]], length(v) - (p[1]+p[2])))
end

# unpack_tuples((3, 2), (1, 2, 3, 4, 6, 7, 8, 9))

partition(p, aa::NTuple{N, Integer}) where N = partition(p, pack_tuples(aa, (), ()))

@inline function partition(p::Int, state::Tuple{Tuple{Int,Int}, NTuple{5, T}}) ::Tuple{Tuple{Int,Int}, NTuple{5, T}} where {T}
    a, l, r = unpack_tuples(state...)

    if length(a) == 0
        state
    elseif a[1] <= p
        newstate = pack_tuples(tail(a), (a[1], l...), r)
        partition(p, newstate)
    else
        newstate = pack_tuples(tail(a), l, (a[1], r...))
        partition(p, newstate)
    end
end

aa = (1,3,5,7,9)
k = 4
@code_warntype partition(k, aa)