I can’t change this Float32 → Int32
Why not? Doing so just works…
apparently @cuDynamicSharedMem can only be 1 dimensional?
No, it can be higher-dimensional but you need to pass the size as a tuple. @cu...SharedMem
macros are a bit of a hack and really should be a proper array type, at which point they should probably implement all of the common constructor syntaxes.
it allows me to say
linear_index :: Int..
but notlinear_index :: Int32..
Again, that just works here. Do know that generally you want to avoid such patterns since it generates quite ugly code (branches for checked arithmetic, allocs and calls to exception methods) that typically is unwanted in hot code. GPU intrinsics return Int32 values but the Julia functions return Int (for codegen reasons) so you can safely unsafe_trunc(Int32, ...)
if you really want a 32-bit value.
Maybe you forgot to reserve memory for the dynamic shared memory (specified using the shmem
keyword argument to @cuda
)? Do note that your use of dynamic shared memory is kind-of peculiar, since it’s normally shared between threads but you only use a single thread. Doing so will probably hurt your occupancy. If you want truly local memory, why not use StaticArrays?
Anyway, here’s the code that works:
using CUDAdrv, CUDAnative, CuArrays
function l_0(x, y, z, w, h)
return x + y*w + z*w*h
end
function l(x, y, z, w, h)
_x = x - 1
_y = y - 1
_z = z - 1
return l_0(_x, _y, _z, w, h) + 1
end
function kernel(out)
x = blockIdx().x
y = blockIdx().y
w = gridDim().x
h = gridDim().y
z = 1
arr = @cuDynamicSharedMem(Int32, (w, h, 3))
arr[x,y,z] = Int32(1)
linear_index::Int32 = l(x,y,z,w,h)
arr[linear_index] = Int32(1) # still works
out[x, y] = linear_index
return nothing
end
function make_matrix(width :: Int, height :: Int)
grid = (width, height)
threads = (1,)
cu_out = CuArray{Int32, 2}(undef, width, height)
@cuda blocks=grid threads=threads shmem=sizeof(Int32)*prod(grid) kernel(cu_out)
out = Array{Int32, 2}(cu_out)
return out
end
function main()
width = 10
height = 10
matrix = make_matrix(width, height)
println(matrix)
end
main()