Hello, I ran into a similar issue in my codebase and it took me literally months to solve. Here’s my solution (with my understanding, which may or may not be correct).
To be clear: this is actually a relatively tricky question that doesn’t have a clear correct answer. The core issue is that you cannot index a Tuple of mixed type on the GPU (as in, you cannot do mixed_tuple[i], but you can do mixed_tuple[1] or mixed_tuple[2]). One way to get around this is to use a generated function with @nif like so:
@generated function stable_nonsense(fns, idx, val)
N = length(fns.parameters)
quote
Base.Cartesian.@nif $N d->d==idx d-> return fns[d](val)
end
end
Here is a full example using KernelAbstractions. Note that KernelAbstractions ultimately boils down to CUDA (Or AMDGPU, Metal, parallel CPU, OneAPI, whatever) and I did test everything on an NVIDIA GPU (as well as AMD).
I don’t know if you are familiar with @generated functions (I certainly wasn’t before running into this issue), but they work entirely on the type domain and essentially generate a function on-the-fly, but at an earlier stage in the compilation process than other functions. In this case @nif essentially writes code that looks like:
function stable_nonsense(fns, idx, val)
if idx == 1
fns[1](val)
elseif idx == 2
fns[2](val)
elseif idx == 3
fns[3](val)
...
else
fns[N](val) # Where N is not a variable, but a number like the others
end
end
If you want some more flexibility (like potentially passing in variable kwargs as well), you can write an @generated function like so:
@generated function call_pt_fx(args, kwargs, idx)
exs = Expr[]
for i = 1:length(fxs.parameters)
ex = quote
if idx == $i
pt = fxs[$i](args...; kwargs[$i]...)
end
end
push!(exs, ex)
end
push!(exs, :(return pt))
return Expr(:block, exs...)
end
Note that in your case, it looks like your index might depend on some additional parameters, (whatever arg1, arg2, and arg3 are in getIndex(...)), so you might need to also create another @generated function to “unroll the loop.” Basically, instead of going for i =1:10, you do something like:
@generated function pt_loop(args, fnums, kwargs)
exs = Expr[]
for i = 1:length(fnums.parameters)
ex = quote
idx = getIndex(...)
pt = call_pt_fx(args, kwargs, idx)
end
push!(exs, ex)
end
push!(exs, :(return pt))
return Expr(:block, exs...)
end
Here fnums is a set of “function numbers” to iterate over. Like maybe I have a set of 3 functions then another set of 4 and want to randomly select one of the first set and then another from the second set (whatever getIndex(...) does).
This code is quite messy, but here’s a “working” example from my codebase: https://github.com/leios/Fable.jl/blob/main/src/run/fractal_flame.jl
Anyway, I hope this helps and doesn’t send you down a crazy rabbit-hole.