I’ve been working on implementing the k nearest neighbors CUDA kernel from https://github.com/facebookresearch/pytorch3d/tree/master/pytorch3d/csrc/knn in Julia. I found KernelAbstractions.jl allows the code to be simplified, however I’m finding it difficult to know exactly what I can and can’t do inside my kernel.

I’ve put the contents of 3 files in code blocks below. The first, `knn_cpu.jl`

implements the pytorch3d knn kernel using KernelAbstractions.jl, however it only works on the CPU due to issues making dynamic function calls. The second block is `MinK.jl`

, a data structure for keeping track of the minimum k values for the knn kernel (This is for the CPU implementation). `knn_gpu.jl`

is the same implementation, but with the entire `MinK`

data structure implemented inside the kernel. This was the only way I could get it to work on the GPU (Nvidia GTX 1080 TI).

For various reasons the CPU implementation is desirable (more readable, more modular, … ect). My question is, what are the minimal changes I need to make to `knn_cpu.jl`

and `MinK.jl`

to make it work on the GPU?

My more general question is how do I know what I can use inside my GPU kernel? In my testing, I could not get any external function that I wrote to work inside the Kernel. My understanding is that functions used inside the kernel must be unlined, but this also didn’t work for me.

Additionally, how would you recommend that I grow my knowledge of the topic? Is the Julia ecosystem mature enough to stay entirely within Julia? Or will I basically need to learn how to use CUDA C fist and then apply my knowledge to Julia.

I’m new to Julia – all critiques are welcome.

`knn_cpu.jl`

```
using KernelAbstractions, Test
include("MinK.jl")
@kernel function knn_kernel(
@Const(points1),
@Const(points2),
@Const(lengths1),
@Const(lengths2),
@Const(D),
@Const(K),
dists,
idxs)
p1, n = @index(Global, NTuple)
@views mink = MinK(dists[:, p1, n], idxs[:, p1, n], K)
for p2 = 1:lengths2[n]
dist = eltype(dists)(0)
for d = 1:D
dist += (points1[d, p1, n] - points2[d, p2, n])^2
end
push!(mink, (dist, p2))
end
nothing
end
"""
Really simple nearest neighbor test
Check that the nearest neighbor to p1=1 in p2 = [-2, -1, 0, 1, 2]
is p1=1 at index 4 with distance 0
"""
function test_knn_kernel()
D, P1, N = 1, 1, 1
K, P2 = 1, 5
p1 = reshape([1f0], (D, P1, N))
p2 = reshape([-2f0, -1f0, 0f0, 1f0, 2f0], (D, P2, N))
lengths1 = [P1]
lengths2 = [P2]
dists = zeros(Float32, (K, P1, N))
idxs = zeros(Int64, (K, P1, N))
kernel! = knn_kernel(CPU(), 1)
event = kernel!(p1, p2, lengths1, lengths2, D, K, dists, idxs, ndrange=(P1, N))
wait(event)
@test dists[1, 1, 1] == 0f0
@test idxs[1, 1, 1] == 4
@test p2[1, 4, 1] == 1f0
end
test_knn_kernel()
function main()
N = 2
D = 1
P1 = 2
P2 = 5
K = 3
p1 = randn(Float32, (D, P1, N))
p2 = randn(Float32, (D, P2, N))
lengths1 = fill(P1, (N,))
lengths2 = fill(P2, (N,))
dists = zeros(Float32, (K, P1, N))
idxs = zeros(Int64, (K, P1, N))
kernel! = knn_kernel(CPU(), 4)
event = kernel!(p1, p2, lengths1, lengths2, D, K, dists, idxs, ndrange=(P1, N))
wait(event)
end
```

`MinK.jl`

