Writing a Metal Kernel

I have the following kernel in CUDA

function inv!(scale, nthreads)
	nblocks = cld(length(scale), nthreads) 
	@cuda threads=nthreads blocks=nblocks inv_kernel!(scale)
end

function inv_kernel!(scale)
	index = (blockIdx().x - 1) * blockDim().x + threadIdx().x
	stride = blockDim().x * gridDim().x
	@inbounds for i = index:stride:length(scale)
		scale[i] = 1 / sqrt(scale[i]))
	end
end

How can I write an equivalent kernel in Metal?

Did you look at the examples? Metal.jl/examples/vadd.jl at main · JuliaGPU/Metal.jl · GitHub

Thanks but these example do not seem to cover the keyword groups (the equivalent to blocks in CUDA?). That is, how should I modify your linked example to allow for parallelism?

Here is what I have written. Is that correct / efficient?

function inv!(scale::MtlVector, nthreads::Integer)
	nblocks = cld(length(scale), nthreads) 
	@metal threads=nthreads groups=nblocks inv_kernel!(scale)
end

function inv_kernel!(scale)
	i = thread_position_in_grid_1d()
	if i <= length(scale)
		scale[i] = 1 / sqrt(scale[i])
	end
	return nothing
end

Hi @matthieu !

I think so, I would have done something very similar. However, is there any problem for using broadcasting: scale .= 1 ./ sqrt.(scale)?

julia> vec = MtlVector([10.0f0^(n - 1) for n in 1:20])
20-element MtlVector{Float32}:
      1.0
     10.0
    100.0
   1000.0
  10000.0
 100000.0
      1.0f6
      1.0f7
      1.0f8
      1.0f9
      1.0f10
      1.0f11
      1.0f12
      1.0f13
      1.0f14
      1.0f15
      1.0f16
      1.0f17
      1.0f18
      1.0f19

julia> vec .= 1 ./ sqrt.(vec)
20-element MtlVector{Float32}:
 1.0
 0.31622776
 0.1
 0.03162278
 0.01
 0.0031622779
 0.001
 0.0003162278
 0.0001
 3.1622774f-5
 1.0f-5
 3.1622778f-6
 1.0f-6
 3.1622776f-7
 1.0f-7
 3.1622776f-8
 1.0f-8
 3.1622776f-9
 1.0f-9
 3.1622777f-10

Right, this is just a simple example — the real kernel I need to write involve indices / multiple arrays.

Ah ok! So you are in the right track :slight_smile:

A question related to this. Indeed in the CUDA example above there is the “stride loop” which decouples the number of threads from the array length:

index = (blockIdx().x - 1) * blockDim().x + threadIdx().x
stride = blockDim().x * gridDim().x. # Total number of threads

what would be the correct way to do this in Metal.jl? Is it as simple as

index = thread_position_in_grid_1d()
stride = threads_per_grid_1d()

or do I need to factor in the thread groups?

Those per_grid/in_grid indexing functions should be sufficient, e.g., see how we implement Broadcast in Metal.jl: Metal.jl/src/broadcast.jl at adac3bd000be7dc6b64725a8e8a7a47a154c4dfe · JuliaGPU/Metal.jl · GitHub

Thanks!