1 /*
2  * Copyright (C) by Argonne National Laboratory
3  *     See COPYRIGHT in top-level directory
4  */
5 
6 #include "mpiimpl.h"
7 
8 /* Algorithm: Noncommutative recursive doubling
9  *
10  * Restrictions: This function currently only implements support for the
11  * power-of-2, block-regular case (all receive counts are equal).
12  *
13  * Implements the reduce-scatter butterfly algorithm described in J. L. Traff's
14  * "An Improved Algorithm for (Non-commutative) Reduce-Scatter with an
15  * Application" from EuroPVM/MPI 2005.
16  *
17  * It takes lgp steps. At step 1, processes exchange (n-n/p) amount of
18  * data; at step 2, (n-2n/p) amount of data; at step 3, (n-4n/p)
19  * amount of data, and so forth.
20  *
21  * Cost = lgp.alpha + n.(lgp-(p-1)/p).beta + n.(lgp-(p-1)/p).gamma
22  */
MPIR_Reduce_scatter_intra_noncommutative(const void * sendbuf,void * recvbuf,const int recvcounts[],MPI_Datatype datatype,MPI_Op op,MPIR_Comm * comm_ptr,MPIR_Errflag_t * errflag)23 int MPIR_Reduce_scatter_intra_noncommutative(const void *sendbuf, void *recvbuf,
24                                              const int recvcounts[], MPI_Datatype datatype,
25                                              MPI_Op op, MPIR_Comm * comm_ptr,
26                                              MPIR_Errflag_t * errflag)
27 {
28     int mpi_errno = MPI_SUCCESS;
29     int mpi_errno_ret = MPI_SUCCESS;
30     int comm_size = comm_ptr->local_size;
31     int rank = comm_ptr->rank;
32     int pof2;
33     int log2_comm_size;
34     int i, k;
35     int recv_offset, send_offset;
36     int block_size, total_count, size;
37     MPI_Aint true_extent, true_lb;
38     int buf0_was_inout;
39     void *tmp_buf0;
40     void *tmp_buf1;
41     void *result_ptr;
42     MPIR_CHKLMEM_DECL(3);
43 
44     MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent);
45 
46     pof2 = 1;
47     log2_comm_size = 0;
48     while (pof2 < comm_size) {
49         pof2 <<= 1;
50         ++log2_comm_size;
51     }
52 
53 #ifdef HAVE_ERROR_CHECKING
54     /* begin error checking */
55     MPIR_Assert(pof2 == comm_size);     /* FIXME this version only works for power of 2 procs */
56 
57     for (i = 0; i < (comm_size - 1); ++i) {
58         MPIR_Assert(recvcounts[i] == recvcounts[i + 1]);
59     }
60     /* end error checking */
61 #endif
62 
63     /* size of a block (count of datatype per block, NOT bytes per block) */
64     block_size = recvcounts[0];
65     total_count = block_size * comm_size;
66 
67     MPIR_CHKLMEM_MALLOC(tmp_buf0, void *, true_extent * total_count, mpi_errno, "tmp_buf0",
68                         MPL_MEM_BUFFER);
69     MPIR_CHKLMEM_MALLOC(tmp_buf1, void *, true_extent * total_count, mpi_errno, "tmp_buf1",
70                         MPL_MEM_BUFFER);
71     /* adjust for potential negative lower bound in datatype */
72     tmp_buf0 = (void *) ((char *) tmp_buf0 - true_lb);
73     tmp_buf1 = (void *) ((char *) tmp_buf1 - true_lb);
74 
75     /* Copy our send data to tmp_buf0.  We do this one block at a time and
76      * permute the blocks as we go according to the mirror permutation. */
77     for (i = 0; i < comm_size; ++i) {
78         mpi_errno =
79             MPIR_Localcopy((char *) (sendbuf ==
80                                      MPI_IN_PLACE ? (const void *) recvbuf : sendbuf) +
81                            (i * true_extent * block_size), block_size, datatype,
82                            (char *) tmp_buf0 +
83                            (MPL_mirror_permutation(i, log2_comm_size) * true_extent * block_size),
84                            block_size, datatype);
85         MPIR_ERR_CHECK(mpi_errno);
86     }
87     buf0_was_inout = 1;
88 
89     send_offset = 0;
90     recv_offset = 0;
91     size = total_count;
92     for (k = 0; k < log2_comm_size; ++k) {
93         /* use a double-buffering scheme to avoid local copies */
94         char *incoming_data = (buf0_was_inout ? tmp_buf1 : tmp_buf0);
95         char *outgoing_data = (buf0_was_inout ? tmp_buf0 : tmp_buf1);
96         int peer = rank ^ (0x1 << k);
97         size /= 2;
98 
99         if (rank > peer) {
100             /* we have the higher rank: send top half, recv bottom half */
101             recv_offset += size;
102         } else {
103             /* we have the lower rank: recv top half, send bottom half */
104             send_offset += size;
105         }
106 
107         mpi_errno = MPIC_Sendrecv(outgoing_data + send_offset * true_extent,
108                                   size, datatype, peer, MPIR_REDUCE_SCATTER_TAG,
109                                   incoming_data + recv_offset * true_extent,
110                                   size, datatype, peer, MPIR_REDUCE_SCATTER_TAG,
111                                   comm_ptr, MPI_STATUS_IGNORE, errflag);
112         if (mpi_errno) {
113             /* for communication errors, just record the error but continue */
114             *errflag =
115                 MPIX_ERR_PROC_FAILED ==
116                 MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER;
117             MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
118             MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
119         }
120         /* always perform the reduction at recv_offset, the data at send_offset
121          * is now our peer's responsibility */
122         if (rank > peer) {
123             /* higher ranked value so need to call op(received_data, my_data) */
124             mpi_errno = MPIR_Reduce_local(incoming_data + recv_offset * true_extent,
125                                           outgoing_data + recv_offset * true_extent,
126                                           size, datatype, op);
127             MPIR_ERR_CHECK(mpi_errno);
128         } else {
129             /* lower ranked value so need to call op(my_data, received_data) */
130             MPIR_Reduce_local(outgoing_data + recv_offset * true_extent,
131                               incoming_data + recv_offset * true_extent, size, datatype, op);
132             MPIR_ERR_CHECK(mpi_errno);
133             buf0_was_inout = !buf0_was_inout;
134         }
135 
136         /* the next round of send/recv needs to happen within the block (of size
137          * "size") that we just received and reduced */
138         send_offset = recv_offset;
139     }
140 
141     MPIR_Assert(size == recvcounts[rank]);
142 
143     /* copy the reduced data to the recvbuf */
144     result_ptr = (char *) (buf0_was_inout ? tmp_buf0 : tmp_buf1) + recv_offset * true_extent;
145     mpi_errno = MPIR_Localcopy(result_ptr, size, datatype, recvbuf, size, datatype);
146     MPIR_ERR_CHECK(mpi_errno);
147 
148   fn_exit:
149     MPIR_CHKLMEM_FREEALL();
150     if (mpi_errno_ret)
151         mpi_errno = mpi_errno_ret;
152     else if (*errflag != MPIR_ERR_NONE)
153         MPIR_ERR_SET(mpi_errno, *errflag, "**coll_fail");
154     return mpi_errno;
155   fn_fail:
156     goto fn_exit;
157 }
158