 How to unroll tree traversal efficiently?

Hi,

Consider a number of intervals on the real line (`Interval`), defined by their endpoints. I want to calculate whether a given point `x` is inside any of these intervals. I’m actually interested in a more complex case involving Bounding Volume Hierarchies as described here, but the stated problem is simple enough to provide a self contained code example included below. I hope that this rephrasing of the question makes it easier to provide feedback.

To calculate ‘point in interval’ efficiently I introduce a tree structure which holds the `Interval`s in its leaves (`Leaf`), and covering intervals in each parent node (`Branch`), such that all children node intervals are inside the interval of its parents. The tree is implemented as a nested set of structs, such that the tree structure is defined by the type (see implementation). To illustrate the tree structure i have for a randomly generated tree of depth 3 (8 intervals):

``````julia> print_tree(tree)
[0.02271979623487108, 0.9852042040171947]
├─ [0.02271979623487108, 0.48566702956650687]
│  ├─ [0.02271979623487108, 0.23317672198634987]
│  │  ├─ [0.02271979623487108, 0.1274017385601396]
│  │  └─ [0.16575577425266075, 0.23317672198634987]
│  └─ [0.2643166652257143, 0.48566702956650687]
│     ├─ [0.2643166652257143, 0.264794113763708]
│     └─ [0.27868113891671764, 0.48566702956650687]
└─ [0.5223866795888832, 0.9852042040171947]
├─ [0.5223866795888832, 0.7110029880048732]
│  ├─ [0.5223866795888832, 0.5244640570593537]
│  └─ [0.5276831868289535, 0.7110029880048732]
└─ [0.783281134081528, 0.9852042040171947]
├─ [0.783281134081528, 0.8277029573388335]
└─ [0.9110671243196529, 0.9852042040171947]
``````

Below follows a self contained code example with three alternative implementations of the
`inanyinterval(x, intervalcontainer)` method.

1. `inanyinterval_naive(x, intervalvector)` : naively loop through all intervals.
2. `inanyinterval_recursive(x, intervaltree)` : recursively traverse the tree and only consider children intervals whose parent intervals cover `x`.
3. `inanyinterval_unrolled(x, intervaltree)`, use the same logic as 2. but unroll the tree since the structure is known a compile time.

To illustrate the unrolled code, consider the code generated for the tree displayed above (note that the `.data` is an `Interval` and that `iscontainded(x,I::Interval)` determines if `x` lies between the interval boundaries):

``````julia> inanyinterval_unrolled_code(treehead)
quote
function inanyinterval_unrolled(x, h::Branch)
if iscontained(x, h.data)
if iscontained(x, (h.children).data)
if iscontained(x, ((h.children).children).data)
if iscontained(x, (((h.children).children).children).data)
return true
end
if iscontained(x, (((h.children).children).children).data)
return true
end
end
if iscontained(x, ((h.children).children).data)
if iscontained(x, (((h.children).children).children).data)
return true
end
if iscontained(x, (((h.children).children).children).data)
return true
end
end
end
if iscontained(x, (h.children).data)
if iscontained(x, ((h.children).children).data)
if iscontained(x, (((h.children).children).children).data)
return true
end
if iscontained(x, (((h.children).children).children).data)
return true
end
end
if iscontained(x, ((h.children).children).data)
if iscontained(x, (((h.children).children).children).data)
return true
end
if iscontained(x, (((h.children).children).children).data)
return true
end
end
end
end
return false
end
end
``````

I then test the code while I increase the `log2treesize` (I’m using Julia 1.1 on a Macbook Pro). I intentionally use `@time` and not `@btime` in the code below, and run it two times to get a rough estimate of the compile time.

With `log2treesize=6` (64 intervals), I get the following:

``````test naive:
0.032667 seconds (146.51 k allocations: 6.946 MiB)
0.000008 seconds (7 allocations: 432 bytes)
test recursive:
1.886771 seconds (1.76 M allocations: 80.098 MiB, 1.26% gc time)
0.000647 seconds (2.65 k allocations: 2.098 MiB)
build inanyinterval_unrolled(x, tree::Branch):
0.793101 seconds (1.06 M allocations: 57.854 MiB, 2.21% gc time)
0.022648 seconds (10.99 k allocations: 717.089 KiB)
test unrolled:
0.234309 seconds (654.85 k allocations: 27.595 MiB, 2.61% gc time)
0.000005 seconds (6 allocations: 2.438 KiB)
``````

Unsurprisingly, the naive implementation is doing really well, but it seems like my unrolled version also does a good job too.

With `log2treesize=10` (1024 intervals), I get the following:

``````test naive:
0.036071 seconds (146.51 k allocations: 6.948 MiB)
0.001023 seconds (7 allocations: 2.391 KiB)
test recursive:
662.104077 seconds (3.89 M allocations: 760.688 MiB, 0.01% gc time)
0.176274 seconds (89.35 k allocations: 603.809 MiB, 27.61% gc time)
build inanyinterval_unrolled(x, tree::Branch):
211.235229 seconds (2.03 M allocations: 112.390 MiB, 0.01% gc time)
0.644516 seconds (278.68 k allocations: 16.732 MiB, 2.86% gc time)
test unrolled:
87.097795 seconds (15.40 M allocations: 608.760 MiB, 0.93% gc time)
0.000071 seconds (6 allocations: 34.406 KiB)
``````

