Is there any good way to call functions from a set of functions in a CUDA kernel

Hi! I had some problem with calling functions from a function set in CUDA kernels. For example, if I define a function set like below and then call the kernel

# Define function f1, f2, f3, f4, f5
...

# Create a NamedTuple of functions
functions = (f1=f1, f2=f2, f3=f3, f4=f4, f5=f5)

# Call CUDA kernel
function kernel!(arg0, arg1, arg2, arg3, getIndex::Function, functions::NamedTuple)
    i = (blockIdx().x - 1) * blockDim().x + threadIdx().x
    j = (blockIdx().y - 1) * blockDim().y + threadIdx().y
    k = (blockIdx().z - 1) * blockDim().z + threadIdx().z

    ... # Some if condiditons
    index = getIndex(i, j, k, arg1, arg2, arg3) # Get the index from i, j, k and some other arguments
    f = functions[index] # Get the target function name
    ... # Use f to do something

    return nothing
end

it sometimes would report the error that

Reason: unsupported dynamic function invocation (call to f1) # or f2, f3, f4, f5 

I think that is because I did not specify the type of any functions from functions in the kernel prototype. But I cannot specify any of them in the prototype as they are not fixed but are chosen in the process within the kernel. So is there any good way to make a kernel like this (i.e., calling functions from a function set) work? Thanks!

1 Like

I think you need to replace the dynamic dispatch with a good old-fashioned branch. For example, you could collapse all your functions into one, like this:

function f(index, ...)
    if index == :f1
        # code from f1
    elseif index == :f2
        # code from f2
    ...
        # and so on
    end
end

function kernel!(...)
    ...
    index = getIndex(...)
    f(index, ...)
    ...
end

If you prefer you can of course keep the separate functions and put the branch inside kernel! instead:

function kernel!(...)
    ...
    index = getIndex(...)
    if index == :f1
        functions[:f1](...)
    elseif index == :f2
        functions[:f2](...)
    ...
    end
    ...
end

The point is that the compiler needs to understand exactly which function/method gets called at every line of code in the kernel. The line f = functions[index] in your current implementation breaks this inferability because f can refer to any of several functions depending on runtime values.

3 Likes

Hello, I ran into a similar issue in my codebase and it took me literally months to solve. Here’s my solution (with my understanding, which may or may not be correct).

To be clear: this is actually a relatively tricky question that doesn’t have a clear correct answer. The core issue is that you cannot index a Tuple of mixed type on the GPU (as in, you cannot do mixed_tuple[i], but you can do mixed_tuple[1] or mixed_tuple[2]). One way to get around this is to use a generated function with @nif like so:

@generated function stable_nonsense(fns, idx, val)
    N = length(fns.parameters)
    quote
        Base.Cartesian.@nif $N d->d==idx d-> return fns[d](val)
    end
end

Here is a full example using KernelAbstractions. Note that KernelAbstractions ultimately boils down to CUDA (Or AMDGPU, Metal, parallel CPU, OneAPI, whatever) and I did test everything on an NVIDIA GPU (as well as AMD).

I don’t know if you are familiar with @generated functions (I certainly wasn’t before running into this issue), but they work entirely on the type domain and essentially generate a function on-the-fly, but at an earlier stage in the compilation process than other functions. In this case @nif essentially writes code that looks like:

function stable_nonsense(fns, idx, val)
    if idx == 1
        fns[1](val)
    elseif idx == 2
        fns[2](val)
    elseif idx == 3
        fns[3](val)
    ...
    else
        fns[N](val) # Where N is not a variable, but a number like the others
    end
end

If you want some more flexibility (like potentially passing in variable kwargs as well), you can write an @generated function like so:

@generated function call_pt_fx(args, kwargs, idx)
    exs = Expr[]
    for i = 1:length(fxs.parameters)
        ex = quote
            if idx == $i
                pt = fxs[$i](args...; kwargs[$i]...) 
            end
        end
        push!(exs, ex)
    end
    push!(exs, :(return pt))

    return Expr(:block, exs...)
end

Note that in your case, it looks like your index might depend on some additional parameters, (whatever arg1, arg2, and arg3 are in getIndex(...)), so you might need to also create another @generated function to “unroll the loop.” Basically, instead of going for i =1:10, you do something like:

@generated function pt_loop(args, fnums, kwargs)
    exs = Expr[]

    for i = 1:length(fnums.parameters)
        ex = quote
            idx = getIndex(...)
            pt = call_pt_fx(args, kwargs, idx)
        end
        push!(exs, ex)
    end

    push!(exs, :(return pt))

    return Expr(:block, exs...)
end

Here fnums is a set of “function numbers” to iterate over. Like maybe I have a set of 3 functions then another set of 4 and want to randomly select one of the first set and then another from the second set (whatever getIndex(...) does).

This code is quite messy, but here’s a “working” example from my codebase: https://github.com/leios/Fable.jl/blob/main/src/run/fractal_flame.jl

Anyway, I hope this helps and doesn’t send you down a crazy rabbit-hole.

4 Likes

Thank you all for the replies! That would certainly be helpful :slight_smile: