Multithreading non-contiguous mapreduction over huge matrices

Hello! A version of the following function is majorly bottle-necking a project I am working on and I could use help speeding it up. In all the below code, A is an k x n matrix, l is a length-n
vector, and q is a length n vector of integers indexing columns of A. All values of q are distinct, so q represents a permutation of the columns of A. I only really care about performance when A is several gigabytes.

I want to find the index (i, j) that maximizes A[i, q[j]]^2 + (1 - l[i])*(1 + l[j]) for 1 <= i <= k and k+1 <= j <= n. The naive implementation is func1.

function func1(A, l, q)
    k, n = size(A)
    mi, mj, mf = 0, 0, -Inf
    for i in 1:k
        for j in k+1:n
            nf = A[i, q[j]]^2 + (1 - l[i])*(1 + l[j])
            if nf > mf
                mi, mj, mf = i, j, nf
            end
        end
    end
    return mi, mj, mf
end

I played around for a while with Base.Threads, LoopVectorization.jl, OhMyThreads, etc but the fastest I could get was simply using mapreduce from ThreadsX as follows.

function func2(A, l, q)
    k, n = size(A)

    pairs = product(1:k, (k+1):n)

    red = (x, y) -> @inbounds x[3] > y[3] ? x : y
    mi, mj, mf = ThreadsX.mapreduce(
        ((i, j),) -> begin
            nf = @inbounds A[i, q[j]]^2 + (1 - l[i])*(1 + l[j])
            (i, j, nf)
        end,
        red,
        pairs;
        init = (0, 0, -Inf)
    )

    return mi, mj, mf
end

This works well. On my mediocre laptop, when A is 5000 x 50000, func1 takes about 2s while func2 takes 0.045s. However, since the rest of my code relies so heavely on BLAS, func2 still takes around 30x longer than everything else even though in theory the other part dominates the time-complexity.

I feel like I don’t know enough about SIMD, precaching, etc. to get this running faster and would appreciate y’alls help.

I can’t test it out now, but I think I would’ve inverted the q permutation into a invq, and done the computation like

     nf = @inbounds A[i, j]^2 + (1 - l[i])*(1 + l[invq[j]])

with the inner loop on i. This will give better data locality, i.e. the matrix A will be looped through in memory order, which is usually a good thing.

1 Like

I think this would help if I were trying to compute nf for all j, but that is not the case. I only want 1+k <= j <= n, so the portion of A that I care about is A[:, q[k+1:n]].

Separately, in case it help, I think I am able to work with A' instead of A. Then q will instead index rows of A which will be contiguous.

Yes after some test, I don’t think you can beat the parallel version of ThreadsX (try to go to OhMyThreads.jl though) here is some tests the 2 may be the best to let threads to BLAS

function func1(A, l, q)
    k, n = size(A)
    mi, mj, mf = 0, 0, -Inf
    for i in 1:k
        for j in k+1:n
            nf = A[i, q[j]]^2 + (1 - l[i])*(1 + l[j])
            if nf > mf
                mi, mj, mf = i, j, nf
            end
        end
    end
    return mi, mj, mf
end
@fastmath function func2(A, l, q)
    k, n = size(A)
    mi,mj,mf = 0,0,-Inf
    @inbounds @simd for j in 1:n-k
        jp = j+k
        qj = q[jp]
        ljp = 1+l[jp]
        @inbounds @simd for i in 1:k
            li = l[i]
            A2 = A[i,qj]
            nf = A2*A2 + (1 - li)*ljp
            if nf > mf
                mi, mj, mf = i, jp, nf
            end
        end
    end
    return mi, mj, mf
end

@fastmath function func3(A, l, q)
    k, n = size(A)
    part = 16
    it = Iterators.partition(1:n-k,cld(n-k,part))
    cache = Vector{Tuple{Int,Int,Float64}}(undef,part)
    @sync for (id,J) in enumerate(it)
        Threads.@spawn begin 
            miloc,mjloc,mfloc = 0,0,-Inf
            @inbounds @simd for j in J
                jp = j+k
                qj = q[jp]
                ljp = 1+l[jp]
                @inbounds @simd for i in 1:k
                    li = l[i]
                    A2 = A[i,qj]
                    nf = A2*A2 + (1 - li)*ljp
                    if nf > mfloc
                        miloc, mjloc, mfloc = i, jp, nf
                    end
                end
            end
            cache[id] = (miloc,mjloc,mfloc)
        end
    end
    mi,mj,mf = 0,0,-Inf
    for i in eachindex(cache)
        mii,mji,nfi = cache[i]
        if nfi > mf
            mi,mj,mf = mii,mji,nfi
        end
    end
    return mi, mj, mf
end



A = rand(5000,50_000);
l = rand(50_000);
q = rand(1:50_000,50_000);

using BenchmarkTools
@btime func1($A,$l,$q)
@btime func2($A,$l,$q)
@btime func3($A,$l,$q)
  1.578 s (0 allocations: 0 bytes)
  231.710 ms (0 allocations: 0 bytes)
  85.590 ms (108 allocations: 12.00 KiB)

While func3 is close, the ThreadX version is still a hair faster. What does ThreadX do so well? Might a smart C implementation be any better, or do you think ThreadX is achieving essentially the best performance I can hope for?

There are two ways to perform reduction here I do a by task reduction but maybe a binary reduction would be better and threadsx may use that, you can try to know the best you can do with C or rust just so you know what you should be able to get but maybe that’s just the biggest operation of your function even if you don’t realise it. The question is do you want that serial or not, ie, do you prefer keeping your threads at max power for blas or use them in here, the fact that the threadsx code made you loose perf on the global indicates that but I’m not sure.