I’m going through the recent paper comparing batched KernelAbstractions kernels to standard array batching in jax and pytorch. In section 5.1.2, I found this interesting tidbit:
KernelAbstractions.jl performs a limited form of auto-tuning by optimizing the launch parameters for occupancy.
I went back to the docs to see if I could find anything describing this, but came up empty handed.
Perhaps I’m looking in the wrong place. Does anyone have any references that describe this functionality (and how well it works across different hardware platforms)?
Thanks!