Is there any good way to call functions from a set of functions in a CUDA kernel

Hello, I ran into a similar issue in my codebase and it took me literally months to solve. Here’s my solution (with my understanding, which may or may not be correct).

To be clear: this is actually a relatively tricky question that doesn’t have a clear correct answer. The core issue is that you cannot index a Tuple of mixed type on the GPU (as in, you cannot do mixed_tuple[i], but you can do mixed_tuple[1] or mixed_tuple[2]). One way to get around this is to use a generated function with @nif like so:

@generated function stable_nonsense(fns, idx, val)
    N = length(fns.parameters)
    quote
        Base.Cartesian.@nif $N d->d==idx d-> return fns[d](val)
    end
end

Here is a full example using KernelAbstractions. Note that KernelAbstractions ultimately boils down to CUDA (Or AMDGPU, Metal, parallel CPU, OneAPI, whatever) and I did test everything on an NVIDIA GPU (as well as AMD).

I don’t know if you are familiar with @generated functions (I certainly wasn’t before running into this issue), but they work entirely on the type domain and essentially generate a function on-the-fly, but at an earlier stage in the compilation process than other functions. In this case @nif essentially writes code that looks like:

function stable_nonsense(fns, idx, val)
    if idx == 1
        fns[1](val)
    elseif idx == 2
        fns[2](val)
    elseif idx == 3
        fns[3](val)
    ...
    else
        fns[N](val) # Where N is not a variable, but a number like the others
    end
end

If you want some more flexibility (like potentially passing in variable kwargs as well), you can write an @generated function like so:

@generated function call_pt_fx(args, kwargs, idx)
    exs = Expr[]
    for i = 1:length(fxs.parameters)
        ex = quote
            if idx == $i
                pt = fxs[$i](args...; kwargs[$i]...) 
            end
        end
        push!(exs, ex)
    end
    push!(exs, :(return pt))

    return Expr(:block, exs...)
end

Note that in your case, it looks like your index might depend on some additional parameters, (whatever arg1, arg2, and arg3 are in getIndex(...)), so you might need to also create another @generated function to “unroll the loop.” Basically, instead of going for i =1:10, you do something like:

@generated function pt_loop(args, fnums, kwargs)
    exs = Expr[]

    for i = 1:length(fnums.parameters)
        ex = quote
            idx = getIndex(...)
            pt = call_pt_fx(args, kwargs, idx)
        end
        push!(exs, ex)
    end

    push!(exs, :(return pt))

    return Expr(:block, exs...)
end

Here fnums is a set of “function numbers” to iterate over. Like maybe I have a set of 3 functions then another set of 4 and want to randomly select one of the first set and then another from the second set (whatever getIndex(...) does).

This code is quite messy, but here’s a “working” example from my codebase: https://github.com/leios/Fable.jl/blob/main/src/run/fractal_flame.jl

Anyway, I hope this helps and doesn’t send you down a crazy rabbit-hole.

4 Likes