How to unroll tree traversal efficiently?

Hi,

Consider a number of intervals on the real line (Interval), defined by their endpoints. I want to calculate whether a given point x is inside any of these intervals. I’m actually interested in a more complex case involving Bounding Volume Hierarchies as described here, but the stated problem is simple enough to provide a self contained code example included below. I hope that this rephrasing of the question makes it easier to provide feedback.

To calculate ‘point in interval’ efficiently I introduce a tree structure which holds the Intervals in its leaves (Leaf), and covering intervals in each parent node (Branch), such that all children node intervals are inside the interval of its parents. The tree is implemented as a nested set of structs, such that the tree structure is defined by the type (see implementation). To illustrate the tree structure i have for a randomly generated tree of depth 3 (8 intervals):

julia> print_tree(tree)
[0.02271979623487108, 0.9852042040171947]
├─ [0.02271979623487108, 0.48566702956650687]
│  ├─ [0.02271979623487108, 0.23317672198634987]
│  │  ├─ [0.02271979623487108, 0.1274017385601396]
│  │  └─ [0.16575577425266075, 0.23317672198634987]
│  └─ [0.2643166652257143, 0.48566702956650687]
│     ├─ [0.2643166652257143, 0.264794113763708]
│     └─ [0.27868113891671764, 0.48566702956650687]
└─ [0.5223866795888832, 0.9852042040171947]
   ├─ [0.5223866795888832, 0.7110029880048732]
   │  ├─ [0.5223866795888832, 0.5244640570593537]
   │  └─ [0.5276831868289535, 0.7110029880048732]
   └─ [0.783281134081528, 0.9852042040171947]
      ├─ [0.783281134081528, 0.8277029573388335]
      └─ [0.9110671243196529, 0.9852042040171947]

Below follows a self contained code example with three alternative implementations of the
inanyinterval(x, intervalcontainer) method.

  1. inanyinterval_naive(x, intervalvector) : naively loop through all intervals.
  2. inanyinterval_recursive(x, intervaltree) : recursively traverse the tree and only consider children intervals whose parent intervals cover x.
  3. inanyinterval_unrolled(x, intervaltree), use the same logic as 2. but unroll the tree since the structure is known a compile time.

To illustrate the unrolled code, consider the code generated for the tree displayed above (note that the .data is an Interval and that iscontainded(x,I::Interval) determines if x lies between the interval boundaries):

julia> inanyinterval_unrolled_code(treehead)
quote
    function inanyinterval_unrolled(x, h::Branch)
        if iscontained(x, h.data)
            if iscontained(x, (h.children[1]).data)
                if iscontained(x, ((h.children[1]).children[1]).data)
                    if iscontained(x, (((h.children[1]).children[1]).children[1]).data)
                        return true
                    end
                    if iscontained(x, (((h.children[1]).children[1]).children[2]).data)
                        return true
                    end
                end
                if iscontained(x, ((h.children[1]).children[2]).data)
                    if iscontained(x, (((h.children[1]).children[2]).children[1]).data)
                        return true
                    end
                    if iscontained(x, (((h.children[1]).children[2]).children[2]).data)
                        return true
                    end
                end
            end
            if iscontained(x, (h.children[2]).data)
                if iscontained(x, ((h.children[2]).children[1]).data)
                    if iscontained(x, (((h.children[2]).children[1]).children[1]).data)
                        return true
                    end
                    if iscontained(x, (((h.children[2]).children[1]).children[2]).data)
                        return true
                    end
                end
                if iscontained(x, ((h.children[2]).children[2]).data)
                    if iscontained(x, (((h.children[2]).children[2]).children[1]).data)
                        return true
                    end
                    if iscontained(x, (((h.children[2]).children[2]).children[2]).data)
                        return true
                    end
                end
            end
        end
        return false
    end
end

I then test the code while I increase the log2treesize (I’m using Julia 1.1 on a Macbook Pro). I intentionally use @time and not @btime in the code below, and run it two times to get a rough estimate of the compile time.

With log2treesize=6 (64 intervals), I get the following:

test naive:
  0.032667 seconds (146.51 k allocations: 6.946 MiB)
  0.000008 seconds (7 allocations: 432 bytes)
test recursive:
  1.886771 seconds (1.76 M allocations: 80.098 MiB, 1.26% gc time)
  0.000647 seconds (2.65 k allocations: 2.098 MiB)
build inanyinterval_unrolled(x, tree::Branch):
  0.793101 seconds (1.06 M allocations: 57.854 MiB, 2.21% gc time)
  0.022648 seconds (10.99 k allocations: 717.089 KiB)
test unrolled:
  0.234309 seconds (654.85 k allocations: 27.595 MiB, 2.61% gc time)
  0.000005 seconds (6 allocations: 2.438 KiB)

Unsurprisingly, the naive implementation is doing really well, but it seems like my unrolled version also does a good job too.

With log2treesize=10 (1024 intervals), I get the following:

test naive:
  0.036071 seconds (146.51 k allocations: 6.948 MiB)
  0.001023 seconds (7 allocations: 2.391 KiB)
test recursive:
662.104077 seconds (3.89 M allocations: 760.688 MiB, 0.01% gc time)
  0.176274 seconds (89.35 k allocations: 603.809 MiB, 27.61% gc time)
build inanyinterval_unrolled(x, tree::Branch):
211.235229 seconds (2.03 M allocations: 112.390 MiB, 0.01% gc time)
  0.644516 seconds (278.68 k allocations: 16.732 MiB, 2.86% gc time)
test unrolled:
 87.097795 seconds (15.40 M allocations: 608.760 MiB, 0.93% gc time)
  0.000071 seconds (6 allocations: 34.406 KiB)

