MPI.API.MPI_Reduce_scatter ERRORS

Thanks @carstenbauer and @simonbyrne for help. I figured it out.
The code below works.

# testReduce_scatter.jl
using MPI, LinearAlgebra
MPI.Init()
comm = MPI.COMM_WORLD
rank = MPI.Comm_rank(comm)
counts = [2, 2, 1]
V = ones(3,5)
V_vbuf = VBuffer(V, counts .* 3)
Y = zeros(3,counts[rank+1])
Y_buf = MPI.Buffer(Y)
add = MPI.Op(+, Float64, iscommutative=true)
MPI.API.MPI_Reduce_scatter(V, Y, V_vbuf.counts, Y_buf.datatype, add, comm)
for i in 0:2
   if i == rank
      println(rank, Y)
   end
   MPI.Barrier(comm)
end
% mpiexecjl -n 3 julia testReduce_scatter.jl
0[3.0 3.0; 3.0 3.0; 3.0 3.0]
1[3.0 3.0; 3.0 3.0; 3.0 3.0]
2[3.0; 3.0; 3.0;;]
1 Like