I’m able to use KernalAbstractions.jl to solve this problem in “jax style”, where the same kernel can be used for CPU and GPU (CUDA, ROCM, oneAPI) backend:
4 Likes
I’m able to use KernalAbstractions.jl to solve this problem in “jax style”, where the same kernel can be used for CPU and GPU (CUDA, ROCM, oneAPI) backend: