BK tree (graph) search algorithm

I’m attempting to implement BK trees via LightGraphs and SimpleWeightedGraphs and I’m struggling to figure out how to write a function (doesn’t work properly) to perform the search. I have a simple function that builds the trees quite efficiently:

using LightGraphs
using SimpleWeightedGraphs
using StringDistances

dictionary = ["air","aid","army","adopt","allusion","ally","alter","amend","assay"]
N = length(dictionary)

d(a,b) = evaluate(Levenshtein(), a, b)

g = SimpleWeightedDiGraph(N)

# Establish root node and first child
add_edge!(g, 1, 2, d(dictionary[1], dictionary[2]))

function add_leaf!(g, p, c)
    while ne(g) < N - 1
        dist = d(dictionary[p], dictionary[c])
        if !in(dist, g.weights[:, p])
            add_edge!(g, p, c, dist)
        else
            add_leaf!(g, findfirst(x -> x == dist, g.weights[:, p])[1], c)
        end
        p = 1
        c += 1
    end
end

add_leaf!(g, 1, 3)

This results in the following graph/tree:

g

In order to query this tree with the target word “aide” and the tolerance level of 2 (meaning I only want words with edit distance <= 2, my understanding (based on this) is that it should go something like this:

  1. Start with the root node (“air” in this case), if d(air, aide) \leq 2, add “air” to the list of matched words to be returned.
  2. Collect all nodes that are connected to “air” where d(air, aide) - 2 \leq d(air, node_i) \leq d(air, aide) + 2
  3. Visit each node and start the process over: compare distance between target and current node, add to matched words, collect connected nodes that meet above criteria, etc.

Recursive functions make my brain hurt more than anything else in programming so I’d really appreciate any assistance with this :smile:. Here are a few things I’ve put together that I think could/should be used in the solution:

For #1 above, it’s easy:

dist = d(dictionary[i], target)
dist <= tol && push!(matches, dictionary[i])

To collect all nodes that meet the criteria in #2:

nodes = [e.dst for e in edges(g) if e.src == i && dist - tol <= e.weight <= dist + tol]

The hard bit is figuring out how to walk down a branch of the tree like this until you reach the end, then walk back up and keep track of everything…here’s a trainwreck of a function that is basically the point I got to before tapping out and writing this post :rofl:

function query_tree(g, target, tol, i, matches, checked=[1], leftover=[])
    length(checked) > 1 && length(leftover) == 0 && return matches
    dist = d(dictionary[i], target)
    dist <= tol && push!(matches, i)
    nodes = vcat([e.dst for e in edges(g) if e.src == i && dist - tol <= e.weight <= dist + tol], leftover)
    push!(checked, setdiff(outneighbors(g, i), nodes)...)

    for node in nodes
        new_dist = d(dictionary[node], target)
        new_dist <= tol && push!(matches, node)
        length(outneighbors(g, node)) == 0 && push!(checked, node)
    end

    new_nodes = setdiff(nodes, checked)
    query_tree(g, target, tol, new_nodes[1], matches, checked, new_nodes[2:end])
end

query_tree(g, target, tol, 1, [])

What would be really nice would be to come up with something that can run down the different branches of the tree in parallel…

This seems to work but it also is an atrocity and is painfully slow:

function continue_search(g, target, tol, matches, ret_points)
    length(ret_points) == 0 && return matches

    i = ret_points[1]

    dist = d(dictionary[i], target)

    nodes = vcat([[e.dst for e in edges(g) if e.src == i && dist - tol <= e.weight <= dist + tol] for i in ret_points]...)

    for node in nodes
        dist = d(dictionary[node], target)
        dist <= tol && push!(matches, dictionary[node])
    end

    ret_points = [node for node in nodes if length(outneighbors(g, node)) > 0]

    continue_search(g, target, tol, matches, ret_points)
end

function init_search(g, target, tol)
    matches = []

    dist = d(dictionary[1], target)

    dist <= tol && push!(matches, dictionary[1])

    nodes = [e.dst for e in edges(g) if e.src == 1 && dist - tol <= e.weight <= dist + tol]

    for node in nodes
        dist = d(dictionary[node], target)
        dist <= tol && push!(matches, dictionary[node])
    end

    ret_points = [node for node in nodes if length(outneighbors(g, node)) > 0]

    continue_search(g, target, tol, matches, ret_points)
end


I’m thinking a good approach might be to write something that first identifies which nodes qualify for comparison, store their indices, and then have another function that goes back and does the comparisons in parallel…

2 Likes