Metal Kernel 3D indices

Hello together, i try to write a metal kernel that do some operations on 3D arrays like:

function vadd!(
nz::Int64,ny::Int64,nx::Int64,
a::MtlDeviceArray{Float32, 3, 1},
b::MtlDeviceArray{Float32, 3, 1},
c::MtlDeviceArray{Float32, 3, 1})

(z,y,x) = thread_position_in_grid_3d()

if z > 1
c[z,y,x] = a[z,y,x] + b[z,y,x] 
end
   
return nothing
end

which should do the same as:

function vadd!(
nz::Int64,ny::Int64,nx::Int64,
a::Array{Float32, 3},
b::Array{Float32, 3},
c::Array{Float32, 3})

for z in 2:nz
    for y in 1:ny
        for x in 1:nx
            c[z,y,x] = a[z,y,x] + b[z,y,x]
        end
    end
end
end;

Later i need to vary with the range of the loops. However, when i try to skip a “z” like above, the kernel function doesnt work properly anymore.

So my question is, how i can set the range of the 3D indices in the kernel function?

You control the thread positions by means of the threads and groups arguments to the kernel, or to @metal. It’s important to know that there’s a limit on the total number of threads in a group, though. See our broadcast implementation for an example: Metal.jl/src/broadcast.jl at de7739909bd2849594c4508cb31d7f5c16608b34 · JuliaGPU/Metal.jl · GitHub

Thanks for the response, unfortunaly i still cannot implement a simple vadd kernel function for 3d arrays with metal. Using your example, my function should look like:

function vadd!(
nz::Int64,ny::Int64,nx::Int64,
a::MtlDeviceArray{Float32, 3, 1},
b::MtlDeviceArray{Float32, 3, 1},
c::MtlDeviceArray{Float32, 3, 1})

is = Tuple(thread_position_in_grid_3d())
stride = threads_per_grid_3d()
while 1 <= is[1] <= nz &&
         1 <= is[2] <= ny &&
         1 <= is[3] <= nx
I = CartesianIndex(is)
@inbounds c[I] = a[I] + b[I]
is = (is[1] + stride[1], is[2] + stride[2], is[3] + stride[3])
end
   
return nothing
end

However, the results are wrong. I still use small arrays (10 x 10 x10), so gpu memory and space should not be the problem.

Please post actually executable code so that it’s easier to help you.

The following seems works fine:

using Test
using Metal

function vadd(a, b, c)
    i0 = Tuple(thread_position_in_grid_3d())
    stride = Tuple(threads_per_grid_3d())
    is = i0
    while 1 <= is[1] <= size(a, 1) &&
          1 <= is[2] <= size(a, 2) &&
          1 <= is[3] <= size(a, 3)
        I = CartesianIndex(is)
        c[I] = a[I] + b[I]
        is = (is[1] + stride[1],
              is[2] + stride[2],
              is[3] + stride[3])
    end
    return
end

function main()
    dims = (3,4,5)
    a = round.(rand(Float32, dims) * 100)
    b = round.(rand(Float32, dims) * 100)
    c = similar(a)

    d_a = MtlArray(a)
    d_b = MtlArray(b)
    d_c = MtlArray(c)

    len = prod(dims)
    @metal threads=dims vadd(d_a, d_b, d_c)
    c = Array(d_c)
    @test a+b ≈ c
end

Obviously still needs to be generalized to selecting a launch configuration that’s compatible with the device; this will only work for small inputs.

Thanks this works. Looks very similar to what i tried, but i must made a mistake somewhere. Anyway, i attach an executable code that solves the miniproblem i tried to describe above for varying the index ranges in kernel functions.

using Test
using Metal

function vadd_metal(nz,ny,nx,a, b, c)
    i0 = Tuple(thread_position_in_grid_3d())
    stride = Tuple(threads_per_grid_3d())
    is = i0
    while 4 <= is[1] <= nz-5 &&
          1 <= is[2] <= ny &&
          1 <= is[3] <= nx
        I = CartesianIndex(is)
        c[I] = a[I] + b[I]
        is = (is[1] + stride[1],
              is[2] + stride[2],
              is[3] + stride[3])
    end
    return 
end

function vadd_cpu!(nz,ny,nx,a, b, c)
    for z in 4:nz-5
        for y in 1:ny
            for x in 1:nx
                c[z,y,x] = a[z,y,x] + b[z,y,x]
            end
        end
    end
end

function main()

nx = 500
ny = 600
nz = 700
dims = (nz,ny,nx)

a = round.(rand(Float32, dims) * 100)
b = round.(rand(Float32, dims) * 100)
c = zeros(Float32,dims)

d_a = MtlArray(a)
d_b = MtlArray(b)
d_c = MtlArray(c)

kernel = @metal launch=false vadd_metal(nz,ny,nx,d_a, d_b, d_c)

dim_arg_sort = sort(collect(size(c)),rev=true)
w = min(size(dim_arg_sort, 1), kernel.pipeline.threadExecutionWidth)
h = min(size(dim_arg_sort, 2), kernel.pipeline.threadExecutionWidth,
                               kernel.pipeline.maxTotalThreadsPerThreadgroup ÷ w)
d = min(size(dim_arg_sort, 3), kernel.pipeline.maxTotalThreadsPerThreadgroup ÷ (w*h))

threads = (w, h, d)
groups = cld.(size(c), threads)

kernel(nz,ny,nx,d_a, d_b, d_c, threads=threads, groups=groups)
d_c = Array(d_c)

vadd_cpu!(nz,ny,nx,a, b, c)

@test c ≈ d_c

end

main()