The unrolled version now clearly outperforms the naive method, but the compilation time of the method which generates the unrolling, and the compilation of the unrolled method itself is getting annoying. The recursive method is unacceptably slow (I’ve commented it out in the code below), and is still outperformed by the naive implementation (even though it makes fewer calls to the iscontained method).

I would ideally use this tree structure for much larger trees (log2treesize ≃20, i.e. around 1 000 000 intervals)

I therefore have the following questions:

  1. Is the nested tree structure a reasonable performant implementation in the first place?
  2. If so, is there an obvious way of making the code unrolling more efficient?
  3. Is it possible to wrap the unrolling into a @generated function? I naively thought I could do it like this @generated inanyinterval_unrolled(x, h::Branch) = inanyinterval_unrolled_code(h). That dosen’t work as the tree is of type DataType{Branch{...}} and not Branch{...} when the generated function tries to build the code, and the logic of the code generation function dispatches on Leaf/Branch type.

All feedback is highly appreciated!

Here is the self-contained code:

# optionally use AbstractTrees to print_tree
try
    import AbstractTrees: children, print_tree, printnode
catch
end

struct Interval{T}
    min::T
    max::T
end
Interval(x::T, y::T) where {T} = Interval{T}(x,y)
Interval{T}(Is::Interval{T}...) where {T} = Interval(minimum(I.min for I in Is), maximum(I.max for I in Is))
iscontained(x::T, I::Interval{T}) where {T} = I.min < x < I.max
Base.show(io::IO, I::Interval) = print(io, "[", I.min, ", ", I.max,"]")

# tree related implementation
struct Branch{T, N, C}
    data::T
    children::NTuple{N, C}
    Branch(data::T, children...) where {T} = new{T, length(children), eltype(children)}(data, children)
end

struct Leaf{T}
    data::T
end
data(n::Branch)     = n.data
data(n::Leaf)       = n.data
children(n::Branch) = n.children
children(n::Leaf)   = ()

# for pretty printing with Abstract Tree
printnode(io::IO, n::Leaf)   = print(io, n.data)
printnode(io::IO, n::Branch) = print(io, n.data)

"""
iterates over the children leaves, such that `f(data(leaf::Leaf)) = true` and
`f(data(parent::Branch)) = true` for all parent branches.
"""
leaves(f, n::Branch{T, N, C}) where {T, N, C<:Leaf} = (c.data for c in n.children)
leaves(f, n::Branch{T, N, C}) where {T, N, C<:Branch} = Iterators.flatten(leaves(f, c) for c in children(n) if f(c.data))

# functions for creating test data
combine(children::T...) where {T} = Branch(T.parameters[1]((c.data for c in children)...), children...)
function create_data(log2treesize)
    N = 2^log2treesize
    r = sort(rand(2*N))
    intervals = Interval.(r[1:2:end], r[2:2:end])
    branches = Leaf.(intervals)
    while N>2
        branches = [combine(branches[i], branches[i+1]) for i in 1:2:N]
        N>>=1
    end
    tree = combine(branches...)
    intervals, tree
end

# testing functions
f(x) = I->iscontained(x, I)
inanyinterval_naive(x, intervals) = any(f(x)(I) for I in intervals)
test_naive(xs, intervals) = [inanyinterval_naive(x, intervals) for x in xs]

inanyinterval_recursive(x, tree) = any(f(x)(I) for I in leaves(f(x), tree))
test_recursive(xs, tree) = [inanyinterval_recursive(x, tree) for x in xs]

function _unroll(::Leaf, expr)
    quote
        if iscontained(x, $expr.data)
            return true
        end
    end
end
function _unroll(b::Branch, expr)
    children_ = [_unroll(c, :($expr.children[$i])) for (i,c) in enumerate(children(b))]
    tmp = quote
        if iscontained(x, $expr.data)
        end
    end
    for c in children_
        push!(tmp.args[2].args[2].args, c.args[2])
    end
    tmp
end

function inanyinterval_unrolled_code(h::Branch)
    x = _unroll(h, :h) |> Base.remove_linenums!
    return quote
        function inanyinterval_unrolled(x, h::Branch)
            $(x.args[1])
            return false
        end
    end  |> Base.remove_linenums!
end
create_inanyinterval_unrolled(treehead::Branch) = eval(inanyinterval_unrolled_code(treehead))
test_unrolled(xs, treehead) = [inanyinterval_unrolled(x, treehead) for x in xs]

## testing
log2treesize = 6
intervals, treehead = create_data(log2treesize)
N = length(intervals)
inside  = [0.5*(I.min + I.max) for I in intervals]
outside = [0.5*(intervals[i].max + intervals[i+1].min) for i in 1:(length(intervals)-1)]
push!(outside, 1.0)
testdata = [inside outside]'[:]
truth = [trues(N) falses(N)]'[:]

##
println("test naive:")
@time res = test_naive(testdata, intervals)
@time res = test_naive(testdata, intervals)
@assert all(res .== truth)
# takes a long time for large log2treesize
#println("test recursive:")
#@time res = test_recursive(testdata, treehead)
#@time res = test_recursive(testdata, treehead)
#@assert all(res .== truth)
println("build inanyinterval_unrolled(x, tree::Branch):")
@time create_inanyinterval_unrolled(treehead)
@time create_inanyinterval_unrolled(treehead)
println("test unrolled:")
@time res = test_unrolled(testdata, treehead)
@time res = test_unrolled(testdata, treehead)
@assert all(res .== truth)

Perhaps only tangentially related but why not store things in a flat array like

image

instead of having to jump through pointers to get to the children? I do that in https://github.com/KristofferC/NearestNeighbors.jl.

3 Likes