Using functions in GPU Kernel (via KernelAbstractions.jl) (k nearest neighbor kernel)

I’ve been working on implementing the k nearest neighbors CUDA kernel from https://github.com/facebookresearch/pytorch3d/tree/master/pytorch3d/csrc/knn in Julia. I found KernelAbstractions.jl allows the code to be simplified, however I’m finding it difficult to know exactly what I can and can’t do inside my kernel.

I’ve put the contents of 3 files in code blocks below. The first, knn_cpu.jl implements the pytorch3d knn kernel using KernelAbstractions.jl, however it only works on the CPU due to issues making dynamic function calls. The second block is MinK.jl, a data structure for keeping track of the minimum k values for the knn kernel (This is for the CPU implementation). knn_gpu.jl is the same implementation, but with the entire MinK data structure implemented inside the kernel. This was the only way I could get it to work on the GPU (Nvidia GTX 1080 TI).

For various reasons the CPU implementation is desirable (more readable, more modular, … ect). My question is, what are the minimal changes I need to make to knn_cpu.jl and MinK.jl to make it work on the GPU?

My more general question is how do I know what I can use inside my GPU kernel? In my testing, I could not get any external function that I wrote to work inside the Kernel. My understanding is that functions used inside the kernel must be unlined, but this also didn’t work for me.

Additionally, how would you recommend that I grow my knowledge of the topic? Is the Julia ecosystem mature enough to stay entirely within Julia? Or will I basically need to learn how to use CUDA C fist and then apply my knowledge to Julia.

I’m new to Julia – all critiques are welcome.

knn_cpu.jl

using KernelAbstractions, Test
include("MinK.jl")


@kernel function knn_kernel(
        @Const(points1),
        @Const(points2),
        @Const(lengths1),
        @Const(lengths2),
        @Const(D),
        @Const(K),
        dists,
        idxs)

    p1, n = @index(Global, NTuple)

    @views mink = MinK(dists[:, p1, n], idxs[:, p1, n], K)

    for p2 = 1:lengths2[n]
        dist = eltype(dists)(0)
        for d = 1:D
            dist += (points1[d, p1, n] - points2[d, p2, n])^2
        end
        push!(mink, (dist, p2))
    end
    nothing
end

"""
Really simple nearest neighbor test

Check that the nearest neighbor to p1=1 in p2 = [-2, -1, 0, 1, 2]
is p1=1 at index 4 with distance 0
"""
function test_knn_kernel()
    D, P1, N = 1, 1, 1
    K, P2 = 1, 5

    p1 = reshape([1f0], (D, P1, N))
    p2 = reshape([-2f0, -1f0, 0f0, 1f0, 2f0], (D, P2, N))
    lengths1 = [P1]
    lengths2 = [P2]
    dists = zeros(Float32, (K, P1, N))
    idxs = zeros(Int64, (K, P1, N))
    kernel! = knn_kernel(CPU(), 1)
    event = kernel!(p1, p2, lengths1, lengths2, D, K, dists, idxs, ndrange=(P1, N))
    wait(event)

    @test dists[1, 1, 1] == 0f0
    @test idxs[1, 1, 1] == 4
    @test p2[1, 4, 1] == 1f0
end
test_knn_kernel()

function main()
    N = 2
    D = 1
    P1 = 2
    P2 = 5
    K = 3

    p1 = randn(Float32, (D, P1, N))
    p2 = randn(Float32, (D, P2, N))
    lengths1 = fill(P1, (N,))
    lengths2 = fill(P2, (N,))


    dists = zeros(Float32, (K, P1, N))
    idxs = zeros(Int64, (K, P1, N))

    kernel! = knn_kernel(CPU(), 4)
    event = kernel!(p1, p2, lengths1, lengths2, D, K, dists, idxs, ndrange=(P1, N))
    wait(event)
end

MinK.jl

import Base: getindex, setindex!, push!, sort!
using Base: findmax


"""
Adapted from:
https://github.com/facebookresearch/pytorch3d/blob/327bd2b9762c05ec7a6f74c2ec1e46f2a764e326/pytorch3d/csrc/utils/mink.cuh
"""
mutable struct MinK{kT,vT,iT}
    keys::AbstractVector{kT}        # i.e. distances
    vals::AbstractVector{vT}         # i.e. indices
    size::iT
    max_key::kT
    max_idx::iT
end


@inline function MinK(keys::AbstractVector{kT}, vals::AbstractVector{vT}, K::iT) where {kT,vT,iT}
    MinK{kT,vT,iT}(keys, vals, iT(0), kT(0), iT(0))
end


@inline getindex(mink::MinK, i) = (mink.keys[i], mink.vals[i])
@inline function setindex!(mink::MinK{kT,vT}, (key, val)::Tuple{kT,vT}, i) where {kT,vT}
    mink.keys[i], mink.vals[i] = (key, val)
end


