Scatter/gather operation with Metal

Hey everyone.

I would like to convert the following scatter/gather operation from the CPU to the GPU, using Metal.jl.

function gather!(y::AbstractVector, refs::AbstractVector, x::AbstractVector)
   for i in eachindex(y)
      y[refs[i]] += x[i]
   return y

I have written the following code:

using Metal
function gather!(y::MtlVector, refs::MtlVector, x::MtlVector; nthreads =
   nblocks = cld(length(y), nthreads)
   Metal.@sync @metal threads=nthreads groups=nblocks gather_kernel!(y, refs,
   return y

function gather_kernel!(y, refs, x)
   i = thread_position_in_grid_1d()
   if i <= length(refs)
      Metal.atomic_fetch_add_explicit(pointer(y, refs[i]), x[i])
   return nothing

This works but this is much slower than on the CPU. Is there a way to speed up the code? Tim suggested I look into threadgroup memory but I don’t know what this means (I really don’t know much about GPUs!) See the Github thread for reference.

You’ll have to educate yourself then :slightly_smiling_face: There’s no easy way to write kernels without knowing about parallel programming. Luckily, there’s plenty of resources online, and you can mostly refer to CUDA material and substitute the intrinsics for Metal ones (shared memory → threadgroup memory, threadIdx() and blockIdx() → thread_position, etc). For example see scatter and gather with CUDA? - CUDA Programming and Performance - NVIDIA Developer Forums. In general, this kind of pattern requires an efficient parallel reduction, you can’t just hammer global memory and expect the kernel to perform well. See for example our mapreduce implementation: Metal.jl/src/mapreduce.jl at main · JuliaGPU/Metal.jl · GitHub, where we perform several tricks to avoid using global memory (threadgroup memory, SIMD intrinsics, etc).

As an alternative, try to rephrase your problem in terms of existing array operations that we’ve implemented for you (like mapreducedim). But do know that for Metal.jl, these haven’t been as optimized as for CUDA.jl.