```
import Base: getindex, setindex!, push!, sort!
using Base: findmax
"""
Adapted from:
https://github.com/facebookresearch/pytorch3d/blob/327bd2b9762c05ec7a6f74c2ec1e46f2a764e326/pytorch3d/csrc/utils/mink.cuh
"""
mutable struct MinK{kT,vT,iT}
keys::AbstractVector{kT} # i.e. distances
vals::AbstractVector{vT} # i.e. indices
size::iT
max_key::kT
max_idx::iT
end
@inline function MinK(keys::AbstractVector{kT}, vals::AbstractVector{vT}, K::iT) where {kT,vT,iT}
MinK{kT,vT,iT}(keys, vals, iT(0), kT(0), iT(0))
end
@inline getindex(mink::MinK, i) = (mink.keys[i], mink.vals[i])
@inline function setindex!(mink::MinK{kT,vT}, (key, val)::Tuple{kT,vT}, i) where {kT,vT}
mink.keys[i], mink.vals[i] = (key, val)
end
@inline function push!(mink::MinK{kT,vT}, (key, val)::Tuple{kT,vT}) where {kT,vT}
K = length(mink.keys)
if mink.size < K # Runtime: O(1)
mink[mink.size + 1] = (key, val)
if mink.size == 0 || key > mink.max_key
mink.max_key = key
mink.max_idx = mink.size + 1
end
mink.size += 1
elseif key < mink.max_key # Runtime: O(K)
# Current key replaces old max
mink[mink.max_idx] = (key, val)
# Find new max from all keys
mink.max_key, mink.max_idx = findmax(mink.keys)
end
end
"""
Bubble sort
Runtime: O(K²)
"""
@inline function sort!(mink::MinK)
for i = 1:mink.size
for j = 1:(mink.size - i)
if mink.keys[j + 1] < mink.keys[j]
mink[j], mink[j + 1] = mink[j + 1], mink[j]
end
end
end
end
function main()
K = 5
keys = Array{Float64,2}(undef, (5, K))
vals = Array{Int64,2}(undef, (5, K))
@views mink = MinK(keys[1, :], vals[1, :], K)
push!(mink, (10.0, 55))
push!(mink, (10.0, 55))
push!(mink, (10.0, 55))
push!(mink, (10.0, 55))
push!(mink, (-1.0, 55))
push!(mink, (-1.0, 55))
push!(mink, (-1.0, 55))
push!(mink, (-1.0, 55))
push!(mink, (-1.0, 55))
mink
end
```

`knn_gpu.jl`

```
using KernelAbstractions, CUDA
using Test
@kernel function knn_kernel(
@Const(points1),
@Const(points2),
@Const(lengths1),
@Const(lengths2),
@Const(D),
@Const(K),
dists,
idxs)
p1, n = @index(Global, NTuple)
# MinK parameters
size = 0 # Size of mink (when less than K)
max_key = dists[1, p1, n] # Placeholder for max_key
max_idx = idxs[1, p1, n] # Placeholder for max_idx
# Runs in current thread
for p2 = 1:lengths2[n]
dist = eltype(dists)(0)
for d = 1:D
dist += (points1[d, p1, n] - points2[d, p2, n])^2
end
# Add (dist, p2) to MinK data structure
if size < K # Runtime: O(1)
dists[size + 1, p1, n] = dist
idxs[size + 1, p1, n] = p2
if size == 0 || dist > max_key
max_key = dist
max_idx = size + 1
end
size += 1
elseif dist < max_key # Runtime: O(K)
# Current key replaces old max
dists[max_idx, p1, n] = dist
idxs[max_idx, p1, n] = p2
# Find new max from all dists
max_key, max_idx = dist, -1
for i = 1:K
if dists[i, p1, n] ≥ max_key
max_key, max_idx = dists[i, p1, n], i
end
end
end
end
end
"""
Really simple nearest neighbor test
Check that the nearest neighbor to p1=1 in p2 = [-2, -1, 0, 1, 2]
is p1=1 at index 4 with distance 0
"""
function test_knn_kernel()
D, P1, N = 1, 1, 1
K, P2 = 1, 5
p1 = reshape([1f0], (D, P1, N))
p2 = reshape([-2f0, -1f0, 0f0, 1f0, 2f0], (D, P2, N))
lengths1 = [P1]
lengths2 = [P2]
dists = zeros(Float32, (K, P1, N))
idxs = zeros(Int64, (K, P1, N))
kernel! = knn_kernel(CPU(), 1, (P1, N))
event = kernel!(p1, p2, lengths1, lengths2, D, K, dists, idxs, ndrange=(P1, N))
wait(event)
@test dists[1, 1, 1] == 0f0
@test idxs[1, 1, 1] == 4
@test p2[1, 4, 1] == 1f0
end
test_knn_kernel()
function main()
N = 2
D = 1
P1 = 2
P2 = 5
K = 3
p1 = CUDA.randn(Float32, (D, P1, N))
p2 = CUDA.randn(Float32, (D, P2, N))
lengths1 = CUDA.fill(P1, (N,))
lengths2 = CUDA.fill(P2, (N,))
dists = CUDA.zeros(Float32, (K, P1, N))
idxs = CUDA.zeros(Int64, (K, P1, N))
kernel! = knn_kernel(CUDADevice(), 8)
event = kernel!(p1, p2, lengths1, lengths2, D, K, dists, idxs, ndrange=(P1, N))
wait(event)
end
```