I am trying to translate a python script that scatters aand gathers 2D arrays, as can be found here. python - Along what axis does mpi4py Scatterv function split a numpy array? - Stack Overflow

I have managed to get somethign working but it needs some improvments. One problem in particular, is how to split the array before I scatter it. In particular, consider the example where the matrix is say `5x4`

and I want to split it between two cores. I have code that works when the comm size divides perfectly the number of rows and you can find it below.

But I would like to know how to generalize this for any kind of splitting. Also, other advice on how to improve the code would be greaty appreciated.

```
using MPI
using Printf
MPI.Init()
comm = MPI.COMM_WORLD
rank = MPI.Comm_rank(comm)
comm_size = MPI.Comm_size(comm)
# This assumes that N is divisible by comm_size!!!!
function find_split_size(test, comm_size, comm)
N, M = size(test)
counts = zeros(Int64, comm_size)
displs = zeros(Int64, comm_size)
counts[:] .= Int64(div(N, comm_size))
displs[:] = cumsum(append!([0], counts))[1:comm_size]
return counts, displs
end
function array_split(test, counts, displs, comm_size, comm)
N, M = size(test)
split = zeros(Float64, comm_size, counts[1], M)
for i in (1:comm_size)
split[i, :, :] = test[displs[i]+1:displs[i]+counts[i],:]
end
return split
end
N, M = 4, 3
if rank == 0
test = reshape(Float64.(1:M), 1, M)
test = repeat(test, N, 1) # from python example
test = permutedims(reshape((1.:(M*N)), M, N)) # a better test problem
outputData_T = zeros((M, N))
@printf("Original matrix: \n")
@printf("================ \n")
@printf("test is %s\n\n", test)
counts, displs = find_split_size(test, comm_size, comm) # need to generalize
else
test = zeros(Float64, N, M)
counts = zeros(Int64, comm_size)
displs = zeros(Int64, comm_size)
outputData_T = zeros((M, N))
end
test_T = permutedims(test)
MPI.Bcast!(counts, 0, comm)
MPI.Bcast!(displs, 0, comm)
if rank == 0
split = array_split(test, counts, displs, comm_size, comm)
split_sizes_input = counts*M
split_sizes_output = counts*M
displacements_input = cumsum(append!([0], split_sizes_input))[1:comm_size]
displacements_output = cumsum(append!([0], split_sizes_output))[1:comm_size]
@printf("Scatter information:\n")
@printf("====================\n")
@printf("Input data split into vectors of sizes %s\n", split_sizes_input)
@printf("Input data split into displacements of sizes %s\n", displacements_input)
@printf("\nSplit is of size %s\n\n", size(split))
else
split = zeros(Float64, comm_size, counts[1], M)
split_sizes_input = zeros(Int64, comm_size)
split_sizes_output = zeros(Int64, comm_size)
displacements_input = zeros(Int64, comm_size)
displacements_output = zeros(Int64, comm_size)
end
MPI.Bcast!(split, 0, comm)
MPI.Bcast!(split_sizes_output, 0, comm)
MPI.Bcast!(displacements_output, 0, comm)
output_chunk_T = permutedims(zeros(size(split[rank+1, :, :])))
MPI.Scatterv!(rank == 0 ? VBuffer(test_T, split_sizes_input, displacements_input) : nothing, output_chunk_T, 0, comm)
output = permutedims(output_chunk_T) #zeros(Float64, (size(output_chunk)[1], M))
if rank == 0
@printf("Gathered array:\n")
@printf("===============\n")
end
@printf("rank = %s output = %s \n", rank, output)
MPI.Barrier(comm)
rank == 0 ? testnew = zeros(Float64, N) : nothing
MPI.Gatherv!(output_chunk_T, rank == 0 ? VBuffer(outputData_T, split_sizes_output, displacements_output) : nothing, 0, comm)
rank == 0 ? @printf("\nGathered outputData = %s\n", permutedims(outputData_T)) : nothing
MPI.Finalize()
```