I am trying to translate a python script that scatters aand gathers 2D arrays, as can be found here. python - Along what axis does mpi4py Scatterv function split a numpy array? - Stack Overflow
I have managed to get somethign working but it needs some improvments. One problem in particular, is how to split the array before I scatter it. In particular, consider the example where the matrix is say 5x4
and I want to split it between two cores. I have code that works when the comm size divides perfectly the number of rows and you can find it below.
But I would like to know how to generalize this for any kind of splitting. Also, other advice on how to improve the code would be greaty appreciated.
using MPI
using Printf
rank = MPI.Comm_rank(comm)
comm_size = MPI.Comm_size(comm)
# This assumes that N is divisible by comm_size!!!!
function find_split_size(test, comm_size, comm)
N, M = size(test)
counts = zeros(Int64, comm_size)
displs = zeros(Int64, comm_size)
counts[:] .= Int64(div(N, comm_size))
displs[:] = cumsum(append!([0], counts))[1:comm_size]
return counts, displs
function array_split(test, counts, displs, comm_size, comm)
N, M = size(test)
split = zeros(Float64, comm_size, counts[1], M)
for i in (1:comm_size)
split[i, :, :] = test[displs[i]+1:displs[i]+counts[i],:]
return split
N, M = 4, 3
if rank == 0
test = reshape(Float64.(1:M), 1, M)
test = repeat(test, N, 1) # from python example
test = permutedims(reshape((1.:(M*N)), M, N)) # a better test problem
outputData_T = zeros((M, N))
@printf("Original matrix: \n")
@printf("================ \n")
@printf("test is %s\n\n", test)
counts, displs = find_split_size(test, comm_size, comm) # need to generalize
test = zeros(Float64, N, M)
counts = zeros(Int64, comm_size)
displs = zeros(Int64, comm_size)
outputData_T = zeros((M, N))
test_T = permutedims(test)
MPI.Bcast!(counts, 0, comm)
MPI.Bcast!(displs, 0, comm)
if rank == 0
split = array_split(test, counts, displs, comm_size, comm)
split_sizes_input = counts*M
split_sizes_output = counts*M
displacements_input = cumsum(append!([0], split_sizes_input))[1:comm_size]
displacements_output = cumsum(append!([0], split_sizes_output))[1:comm_size]
@printf("Scatter information:\n")
@printf("Input data split into vectors of sizes %s\n", split_sizes_input)
@printf("Input data split into displacements of sizes %s\n", displacements_input)
@printf("\nSplit is of size %s\n\n", size(split))
split = zeros(Float64, comm_size, counts[1], M)
split_sizes_input = zeros(Int64, comm_size)
split_sizes_output = zeros(Int64, comm_size)
displacements_input = zeros(Int64, comm_size)
displacements_output = zeros(Int64, comm_size)
MPI.Bcast!(split, 0, comm)
MPI.Bcast!(split_sizes_output, 0, comm)
MPI.Bcast!(displacements_output, 0, comm)
output_chunk_T = permutedims(zeros(size(split[rank+1, :, :])))
MPI.Scatterv!(rank == 0 ? VBuffer(test_T, split_sizes_input, displacements_input) : nothing, output_chunk_T, 0, comm)
output = permutedims(output_chunk_T) #zeros(Float64, (size(output_chunk)[1], M))
if rank == 0
@printf("Gathered array:\n")
@printf("rank = %s output = %s \n", rank, output)
rank == 0 ? testnew = zeros(Float64, N) : nothing
MPI.Gatherv!(output_chunk_T, rank == 0 ? VBuffer(outputData_T, split_sizes_output, displacements_output) : nothing, 0, comm)
rank == 0 ? @printf("\nGathered outputData = %s\n", permutedims(outputData_T)) : nothing