How to use MPI.jl to scatter a matrix to a 2D processor grid?

For example, I have a 5 by 5 matrix
A = [1, 2, 3, 4, 5; 1, 2, 3, 4, 5; 1, 2, 3, 4, 5; 1, 2, 3, 4, 5; 1, 2 ,3 ,4, 5]
and I want to scatter A to four processors like [2,3; 4, 5].
More precisely, I want to broadcast [1,2,3;1,2,3;1,2,3] to processor 2, [4,5;4,5;4,5] to processor 3, [1,2,3;1,2,3] to processor 4, and [4,5;4,5] to processor 5.
I tried MPI.Scatter and MPI.Scatterv! but it seems that they could only split the matrix along the last dimension (column).

Is there any way to do this using MPI.jl?

Yes, have you checked out the example at Scatterv and Gatherv · MPI.jl?

I found it more straightforward to accomplish the same task with a combination of ISend and IRecv! with SubArrays, since the memory layout is 2D.

I wrote a package MPIHaloArrays.jl that does this sort of operation (shameless plug :smile:).

This is a slightly modified snippet from MPIHaloArrays/src/scattergather.jl file that does this if you don’t want to use the MPIHaloArray type.


# A is the array you wish to split up
# A_local is the smaller array that lives on each MPI proc
A_local = zeros(eltype(A), remote_size)

# buffer to send the small chunk of the original array A to
remote_buf = MPI.Buffer(A_local) 
    
# vector of MPI send/recv requests
reqs = Vector{MPI.Request}(undef, 0) 

# loop through and send each small chunk to the different MPI procs
if rank == root
    for sendrank in 0:nprocs-1 # the MPI library is 0-based

        # Get the indices on the root buffer to send to the remote buffer
        ilo, ihi, jlo, jhi = indices[sendrank + 1]

        # A is the global array you want to scatter
        data_on_root = @view A[ilo:ihi, jlo:jhi] 

        root_buf = MPI.Buffer(data_on_root)
        sendtag = sendrank + 1000
        sreq =  MPI.Isend(root_buf, sendrank, sendtag, comm)
        push!(reqs, sreq)
    end
end

recievetag = topology.rank + 1000
rreq = MPI.Irecv!(remote_buf, root, recievetag, comm)
push!(reqs, rreq)
  
MPI.Waitall!(reqs)

# now you have `A_local` on each MPI process that contains a subset of the original 2D array
2 Likes