Hello, I ran into a similar issue in my codebase and it took me literally months to solve. Here’s my solution (with my understanding, which may or may not be correct).
To be clear: this is actually a relatively tricky question that doesn’t have a clear correct answer. The core issue is that you cannot index a Tuple of mixed type on the GPU (as in, you cannot do mixed_tuple[i]
, but you can do mixed_tuple[1]
or mixed_tuple[2]
). One way to get around this is to use a generated function with @nif
like so:
@generated function stable_nonsense(fns, idx, val)
N = length(fns.parameters)
quote
Base.Cartesian.@nif $N d->d==idx d-> return fns[d](val)
end
end
Here is a full example using KernelAbstractions. Note that KernelAbstractions ultimately boils down to CUDA (Or AMDGPU, Metal, parallel CPU, OneAPI, whatever) and I did test everything on an NVIDIA GPU (as well as AMD).
I don’t know if you are familiar with @generated
functions (I certainly wasn’t before running into this issue), but they work entirely on the type domain and essentially generate a function on-the-fly, but at an earlier stage in the compilation process than other functions. In this case @nif
essentially writes code that looks like:
function stable_nonsense(fns, idx, val)
if idx == 1
fns[1](val)
elseif idx == 2
fns[2](val)
elseif idx == 3
fns[3](val)
...
else
fns[N](val) # Where N is not a variable, but a number like the others
end
end
If you want some more flexibility (like potentially passing in variable kwargs as well), you can write an @generated
function like so:
@generated function call_pt_fx(args, kwargs, idx)
exs = Expr[]
for i = 1:length(fxs.parameters)
ex = quote
if idx == $i
pt = fxs[$i](args...; kwargs[$i]...)
end
end
push!(exs, ex)
end
push!(exs, :(return pt))
return Expr(:block, exs...)
end
Note that in your case, it looks like your index might depend on some additional parameters, (whatever arg1
, arg2
, and arg3
are in getIndex(...)
), so you might need to also create another @generated
function to “unroll the loop.” Basically, instead of going for i =1:10
, you do something like:
@generated function pt_loop(args, fnums, kwargs)
exs = Expr[]
for i = 1:length(fnums.parameters)
ex = quote
idx = getIndex(...)
pt = call_pt_fx(args, kwargs, idx)
end
push!(exs, ex)
end
push!(exs, :(return pt))
return Expr(:block, exs...)
end
Here fnums
is a set of “function numbers” to iterate over. Like maybe I have a set of 3 functions then another set of 4 and want to randomly select one of the first set and then another from the second set (whatever getIndex(...)
does).
This code is quite messy, but here’s a “working” example from my codebase: https://github.com/leios/Fable.jl/blob/main/src/run/fractal_flame.jl
Anyway, I hope this helps and doesn’t send you down a crazy rabbit-hole.