The unrolled version now clearly outperforms the naive method, but the compilation time of the method which generates the unrolling, and the compilation of the unrolled method itself is getting annoying. The recursive method is unacceptably slow (I’ve commented it out in the code below), and is still outperformed by the naive implementation (even though it makes fewer calls to the `iscontained` method).

I would ideally use this tree structure for much larger trees (`log2treesize ≃20`, i.e. around 1 000 000 intervals)

I therefore have the following questions:

1. Is the nested tree structure a reasonable performant implementation in the first place?
2. If so, is there an obvious way of making the code unrolling more efficient?
3. Is it possible to wrap the unrolling into a `@generated` function? I naively thought I could do it like this `@generated inanyinterval_unrolled(x, h::Branch) = inanyinterval_unrolled_code(h)`. That dosen’t work as the tree is of type `DataType{Branch{...}}` and not `Branch{...}` when the generated function tries to build the code, and the logic of the code generation function dispatches on `Leaf/Branch` type.

All feedback is highly appreciated!

Here is the self-contained code:

``````# optionally use AbstractTrees to print_tree
try
import AbstractTrees: children, print_tree, printnode
catch
end

struct Interval{T}
min::T
max::T
end
Interval(x::T, y::T) where {T} = Interval{T}(x,y)
Interval{T}(Is::Interval{T}...) where {T} = Interval(minimum(I.min for I in Is), maximum(I.max for I in Is))
iscontained(x::T, I::Interval{T}) where {T} = I.min < x < I.max
Base.show(io::IO, I::Interval) = print(io, "[", I.min, ", ", I.max,"]")

# tree related implementation
struct Branch{T, N, C}
data::T
children::NTuple{N, C}
Branch(data::T, children...) where {T} = new{T, length(children), eltype(children)}(data, children)
end

struct Leaf{T}
data::T
end
data(n::Branch)     = n.data
data(n::Leaf)       = n.data
children(n::Branch) = n.children
children(n::Leaf)   = ()

# for pretty printing with Abstract Tree
printnode(io::IO, n::Leaf)   = print(io, n.data)
printnode(io::IO, n::Branch) = print(io, n.data)

"""
iterates over the children leaves, such that `f(data(leaf::Leaf)) = true` and
`f(data(parent::Branch)) = true` for all parent branches.
"""
leaves(f, n::Branch{T, N, C}) where {T, N, C<:Leaf} = (c.data for c in n.children)
leaves(f, n::Branch{T, N, C}) where {T, N, C<:Branch} = Iterators.flatten(leaves(f, c) for c in children(n) if f(c.data))

# functions for creating test data
combine(children::T...) where {T} = Branch(T.parameters((c.data for c in children)...), children...)
function create_data(log2treesize)
N = 2^log2treesize
r = sort(rand(2*N))
intervals = Interval.(r[1:2:end], r[2:2:end])
branches = Leaf.(intervals)
while N>2
branches = [combine(branches[i], branches[i+1]) for i in 1:2:N]
N>>=1
end
tree = combine(branches...)
intervals, tree
end

# testing functions
f(x) = I->iscontained(x, I)
inanyinterval_naive(x, intervals) = any(f(x)(I) for I in intervals)
test_naive(xs, intervals) = [inanyinterval_naive(x, intervals) for x in xs]

inanyinterval_recursive(x, tree) = any(f(x)(I) for I in leaves(f(x), tree))
test_recursive(xs, tree) = [inanyinterval_recursive(x, tree) for x in xs]

function _unroll(::Leaf, expr)
quote
if iscontained(x, \$expr.data)
return true
end
end
end
function _unroll(b::Branch, expr)
children_ = [_unroll(c, :(\$expr.children[\$i])) for (i,c) in enumerate(children(b))]
tmp = quote
if iscontained(x, \$expr.data)
end
end
for c in children_
push!(tmp.args.args.args, c.args)
end
tmp
end

function inanyinterval_unrolled_code(h::Branch)
x = _unroll(h, :h) |> Base.remove_linenums!
return quote
function inanyinterval_unrolled(x, h::Branch)
\$(x.args)
return false
end
end  |> Base.remove_linenums!
end

## testing
log2treesize = 6
N = length(intervals)
inside  = [0.5*(I.min + I.max) for I in intervals]
outside = [0.5*(intervals[i].max + intervals[i+1].min) for i in 1:(length(intervals)-1)]
push!(outside, 1.0)
testdata = [inside outside]'[:]
truth = [trues(N) falses(N)]'[:]

##
println("test naive:")
@time res = test_naive(testdata, intervals)
@time res = test_naive(testdata, intervals)
@assert all(res .== truth)
# takes a long time for large log2treesize
#println("test recursive:")
#@assert all(res .== truth)
println("build inanyinterval_unrolled(x, tree::Branch):") 