Hello! A version of the following function is majorly bottle-necking a project I am working on and I could use help speeding it up. In all the below code, A is an k x n matrix, l is a length-n
vector, and q is a length n vector of integers indexing columns of A. All values of q are distinct, so q represents a permutation of the columns of A. I only really care about performance when A is several gigabytes.
I want to find the index (i, j) that maximizes A[i, q[j]]^2 + (1 - l[i])*(1 + l[j]) for 1 <= i <= k and k+1 <= j <= n. The naive implementation is func1.
function func1(A, l, q)
k, n = size(A)
mi, mj, mf = 0, 0, -Inf
for i in 1:k
for j in k+1:n
nf = A[i, q[j]]^2 + (1 - l[i])*(1 + l[j])
if nf > mf
mi, mj, mf = i, j, nf
end
end
end
return mi, mj, mf
end
I played around for a while with Base.Threads, LoopVectorization.jl, OhMyThreads, etc but the fastest I could get was simply using mapreduce from ThreadsX as follows.
function func2(A, l, q)
k, n = size(A)
pairs = product(1:k, (k+1):n)
red = (x, y) -> @inbounds x[3] > y[3] ? x : y
mi, mj, mf = ThreadsX.mapreduce(
((i, j),) -> begin
nf = @inbounds A[i, q[j]]^2 + (1 - l[i])*(1 + l[j])
(i, j, nf)
end,
red,
pairs;
init = (0, 0, -Inf)
)
return mi, mj, mf
end
This works well. On my mediocre laptop, when A is 5000 x 50000, func1 takes about 2s while func2 takes 0.045s. However, since the rest of my code relies so heavely on BLAS, func2 still takes around 30x longer than everything else even though in theory the other part dominates the time-complexity.
I feel like I don’t know enough about SIMD, precaching, etc. to get this running faster and would appreciate y’alls help.