Thanks – That example was useful
I was able to figure out how to do things based on the knn / inrange search functions. Here’s my current code to implement the d3-force Barnes-Hut repulsive force computation based on the KDTree.
Sorry, it’s super messy at the moment, but basically, the algorithm words in two phases. in the first phase, we compute the centers of each non-leaf node. (And associated weights). Then in the next phase, we check if the current rect/hyperrect from the KDtree is ‘far enough away’ (diam^2 < theta*centerdist). If so, we use the approximation. If not, we recurse to each subtree.
If we get to a leaf, then we just directly compute the approximation.
At the moment, there must be a few things that are causing the allocator to get involved, so the timing isn’t great. (1 second for 100k points…) I see no reason why this shouldn’t get down to a few milliseconds… since we can compute the original one in a few ms… but that’s an optimization step.
At the moment, I can make headway with what I need. But if you want to think about an API for walking these trees, I can try and write some functions as I clean this mess up
using NearestNeighbors
using StableRNGs
using GeometryBasics
using StaticArrays
using LinearAlgebra
pts = rand(StableRNG(1), Point2f, 10000)
T = KDTree(pts; leafsize = 10)
##
function build_centers(T::KDTree, pts)
n = length(T.nodes)
centers = Vector{Point2f}(undef, n)
weights = Vector{Float32}(undef, n)
# we need to do a post-order traversal
function walk(T, n, idx)
center = 0.0 .* first(pts)
weight = zero(eltype(weights))
if NearestNeighbors.isleaf(n, idx)
idxmap = T.indices
treepts = T.data
npts = 0
for ptsidx in NearestNeighbors.get_leaf_range(T.tree_data, idx)
npts += 1
Tidx = T.reordered ? ptsidx : idxmap[ptsidx]
center = center .+ treepts[Tidx]
weight += (-30) # need to make it the actual weight
end
center = center ./ npts
return (center, weight)
else
left, right = NearestNeighbors.getleft(idx), NearestNeighbors.getright(idx)
lcenter, lweight = walk(T, n, left)
rcenter, rweight = walk(T, n, right)
centers[idx] = (abs(lweight) .* lcenter .+ abs(rweight) .* rcenter) ./ (abs(lweight) .+ abs(rweight))
weights[idx] = lweight + rweight
return (centers[idx], weights[idx])
end
end
walk(T, n, 1)
return centers, weights
end
centers, weights = build_centers(T, pts)
## Visualize the centers to check...
function plot_rects!(ax, T; level=2, index=1, hyper_rec=T.hyper_rec, expand=0.1)
if NearestNeighbors.isleaf(length(T.nodes), index)
#line = lines!(ax, Rect2(hyper_rec.mins, hyper_rec.maxes - hyper_rec.mins), color=Cycled(rand(1:16)))
line = lines!(ax, Rect2(hyper_rec.mins, hyper_rec.maxes - hyper_rec.mins))
tree = T
for z in NearestNeighbors.get_leaf_range(tree.tree_data, index)
idx = tree.reordered ? z : tree.indices[z]
#idx = tree.indices[z]
#@show tree.reordered
#scatter!(ax, [pts[idx]], color=line.color)
scatter!(ax, [tree.data[idx]], color=line.color)
end
elseif level == 0
line = lines!(ax, Rect2(hyper_rec.mins, hyper_rec.maxes - hyper_rec.mins))
scatter!(ax, [centers[index]], color=line.color, marker=:circle, markersize=15)
else
node = T.nodes[index]
split_val = node.split_val
split_dim = node.split_dim
right = NearestNeighbors.getright(index)
left = NearestNeighbors.getleft(index)
hyper_rec_right = NearestNeighbors.HyperRectangle(@inbounds(setindex(hyper_rec.mins, split_val, split_dim)), hyper_rec.maxes)
hyper_rec_left = NearestNeighbors.HyperRectangle(hyper_rec.mins, @inbounds setindex(hyper_rec.maxes, split_val, split_dim))
plot_rects!(ax, T; level=level-1, index=left, hyper_rec=hyper_rec_left, expand=expand-0.02)
plot_rects!(ax, T; level=level-1, index=right, hyper_rec=hyper_rec_right, expand=expand-0.02)
end
end
f = scatter(pts, marker='O', markersize=15, color=:black)
hidedecorations!(f.axis)
hidespines!(f.axis)
plot_rects!(f.axis, T; level=2)
f
##
vel = similar(pts)
function applyforces(T::KDTree, pts, vel, centers, weights;
theta2=0.81, strength=-30.0,
min_distance2 = 1.0,
max_distance2 = Inf,
alpha=1.0)
function _compute_force(pt1, pt2, strength)
d = pt2 .- pt1
d2 = dot(d, d)
if d2 < max_distance2
#d = jiggle(d, rng)
d2 = dot(d, d)
if d2 < min_distance2
d2 = sqrt(min_distance2*d2)
end
w = strength*alpha / d2
return d .* w
else
return 0.0 .* pt1
end
end
function _computeforce(target, treeindex, targetpt, T, rect)
f = 0.0 .* targetpt
if NearestNeighbors.isleaf(length(T.nodes), treeindex)
idxmap = T.indices
treepts = T.data
for Tidx in NearestNeighbors.get_leaf_range(T.tree_data, treeindex)
ptsidx = idxmap[Tidx]
Tidx = T.reordered ? Tidx : ptsidx
if ptsidx != target
pt = treepts[Tidx]
f = f .+ _compute_force(targetpt, pt, strength)
end
end
else
node = T.nodes[treeindex]
split_val = node.split_val
split_dim = node.split_dim
center = centers[treeindex]
w = weights[treeindex]
d = center .- targetpt
d2 = dot(d,d)
diam = maximum(rect.maxes .- rect.mins)
if (diam*diam / theta2) < d2
# apply the approximation
if d2 < min_distance2
d2 = sqrt(min_distance2*d2)
end
f = f .+ (d .* (w * alpha / d2))
# and then don't recurse...
else
# otherwise, recurse...
left, right = NearestNeighbors.getleft(treeindex), NearestNeighbors.getright(treeindex)
rect_right = NearestNeighbors.HyperRectangle(@inbounds(setindex(rect.mins, split_val, split_dim)), rect.maxes)
rect_left = NearestNeighbors.HyperRectangle(rect.mins, @inbounds setindex(rect.maxes, split_val, split_dim))
f = f .+ _computeforce(target, left, targetpt, T, rect_left)
f = f .+ _computeforce(target, right, targetpt, T, rect_right)
end
end
return f
end
for i in eachindex(T.data)
vel[i] = _computeforce(i, 1, pts[i], T, T.hyper_rec)
end
end
@time applyforces(T, pts, vel, centers, weights)
@time applyforces(T, pts, vel, centers, weights)
@time applyforces(T, pts, vel, centers, weights)
## check forces
vel2 = similar(pts)
function simpleforces(pts, vel2;
strength=-30.0,
min_distance2 = 1.0,
max_distance2 = Inf,
alpha=1.0)
function _compute_force(pt1, pt2, strength)
d = pt2 .- pt1
d2 = dot(d, d)
if d2 < max_distance2
#d = jiggle(d, rng)
d2 = dot(d, d)
if d2 < min_distance2
d2 = sqrt(min_distance2*d2)
end
w = strength*alpha / d2
return d .* w
else
return 0.0 .* pt1
end
end
for i in eachindex(pts)
targetpt = pts[i]
f = 0.0 .* targetpt
for j in eachindex(pts)
if i != j
f = f .+ _compute_force(targetpt, pts[j], strength)
end
end
vel2[i] = f
end
end
@time simpleforces(pts, vel2)
@time simpleforces(pts, vel2)
@time simpleforces(pts, vel2)
isapprox(vel, vel2; rtol=0.1)