I would like to compute a multi-dimensional array (4-D) A of a 2 components function f on a large multi-dimensional domain (X,Y,T) with GPU. So A will have size: 2, NX, NY, NT.
X, Y and T are 1D arrays
I tried to use the code of @SimonDanisch in An Introduction to GPU Programming in Julia - Nextjournal
with a simple map kernel.
This is what I wrote:
using GPUArrays, CuArrays
using SharedArrays
using DistributedArrays
using BenchmarkTools
# Overloading the Julia Base map! function for GPUArrays
function Base.map!(f::Function, A::GPUArray, X::GPUArray, Y::GPUArray, T::GPUArray)
# our function that will run on the gpu
function kernel(state, f, A, X, Y, T)
# If launch parameters aren't specified, linear_index gets the index
# into the Array passed as second argument to gpu_call (`A`)
I = CartesianIndex(state)
if I[2]*I[3]*I[4] <= length(A)/2
@inbounds A[1:2,I[2],I[3],I[4]] = f(X[I[2]], Y[I[3]], T[I[4]])
# call kernel on the gpu
gpu_call(kernel, A, (f, A, X, Y, T))
# on the GPU:
NX, NY, NT = 10, 15, 13
X, Y, T= rand(NX), rand(NY), rand(NT)
a = zeros(2, NX, NY, NT)
kernel(x,y,t) = [x+y, y-t]
xgpu, ygpu, tgpu, agpu = cu(X), cu(Y),cu(T), cu(a)
gpu_t = @belapsed begin
map!($kernel, $agpu, $xgpu, $ygpu, $tgpu)
I have the following error:
`GPU compilation of kernel(CuArrays.CuKernelState, typeof(kernel), CUDAnative.CuDeviceArray{Float32,4,CUDAnative.AS.Global}, CUDAnative.CuDeviceArray{Float32,1,CUDAnative.AS.Global}, CUDAnative.CuDeviceArray{Float32,1,CUDAnative.AS.Global}, CUDAnative.CuDeviceArray{Float32,1,CUDAnative.AS.Global}) failed
KernelError: kernel returns a value of type Union{}
Make sure your kernel function ends in return
, return nothing
or nothing
If the returned value is of type Union{}
, your Julia code probably throws an exception.
Inspect the code with @device_code_warntype
for more details.
I don’t know how to use the macro @device_code_warntype
to debug the code. I am working on a cluster, I wrote JULIA_DEBUG=CUDAnative in the terminal before to launch the julia REPL but I got ‘UndefVarError: @device_code_warntype not defined’ when I run the above code