1 /*
2  * Copyright (C) by Argonne National Laboratory
3  *     See COPYRIGHT in top-level directory
4  */
5 
6 #include "mpiimpl.h"
7 
8 /* Local utility macro: takes an two args and sets lvalue cr_ equal to the rank
9  * in comm_ptr corresponding to rvalue gr_ */
10 #define to_comm_rank(cr_, gr_)                                                                                \
11     do {                                                                                                      \
12         int gr_tmp_ = (gr_);                                                                                  \
13         mpi_errno = MPIR_Group_translate_ranks_impl(group_ptr, 1, &(gr_tmp_), comm_ptr->local_group, &(cr_)); \
14         MPIR_ERR_CHECK(mpi_errno);                                                                            \
15         MPIR_Assert((cr_) != MPI_UNDEFINED);                                                                  \
16     } while (0)
17 
MPII_Allreduce_group_intra(void * sendbuf,void * recvbuf,int count,MPI_Datatype datatype,MPI_Op op,MPIR_Comm * comm_ptr,MPIR_Group * group_ptr,int tag,MPIR_Errflag_t * errflag)18 int MPII_Allreduce_group_intra(void *sendbuf, void *recvbuf, int count,
19                                MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr,
20                                MPIR_Group * group_ptr, int tag, MPIR_Errflag_t * errflag)
21 {
22     MPI_Aint type_size;
23     int mpi_errno = MPI_SUCCESS;
24     int mpi_errno_ret = MPI_SUCCESS;
25     /* newrank is a rank in group_ptr */
26     int mask, dst, is_commutative, pof2, newrank, rem, newdst, i,
27         send_idx, recv_idx, last_idx, send_cnt, recv_cnt, *cnts, *disps;
28     MPI_Aint true_extent, true_lb, extent;
29     void *tmp_buf;
30     int group_rank, group_size;
31     int cdst, csrc;
32     MPIR_CHKLMEM_DECL(3);
33 
34     group_rank = group_ptr->rank;
35     group_size = group_ptr->size;
36     MPIR_ERR_CHKANDJUMP(group_rank == MPI_UNDEFINED, mpi_errno, MPI_ERR_OTHER, "**rank");
37 
38     is_commutative = MPIR_Op_is_commutative(op);
39 
40     /* need to allocate temporary buffer to store incoming data */
41     MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent);
42     MPIR_Datatype_get_extent_macro(datatype, extent);
43 
44     MPIR_CHKLMEM_MALLOC(tmp_buf, void *, count * (MPL_MAX(extent, true_extent)), mpi_errno,
45                         "temporary buffer", MPL_MEM_BUFFER);
46 
47     /* adjust for potential negative lower bound in datatype */
48     tmp_buf = (void *) ((char *) tmp_buf - true_lb);
49 
50     /* copy local data into recvbuf */
51     if (sendbuf != MPI_IN_PLACE) {
52         mpi_errno = MPIR_Localcopy(sendbuf, count, datatype, recvbuf, count, datatype);
53         MPIR_ERR_CHECK(mpi_errno);
54     }
55 
56     MPIR_Datatype_get_size_macro(datatype, type_size);
57 
58     /* get nearest power-of-two less than or equal to comm_size */
59     pof2 = MPL_pof2(group_size);
60 
61     rem = group_size - pof2;
62 
63     /* In the non-power-of-two case, all even-numbered
64      * processes of rank < 2*rem send their data to
65      * (rank+1). These even-numbered processes no longer
66      * participate in the algorithm until the very end. The
67      * remaining processes form a nice power-of-two. */
68 
69     if (group_rank < 2 * rem) {
70         if (group_rank % 2 == 0) {      /* even */
71             to_comm_rank(cdst, group_rank + 1);
72             mpi_errno = MPIC_Send(recvbuf, count, datatype, cdst, tag, comm_ptr, errflag);
73             if (mpi_errno) {
74                 /* for communication errors, just record the error but continue */
75                 *errflag =
76                     MPIX_ERR_PROC_FAILED ==
77                     MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER;
78                 MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
79                 MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
80             }
81 
82             /* temporarily set the rank to -1 so that this
83              * process does not pariticipate in recursive
84              * doubling */
85             newrank = -1;
86         } else {        /* odd */
87             to_comm_rank(csrc, group_rank - 1);
88             mpi_errno = MPIC_Recv(tmp_buf, count,
89                                   datatype, csrc, tag, comm_ptr, MPI_STATUS_IGNORE, errflag);
90             if (mpi_errno) {
91                 /* for communication errors, just record the error but continue */
92                 *errflag =
93                     MPIX_ERR_PROC_FAILED ==
94                     MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER;
95                 MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
96                 MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
97             }
98 
99             /* do the reduction on received data. since the
100              * ordering is right, it doesn't matter whether
101              * the operation is commutative or not. */
102             mpi_errno = MPIR_Reduce_local(tmp_buf, recvbuf, count, datatype, op);
103             MPIR_ERR_CHECK(mpi_errno);
104 
105             /* change the rank */
106             newrank = group_rank / 2;
107         }
108     } else      /* rank >= 2*rem */
109         newrank = group_rank - rem;
110 
111     /* If op is user-defined or count is less than pof2, use
112      * recursive doubling algorithm. Otherwise do a reduce-scatter
113      * followed by allgather. (If op is user-defined,
114      * derived datatypes are allowed and the user could pass basic
115      * datatypes on one process and derived on another as long as
116      * the type maps are the same. Breaking up derived
117      * datatypes to do the reduce-scatter is tricky, therefore
118      * using recursive doubling in that case.) */
119 
120     if (newrank != -1) {
121         if ((count * type_size <= MPIR_CVAR_ALLREDUCE_SHORT_MSG_SIZE) ||
122             (!HANDLE_IS_BUILTIN(op)) || (count < pof2)) {
123             /* use recursive doubling */
124             mask = 0x1;
125             while (mask < pof2) {
126                 newdst = newrank ^ mask;
127                 /* find real rank of dest */
128                 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
129                 to_comm_rank(cdst, dst);
130 
131                 /* Send the most current data, which is in recvbuf. Recv
132                  * into tmp_buf */
133                 mpi_errno = MPIC_Sendrecv(recvbuf, count, datatype,
134                                           cdst, tag, tmp_buf,
135                                           count, datatype, cdst,
136                                           tag, comm_ptr, MPI_STATUS_IGNORE, errflag);
137                 if (mpi_errno) {
138                     /* for communication errors, just record the error but continue */
139                     *errflag =
140                         MPIX_ERR_PROC_FAILED ==
141                         MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER;
142                     MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
143                     MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
144                 } else {
145 
146                     /* tmp_buf contains data received in this step.
147                      * recvbuf contains data accumulated so far */
148 
149                     if (is_commutative || (dst < group_rank)) {
150                         /* op is commutative OR the order is already right */
151                         mpi_errno = MPIR_Reduce_local(tmp_buf, recvbuf, count, datatype, op);
152                         MPIR_ERR_CHECK(mpi_errno);
153                     } else {
154                         /* op is noncommutative and the order is not right */
155                         mpi_errno = MPIR_Reduce_local(recvbuf, tmp_buf, count, datatype, op);
156                         MPIR_ERR_CHECK(mpi_errno);
157 
158                         /* copy result back into recvbuf */
159                         mpi_errno = MPIR_Localcopy(tmp_buf, count, datatype,
160                                                    recvbuf, count, datatype);
161                         MPIR_ERR_CHECK(mpi_errno);
162                     }
163                 }
164                 mask <<= 1;
165             }
166         }
167 
168         else {
169 
170             /* do a reduce-scatter followed by allgather */
171 
172             /* for the reduce-scatter, calculate the count that
173              * each process receives and the displacement within
174              * the buffer */
175 
176             MPIR_CHKLMEM_MALLOC(cnts, int *, pof2 * sizeof(int), mpi_errno, "counts",
177                                 MPL_MEM_BUFFER);
178             MPIR_CHKLMEM_MALLOC(disps, int *, pof2 * sizeof(int), mpi_errno, "displacements",
179                                 MPL_MEM_BUFFER);
180 
181             for (i = 0; i < (pof2 - 1); i++)
182                 cnts[i] = count / pof2;
183             cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
184 
185             if (pof2)
186                 disps[0] = 0;
187             for (i = 1; i < pof2; i++)
188                 disps[i] = disps[i - 1] + cnts[i - 1];
189 
190             mask = 0x1;
191             send_idx = recv_idx = 0;
192             last_idx = pof2;
193             while (mask < pof2) {
194                 newdst = newrank ^ mask;
195                 /* find real rank of dest */
196                 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
197                 to_comm_rank(cdst, dst);
198 
199                 send_cnt = recv_cnt = 0;
200                 if (newrank < newdst) {
201                     send_idx = recv_idx + pof2 / (mask * 2);
202                     for (i = send_idx; i < last_idx; i++)
203                         send_cnt += cnts[i];
204                     for (i = recv_idx; i < send_idx; i++)
205                         recv_cnt += cnts[i];
206                 } else {
207                     recv_idx = send_idx + pof2 / (mask * 2);
208                     for (i = send_idx; i < recv_idx; i++)
209                         send_cnt += cnts[i];
210                     for (i = recv_idx; i < last_idx; i++)
211                         recv_cnt += cnts[i];
212                 }
213 
214                 /* Send data from recvbuf. Recv into tmp_buf */
215                 mpi_errno = MPIC_Sendrecv((char *) recvbuf +
216                                           disps[send_idx] * extent,
217                                           send_cnt, datatype,
218                                           cdst, tag,
219                                           (char *) tmp_buf +
220                                           disps[recv_idx] * extent,
221                                           recv_cnt, datatype, cdst,
222                                           tag, comm_ptr, MPI_STATUS_IGNORE, errflag);
223                 if (mpi_errno) {
224                     /* for communication errors, just record the error but continue */
225                     *errflag =
226                         MPIX_ERR_PROC_FAILED ==
227                         MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER;
228                     MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
229                     MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
230                 }
231 
232                 /* tmp_buf contains data received in this step.
233                  * recvbuf contains data accumulated so far */
234 
235                 /* This algorithm is used only for predefined ops
236                  * and predefined ops are always commutative. */
237                 mpi_errno = MPIR_Reduce_local(((char *) tmp_buf + disps[recv_idx] * extent),
238                                               ((char *) recvbuf + disps[recv_idx] * extent),
239                                               recv_cnt, datatype, op);
240                 MPIR_ERR_CHECK(mpi_errno);
241 
242                 /* update send_idx for next iteration */
243                 send_idx = recv_idx;
244                 mask <<= 1;
245 
246                 /* update last_idx, but not in last iteration
247                  * because the value is needed in the allgather
248                  * step below. */
249                 if (mask < pof2)
250                     last_idx = recv_idx + pof2 / mask;
251             }
252 
253             /* now do the allgather */
254 
255             mask >>= 1;
256             while (mask > 0) {
257                 newdst = newrank ^ mask;
258                 /* find real rank of dest */
259                 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
260                 to_comm_rank(cdst, dst);
261 
262                 send_cnt = recv_cnt = 0;
263                 if (newrank < newdst) {
264                     /* update last_idx except on first iteration */
265                     if (mask != pof2 / 2)
266                         last_idx = last_idx + pof2 / (mask * 2);
267 
268                     recv_idx = send_idx + pof2 / (mask * 2);
269                     for (i = send_idx; i < recv_idx; i++)
270                         send_cnt += cnts[i];
271                     for (i = recv_idx; i < last_idx; i++)
272                         recv_cnt += cnts[i];
273                 } else {
274                     recv_idx = send_idx - pof2 / (mask * 2);
275                     for (i = send_idx; i < last_idx; i++)
276                         send_cnt += cnts[i];
277                     for (i = recv_idx; i < send_idx; i++)
278                         recv_cnt += cnts[i];
279                 }
280 
281                 mpi_errno = MPIC_Sendrecv((char *) recvbuf +
282                                           disps[send_idx] * extent,
283                                           send_cnt, datatype,
284                                           cdst, tag,
285                                           (char *) recvbuf +
286                                           disps[recv_idx] * extent,
287                                           recv_cnt, datatype, cdst,
288                                           tag, comm_ptr, MPI_STATUS_IGNORE, errflag);
289                 if (mpi_errno) {
290                     /* for communication errors, just record the error but continue */
291                     *errflag =
292                         MPIX_ERR_PROC_FAILED ==
293                         MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER;
294                     MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
295                     MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
296                 }
297 
298                 if (newrank > newdst)
299                     send_idx = recv_idx;
300 
301                 mask >>= 1;
302             }
303         }
304     }
305 
306     /* In the non-power-of-two case, all odd-numbered
307      * processes of rank < 2*rem send the result to
308      * (rank-1), the ranks who didn't participate above. */
309     if (group_rank < 2 * rem) {
310         if (group_rank % 2) {   /* odd */
311             to_comm_rank(cdst, group_rank - 1);
312             mpi_errno = MPIC_Send(recvbuf, count, datatype, cdst, tag, comm_ptr, errflag);
313         } else {        /* even */
314             to_comm_rank(csrc, group_rank + 1);
315             mpi_errno = MPIC_Recv(recvbuf, count,
316                                   datatype, csrc, tag, comm_ptr, MPI_STATUS_IGNORE, errflag);
317         }
318         if (mpi_errno) {
319             /* for communication errors, just record the error but continue */
320             *errflag =
321                 MPIX_ERR_PROC_FAILED ==
322                 MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER;
323             MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
324             MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
325         }
326     }
327 
328   fn_exit:
329     MPIR_CHKLMEM_FREEALL();
330     if (mpi_errno_ret)
331         mpi_errno = mpi_errno_ret;
332     else if (*errflag != MPIR_ERR_NONE)
333         MPIR_ERR_SET(mpi_errno, *errflag, "**coll_fail");
334     return (mpi_errno);
335 
336   fn_fail:
337     goto fn_exit;
338 }
339 
MPII_Allreduce_group(void * sendbuf,void * recvbuf,int count,MPI_Datatype datatype,MPI_Op op,MPIR_Comm * comm_ptr,MPIR_Group * group_ptr,int tag,MPIR_Errflag_t * errflag)340 int MPII_Allreduce_group(void *sendbuf, void *recvbuf, int count,
341                          MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr,
342                          MPIR_Group * group_ptr, int tag, MPIR_Errflag_t * errflag)
343 {
344     int mpi_errno = MPI_SUCCESS;
345 
346     MPIR_ERR_CHKANDJUMP(comm_ptr->comm_kind != MPIR_COMM_KIND__INTRACOMM, mpi_errno, MPI_ERR_OTHER,
347                         "**commnotintra");
348 
349     mpi_errno = MPII_Allreduce_group_intra(sendbuf, recvbuf, count, datatype,
350                                            op, comm_ptr, group_ptr, tag, errflag);
351     MPIR_ERR_CHECK(mpi_errno);
352 
353   fn_exit:
354     return mpi_errno;
355   fn_fail:
356     goto fn_exit;
357 }
358