[ANN] Shrike.jl: Julia beats optimized C++ (Fast accurate approximate nearest neighbor search)

I wrote a Julia package (Shrike.jl) for approximate nearest neighbor search. My algorithm draws heavily on the mrpt algorithm and uses the same principles behind annoy (a popular approximate nearest neighbor package by Spotify). My Julia implementation beat both of them! See the Shrike.jl/README.md for details.

I was shocked and thrilled! I am no Julia expert, but the fact that I was able to outperform about 2000 lines of C++ in 500 lines of Julia is a real testament to the power of the language.

As a caveat, I should mention that writing Shrike.jl wasn’t simple. I spent a lot of time learning about and testing different Julia optimizations strategies.

This is my first time making a public Julia package and I am planning to figure out how to add it to the registry. Any suggestions or feedback on the package would be appreciated. For instance, where can I leave out type annotations and not have it impact performance?

59 Likes

8 posts were split to a new topic: Naming diversion for Shrike.jl

This looks great!
Have you considered porting it to MLJ.jl?
They currently have 150+ supported models (including NearestNeighborModels.jl).

7 Likes

Great! Is this package comparable to the widely used and extremely efficient NearestNeighbors.jl? If yes, can you benchmark against as well? That would be a good reference to have.

1 Like

He does this exact benchmark in the readme IIUC

2 Likes

I see the plots now. I wish they were numbers though with comparable scenarios in the sense that each parameter combination is “equivalent”. I cannot fully support the statement in the README that says:

It is important to note that NearestNeighbors.jl was designed to return the exact k-nearest-neighbors as quickly as possible, and does not approximate, hence the high accuracy and lower speed.

by just looking at the dots in the plot. Can we really compare the dots in the upper left to the dots in the right of the plot?

BTW It is wonderful to see speed! I am just trying to be “scientifically careful” with the conclusions because nearest neighbors is a really useful tool and I depend a lot on NearestNeighbors.jl in my own packages. If @djpasseyjr could show that there is a real advantage here, I could consider taking a more in-depth look into Shrike.jl to use it in future projects.

6 Likes

Glancing at your code, it should be possible to make it faster still — it seems like you do a lot of “vectorized” programming (ala Matlab/NumPy) allocating arrays in what look like critical loops.

13 Likes

Just as a note regarding NearestNeighbors.jl, due to curse of dimensionality, the kd-tree becomes very bad for high dimensional data (k-d tree - Wikipedia) You want something like N >> 2^k to use it where N is the number of points and k is the dimension.

The original motivation of NearestNeighbors.jl was to run it on geometrical queries (so mostly 2D and 3D) so it is likely there it will perform the best.

9 Likes

Super interested in this! Can you point me to some resources for learning how to do this? What parts of the code are you referring to?

2 Likes

Very nice work. I too have implemented ANN algorithms in Julia (unreleased so far) which perform comparable to C++ implementations in vastly fewer lines of code.

Have you seen http://ann-benchmarks.com/? It would be good to get this in there for a more comprehensive comparison.

1 Like

For example, on the function traverse_tree you could probably preallocate the mask array and the _get_split array. Which would increase performance a little bit.

The way I see it, it depends on what you need. If you need the exact k nearest neighbors, then stick with NearestNeighbors.jl. If 80% of the exact nearest neighbors (with the remaining being 20% points that are close) is good enough, then you can get a speed boost by using an approximate nearest neighbors algorithm.

In this paper, the authors construct a nearest neighbor graph, by drawing edges between a point and it’s nearest neighbors. Because the graph just gives a topology to the data, and it doesn’t need to be exact, they use an approximate nearest neighbors search to speed things up.

NearestNeighborDescent.jl is also a pure Julia package for approximate nearest neighbors; might be worth adding it to your benchmarks.

4 Likes

Other than thinking about ANN vs exact, I also think it is quite important to also benchmark the two packages on your specific data set. As an example using 3-dimensional points (and please correct me if I messed up some of the hyper parameters for Shrike.jl, I used those I found in the READMEl):

julia> dim = 3; n_points = 10^4; X = rand(dim, n_points); k = 10;

julia> using Shrike;

julia> @time shi = ShrikeIndex(X; depth=6, ntrees=5);
  0.199808 seconds (3.00 k allocations: 11.166 MiB)

julia> @time nn = allknn(shi, k; vote_cutoff=1, ne_iters=0);
 13.058128 seconds (379.93 k allocations: 358.101 MiB, 76.59% gc time)

julia> using NearestNeighbors;

julia> @time kd_tree = KDTree(X);
  0.006030 seconds (27 allocations: 657.781 KiB)

julia> @time nn2 = knn(kd_tree, X, k);
  0.017842 seconds (50.01 k allocations: 5.112 MiB)
8 Likes

Read the performance tips in the manual. Another productive thing to do is to post a small self-contained snippet of the most performance-critical section of your code, along with sample inputs, and ask for advice on speeding it up.

4 Likes