I guess you are assuming that the outer most axis is much longer than Threads.nthreads()
and the work load per element is more-or-less constant? I suppose that’s reasonable when the user is asking to vectorize the loop. But, if you want to parallelize all but the inner most axis, it may be useful to use halve
function from SplittablesBase.jl (which can handle IndexCartesian
-style arrays and zip
of them). This is how I support something like ThreadsX.foreach(f, A, A')
. I think this approach is flexible enough to handle “non-rectangular” iterations like upper-triangular part of the matrix.
It’s interesting as I thought reduction would be the easiest part as there is no mutation (e.g., it sounds hard to chunk the iteration space appropriately to avoid simultaneously writing to the shared region). Though maybe there are something subtle when mixing with vectorization? Naively, I’d imagine it’d be implemented as a transformation to mapreduce
:
- Separate the loop body to the mapping (
(m, n) -> x[m] * A[m,n] * y[n]
) and reduction (+
) parts and generate functions for them. - Determine unroll factor and SIMD vector width.
- Feed those functions and parameters to parallelized and vectorized
mapreduce
.