@inline function push!(mink::MinK{kT,vT}, (key, val)::Tuple{kT,vT}) where {kT,vT}
    K = length(mink.keys)
    if mink.size < K                        # Runtime: O(1)
        mink[mink.size + 1] = (key, val)
        if mink.size == 0 || key > mink.max_key
            mink.max_key = key
            mink.max_idx = mink.size + 1
        end
        mink.size += 1
    elseif key < mink.max_key               # Runtime: O(K)
        # Current key replaces old max
        mink[mink.max_idx] = (key, val)
        # Find new max from all keys
        mink.max_key, mink.max_idx = findmax(mink.keys)
    end
end

"""
Bubble sort
Runtime: O(K²)
"""
@inline function sort!(mink::MinK)
    for i = 1:mink.size
        for j = 1:(mink.size - i)
            if mink.keys[j + 1] < mink.keys[j]
                mink[j], mink[j + 1] = mink[j + 1], mink[j]
            end
        end
    end
end


function main()
    K = 5
    keys = Array{Float64,2}(undef, (5, K))
    vals = Array{Int64,2}(undef, (5, K))
        
    @views mink = MinK(keys[1, :], vals[1, :], K)
    push!(mink, (10.0, 55))
    push!(mink, (10.0, 55))
    push!(mink, (10.0, 55))
    push!(mink, (10.0, 55))
    push!(mink, (-1.0, 55))
    push!(mink, (-1.0, 55))
    push!(mink, (-1.0, 55))
    push!(mink, (-1.0, 55))
    push!(mink, (-1.0, 55))
    mink
end

knn_gpu.jl

using KernelAbstractions, CUDA
using Test


@kernel function knn_kernel(
        @Const(points1),
        @Const(points2),
        @Const(lengths1),
        @Const(lengths2),
        @Const(D),
        @Const(K),
        dists,
        idxs)

    p1, n = @index(Global, NTuple)

    # MinK parameters
    size = 0                    # Size of mink (when less than K)
    max_key = dists[1, p1, n]   # Placeholder for max_key
    max_idx = idxs[1, p1, n]    # Placeholder for max_idx

    # Runs in current thread
    for p2 = 1:lengths2[n]
        dist = eltype(dists)(0)
        for d = 1:D
            dist += (points1[d, p1, n] - points2[d, p2, n])^2
        end
        # Add (dist, p2) to MinK data structure
        if size < K           # Runtime: O(1)
            dists[size + 1, p1, n] = dist
            idxs[size + 1, p1, n] = p2
            if size == 0 || dist > max_key
                max_key = dist
                max_idx = size + 1
            end
            size += 1
        elseif dist < max_key       # Runtime: O(K)
            # Current key replaces old max
            dists[max_idx, p1, n] = dist
            idxs[max_idx, p1, n] = p2
            # Find new max from all dists
            max_key, max_idx = dist, -1
            for i = 1:K
                if dists[i, p1, n] ≥ max_key
                    max_key, max_idx = dists[i, p1, n], i
                end
            end
        end
    end
end

"""
Really simple nearest neighbor test

Check that the nearest neighbor to p1=1 in p2 = [-2, -1, 0, 1, 2]
is p1=1 at index 4 with distance 0
"""
function test_knn_kernel()
    D, P1, N = 1, 1, 1
    K, P2 = 1, 5

    p1 = reshape([1f0], (D, P1, N))
    p2 = reshape([-2f0, -1f0, 0f0, 1f0, 2f0], (D, P2, N))
    lengths1 = [P1]
    lengths2 = [P2]
    dists = zeros(Float32, (K, P1, N))
    idxs = zeros(Int64, (K, P1, N))

    kernel! = knn_kernel(CPU(), 1, (P1, N))
    event = kernel!(p1, p2, lengths1, lengths2, D, K, dists, idxs, ndrange=(P1, N))
    wait(event)

    @test dists[1, 1, 1] == 0f0
    @test idxs[1, 1, 1] == 4
    @test p2[1, 4, 1] == 1f0
end
test_knn_kernel()

function main()
    N = 2
    D = 1
    P1 = 2
    P2 = 5
    K = 3

    p1 = CUDA.randn(Float32, (D, P1, N))
    p2 = CUDA.randn(Float32, (D, P2, N))
    lengths1 = CUDA.fill(P1, (N,))
    lengths2 = CUDA.fill(P2, (N,))


    dists = CUDA.zeros(Float32, (K, P1, N))
    idxs = CUDA.zeros(Int64, (K, P1, N))

    kernel! = knn_kernel(CUDADevice(), 8)
    event = kernel!(p1, p2, lengths1, lengths2, D, K, dists, idxs, ndrange=(P1, N))
    wait(event)
end

That typically indicates an inference failure. You can inspect the generated code using CUDA.jl’s @device_code_warntype interactive=true, which uses Cthulhu to render the code and highlight any badly-inferred expressions.