Get node indices in NearestNeighbors?

Hey all—

I’m building a KDTree using NearestNeighbors.jl like this:

using StaticArrays, NearestNeighbors
pts  = [SVector{2, Float64}(randn(2)) for _ in 1:1024]
tree = KDTree(pts, leafsize=25) 

and I’m trying to find the correct way to get information about what indices of tree.data are in leaf node j. In this example there are 41 nodes, and I’d love to have some vector leaf_indices such that leaf_indices[j] is a Vector{Int64} of length 25, or of course some other representation of the same information. I think that indices 1:25 are leaf node 1, 26:50 are in leaf node 2, and so on, but I’m not completely certain.

If that’s not right: I see that there are some internal functions for getting the indices of the (re-ordered) data based on the index of a leaf node, but I’m having trouble figuring out what the indices for the leaf nodes are. NearestNeighbors.isleaf just checks if an index is larger than the number of non-leaf nodes (which in this case is 40), but there’s something I’m not understanding because NearestNeighbors.get_leaf_range(tree.tree_data, 41) is 450:474, and the ranges only increase as I increase the argument. No integer argument here gives me 1:25, for example. So I’m clearly not understanding the design here.

Can anybody help me understand the right way to extract this information?

I’m not sure how bad of manners this is to do…but @kristoffer.carlsson, can you help me out with this, assuming that it is as simple as I think it is?

1 Like

It was quite a long time since I worked properly with the internals of NearestNeighbors so my memory is a bit fuzzy on the details. But perhaps this notebook can help https://github.com/KristofferC/NearestNeighbors.jl/blob/master/examples/balltree_illustration.ipynb. It’s a very old notebook though so it probably doesn’t still run, I should update it.

Anyway, it shows some illustrations for a Ball Tree but I think it should be pretty much the same for a KDTree. For example:

# Skip non leaf nodes
offset = tree.tree_data.n_internal_nodes + 1
nleafs = tree.tree_data.n_leafs

# Range of leaf nodes
index_range = offset: offset + nleafs - 1

leaf_indices = map(idx -> NearestNeighbors.get_leaf_range(tree.tree_data, idx), index_range)
41-element Vector{UnitRange{Int64}}:
 450:474
 475:499
 500:524
 525:549
 550:574
 575:599
 600:624
...

This should give you the indices into the reordered data. If you want the indices into your original input data, it should be:

julia> [tree.indices[i] for i in leaf_indices]
41-element Vector{Vector{Int64}}:
 [513, 110, 792, 461, 670, 143, 198, 842, 246, 501  …  744, 945, 341, 807, 862, 825, 739, 26, 677, 96]
 [291, 680, 161, 882, 132, 36, 398, 1010, 727, 345  …  126, 65, 89, 510, 849, 195, 617, 471, 448, 381]
 [169, 783, 772, 477, 355, 323, 614, 684, 202, 768  …  750, 437, 1001, 780, 310, 328, 597, 482, 701, 681]
...

Regarding

They will increase until you reach tree.tree_data.cross_node

julia> tree.tree_data.cross_node
64

julia> NearestNeighbors.get_leaf_range(tree.tree_data, 63)
1000:1024

julia> NearestNeighbors.get_leaf_range(tree.tree_data, 64)
1:25

julia> NearestNeighbors.get_leaf_range(tree.tree_data, 65)
26:50
1 Like

To give an example, let’s say we create a tree with 6 points and leafsize=1 for simplicity:

Note that in NearestNeighbors.jl, we always fill “from the left” and at most one leaf node will not have a completely filled “bucket”.

The black numbers are the numer of the nodes. The green ones are the number of the leaf nodes. The “cross node” is 8 which is the node where the leaf nodes start (but note that the node number for leaf node 1 is higher than the node number for leaf node 5. Doing this in the package:

julia> tree = KDTree(rand(3,6); leafsize=1);

# nodes 1-5 are internal (non-leaf)
julia> tree.tree_data.n_internal_nodes
5

# 8 is the "cross" node where the leaf nodes start
julia> tree.tree_data.cross_node
8

# The range of the leaf nodes for node 8
julia> NearestNeighbors.get_leaf_range(tree.tree_data, 8)
1:1

# The range of leaf nodes for node 6
julia> NearestNeighbors.get_leaf_range(tree.tree_data, 6)
5:5

# and the size of the last (possible unfilled) node bucket is here of course 1
julia> tree.tree_data.last_node_size
1

You can try create a tree with 11 nodes and a leafsize of 2 to see how the last bucket will now be half full.

3 Likes

@kristoffer.carlsson: Wow! I really can’t thank you enough. That’s incredibly helpful and definitely enough to work it out exactly for the KD-tree case. I’ll update this thread with a clean solution shortly.

Thanks again for the amazing response, and thanks for writing one of my favorite packages in the Julia ecosystem.

1 Like