Scattering and gathering arrays

I am trying to translate a python script that scatters aand gathers 2D arrays, as can be found here. python - Along what axis does mpi4py Scatterv function split a numpy array? - Stack Overflow

I have managed to get somethign working but it needs some improvments. One problem in particular, is how to split the array before I scatter it. In particular, consider the example where the matrix is say 5x4 and I want to split it between two cores. I have code that works when the comm size divides perfectly the number of rows and you can find it below.

But I would like to know how to generalize this for any kind of splitting. Also, other advice on how to improve the code would be greaty appreciated.

using MPI
using Printf

MPI.Init()

comm = MPI.COMM_WORLD
rank = MPI.Comm_rank(comm)
comm_size = MPI.Comm_size(comm)

# This assumes that N is divisible by comm_size!!!!
function find_split_size(test, comm_size, comm)

    N, M = size(test)

    counts = zeros(Int64, comm_size)
    displs = zeros(Int64, comm_size)

    counts[:] .= Int64(div(N, comm_size))
    displs[:]  = cumsum(append!([0], counts))[1:comm_size]

    return counts, displs
end

function array_split(test, counts, displs, comm_size, comm)

    N, M = size(test)

    split = zeros(Float64, comm_size, counts[1], M)
    for i in (1:comm_size)        
        split[i, :, :] = test[displs[i]+1:displs[i]+counts[i],:]
    end

    return split
end

N, M = 4, 3

if rank == 0
    test         = reshape(Float64.(1:M), 1, M)
    test         = repeat(test, N, 1)                              # from python example
    test         = permutedims(reshape((1.:(M*N)), M, N))          # a better test problem
    outputData_T = zeros((M, N))

    @printf("Original matrix: \n")
    @printf("================ \n")
    @printf("test is %s\n\n", test)

    counts, displs = find_split_size(test, comm_size, comm)      # need to generalize
 else
    test         = zeros(Float64, N, M)
    counts       = zeros(Int64, comm_size)
    displs       = zeros(Int64, comm_size)
    outputData_T = zeros((M, N))
end

test_T = permutedims(test)

MPI.Bcast!(counts, 0, comm)
MPI.Bcast!(displs, 0, comm)

if rank == 0
    split = array_split(test, counts, displs, comm_size, comm) 

    split_sizes_input  = counts*M
    split_sizes_output = counts*M

    displacements_input  = cumsum(append!([0], split_sizes_input))[1:comm_size]
    displacements_output = cumsum(append!([0], split_sizes_output))[1:comm_size]
    
    @printf("Scatter information:\n")
    @printf("====================\n")
    @printf("Input data split into vectors of sizes %s\n",       split_sizes_input)
    @printf("Input data split into displacements of sizes %s\n", displacements_input)
    @printf("\nSplit is of size %s\n\n", size(split))
else
    split = zeros(Float64, comm_size, counts[1], M)

    split_sizes_input    = zeros(Int64, comm_size)
    split_sizes_output   = zeros(Int64, comm_size)

    displacements_input  = zeros(Int64, comm_size)
    displacements_output = zeros(Int64, comm_size)
end

MPI.Bcast!(split, 0, comm)
MPI.Bcast!(split_sizes_output, 0, comm)
MPI.Bcast!(displacements_output, 0, comm)

output_chunk_T = permutedims(zeros(size(split[rank+1, :, :])))

MPI.Scatterv!(rank == 0 ? VBuffer(test_T, split_sizes_input, displacements_input) : nothing, output_chunk_T, 0, comm)

output = permutedims(output_chunk_T) #zeros(Float64, (size(output_chunk)[1], M))
if rank == 0
    @printf("Gathered array:\n")
    @printf("===============\n")
end
@printf("rank = %s  output = %s \n", rank, output)

MPI.Barrier(comm)

rank == 0 ? testnew = zeros(Float64, N) : nothing
MPI.Gatherv!(output_chunk_T, rank == 0 ? VBuffer(outputData_T, split_sizes_output, displacements_output) : nothing, 0, comm)
rank == 0 ? @printf("\nGathered outputData = %s\n", permutedims(outputData_T)) : nothing

MPI.Finalize()

It is been a while since the last time I touched MPI, but I am somewhat surprised that the MPI.jl package doesn’t provide higher-level constructs to handle the splits for you. If I remember correctly, when I wrote the C++ Boost.MPI Scatterv and Gatherv operations, we handled the uneven splits for the user. I wonder if the MPI.jl developers have this feature in mind? I am also surprised that the code relies on low-level buffers like VBuffer, Buffer, UBuffer, … I thought the MPI.jl package had an abstraction layer to scatter/gather Julia arrays directly.

Thanks @juliohm for sharing your thoughts. I don’t believe that MPI.jl has this features yet. I for one think it would make a nice addition as it would make it much easier for nonexperts like me to do this. I would be happy to help put something together, after I figure out how to do this.

Is this the software you are referring to? Function scatterv - 1.59.0

Numpy has array_split, which does the splittin for you, which is very convenient. It puts the arrays together in a list. I tried using a vector of arrays or ntuple of arrays, as that seemed to be similar, but I don’t think this objects are recognized because of isbitestype failing.

Exactly, boost::mpi::scatterv and boost::mpi::gatherv accept C++ std::vector as input and you don’t need to deal with buffers manually. All you need to do is provide the size of each chunk.

Regarding the splitting of arrays in Julia you can use Iterators.partition:

julia> collect(Iterators.partition([1,2,3,4,5], 2))
3-element Vector{SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}}:
 [1, 2]
 [3, 4]
 [5]

My concern was more about the low-level API that MPI.jl is offering.

The implementation in Boost has the typical C++ verbosity but you could try to follow it and see if developers in MPI.jl plan to add similar support for native Julia Vector: mpi/scatterv.hpp at develop · boostorg/mpi · GitHub

Thanks to this discussion an example has been created to show how to scatter and gather a 2D array, as can be found here.

We do need to use VBuffer but as inputs it only needs the data and a counts array. The problem I had before of trying to construct an array of different sizes is avoided by simply having the scattered array as output.

This code makes a lot of sense to me and I can certainly use this moving forward but if others had suggestions please let me know.