Cumulative sum on GPUArray using KernelAbstractions

Hello,

I need to implement the cumulative sum (cumsum) on a GPU array (CUDA.jl or Metal.jl). Looking at the CUDA.jl repository, I found this definition

function cumsum!(sums)
    shift = 1

    while shift < length(sums)
        to_add = 0
        @inbounds if threadIdx().x - shift > 0
            to_add = sums[threadIdx().x - shift]
        end

        sync_threads()
        @inbounds if threadIdx().x - shift > 0
            sums[threadIdx().x] += to_add
        end

        sync_threads()
        shift *= 2
    end
end

which can be executed inside a CUDA kernel.

Now, I need to implement it using KernelAbstractions.jl, but I’m not familiar with shared memory, especially with KernelAbstractions.jl. The easiest way I thought was something like

function cumsum!(sums)
    idx = @index(Global)
    shift = 1

    while shift < length(sums)
        to_add = 0
        @inbounds if idx - shift > 0
            to_add = sums[idx - shift]
        end

        KernelAbstractions.@syncronize()
        @inbounds if idx - shift > 0
            sums[idx] += to_add
        end

        KernelAbstractions.@syncronize()
        shift *= 2
    end
end

But I don’t know if this is correct, or if I have to use KernelAbstractions.@localmem or something else.

This is a nice article that shows how to implement prefix scan in C/CUDA.

You (probably) don’t have to implement your own kernel for this: CUDA.jl and Metal.jl (as well as AMDGPU.jl) implement Base.accumulate for their gpu array types (CUDA, Metal (and AMDGPU )), so you could just use accumulate(+, gpu_array).

And if you want to port this to KA.jl, the cumsum! function you’re looking at is an implementation detail from the sorting kernel, and not something reusable. A more general cumsum kernel can be found in CUDA.jl/src/accumulate.jl at 972f3f0a9d594c75431df96b027588f8279ad2de · JuliaGPU/CUDA.jl · GitHub, and resembles our mapreduce implementation, which means you can probably draw inspiration from the WIP mapreduce implementation with KA.jl in Implement mapreduce by vchuravy · Pull Request #561 · JuliaGPU/GPUArrays.jl · GitHub.