1 /*
2 * Copyright (C) by Argonne National Laboratory
3 * See COPYRIGHT in top-level directory
4 */
5
6 #include "mpiimpl.h"
7
8 /* Local utility macro: takes an two args and sets lvalue cr_ equal to the rank
9 * in comm_ptr corresponding to rvalue gr_ */
10 #define to_comm_rank(cr_, gr_) \
11 do { \
12 int gr_tmp_ = (gr_); \
13 mpi_errno = MPIR_Group_translate_ranks_impl(group_ptr, 1, &(gr_tmp_), comm_ptr->local_group, &(cr_)); \
14 MPIR_ERR_CHECK(mpi_errno); \
15 MPIR_Assert((cr_) != MPI_UNDEFINED); \
16 } while (0)
17
MPII_Allreduce_group_intra(void * sendbuf,void * recvbuf,int count,MPI_Datatype datatype,MPI_Op op,MPIR_Comm * comm_ptr,MPIR_Group * group_ptr,int tag,MPIR_Errflag_t * errflag)18 int MPII_Allreduce_group_intra(void *sendbuf, void *recvbuf, int count,
19 MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr,
20 MPIR_Group * group_ptr, int tag, MPIR_Errflag_t * errflag)
21 {
22 MPI_Aint type_size;
23 int mpi_errno = MPI_SUCCESS;
24 int mpi_errno_ret = MPI_SUCCESS;
25 /* newrank is a rank in group_ptr */
26 int mask, dst, is_commutative, pof2, newrank, rem, newdst, i,
27 send_idx, recv_idx, last_idx, send_cnt, recv_cnt, *cnts, *disps;
28 MPI_Aint true_extent, true_lb, extent;
29 void *tmp_buf;
30 int group_rank, group_size;
31 int cdst, csrc;
32 MPIR_CHKLMEM_DECL(3);
33
34 group_rank = group_ptr->rank;
35 group_size = group_ptr->size;
36 MPIR_ERR_CHKANDJUMP(group_rank == MPI_UNDEFINED, mpi_errno, MPI_ERR_OTHER, "**rank");
37
38 is_commutative = MPIR_Op_is_commutative(op);
39
40 /* need to allocate temporary buffer to store incoming data */
41 MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent);
42 MPIR_Datatype_get_extent_macro(datatype, extent);
43
44 MPIR_CHKLMEM_MALLOC(tmp_buf, void *, count * (MPL_MAX(extent, true_extent)), mpi_errno,
45 "temporary buffer", MPL_MEM_BUFFER);
46
47 /* adjust for potential negative lower bound in datatype */
48 tmp_buf = (void *) ((char *) tmp_buf - true_lb);
49
50 /* copy local data into recvbuf */
51 if (sendbuf != MPI_IN_PLACE) {
52 mpi_errno = MPIR_Localcopy(sendbuf, count, datatype, recvbuf, count, datatype);
53 MPIR_ERR_CHECK(mpi_errno);
54 }
55
56 MPIR_Datatype_get_size_macro(datatype, type_size);
57
58 /* get nearest power-of-two less than or equal to comm_size */
59 pof2 = MPL_pof2(group_size);
60
61 rem = group_size - pof2;
62
63 /* In the non-power-of-two case, all even-numbered
64 * processes of rank < 2*rem send their data to
65 * (rank+1). These even-numbered processes no longer
66 * participate in the algorithm until the very end. The
67 * remaining processes form a nice power-of-two. */
68
69 if (group_rank < 2 * rem) {
70 if (group_rank % 2 == 0) { /* even */
71 to_comm_rank(cdst, group_rank + 1);
72 mpi_errno = MPIC_Send(recvbuf, count, datatype, cdst, tag, comm_ptr, errflag);
73 if (mpi_errno) {
74 /* for communication errors, just record the error but continue */
75 *errflag =
76 MPIX_ERR_PROC_FAILED ==
77 MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER;
78 MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
79 MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
80 }
81
82 /* temporarily set the rank to -1 so that this
83 * process does not pariticipate in recursive
84 * doubling */
85 newrank = -1;
86 } else { /* odd */
87 to_comm_rank(csrc, group_rank - 1);
88 mpi_errno = MPIC_Recv(tmp_buf, count,
89 datatype, csrc, tag, comm_ptr, MPI_STATUS_IGNORE, errflag);
90 if (mpi_errno) {
91 /* for communication errors, just record the error but continue */
92 *errflag =
93 MPIX_ERR_PROC_FAILED ==
94 MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER;
95 MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
96 MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
97 }
98
99 /* do the reduction on received data. since the
100 * ordering is right, it doesn't matter whether
101 * the operation is commutative or not. */
102 mpi_errno = MPIR_Reduce_local(tmp_buf, recvbuf, count, datatype, op);
103 MPIR_ERR_CHECK(mpi_errno);
104
105 /* change the rank */
106 newrank = group_rank / 2;
107 }
108 } else /* rank >= 2*rem */
109 newrank = group_rank - rem;
110
111 /* If op is user-defined or count is less than pof2, use
112 * recursive doubling algorithm. Otherwise do a reduce-scatter
113 * followed by allgather. (If op is user-defined,
114 * derived datatypes are allowed and the user could pass basic
115 * datatypes on one process and derived on another as long as
116 * the type maps are the same. Breaking up derived
117 * datatypes to do the reduce-scatter is tricky, therefore
118 * using recursive doubling in that case.) */
119
120 if (newrank != -1) {
121 if ((count * type_size <= MPIR_CVAR_ALLREDUCE_SHORT_MSG_SIZE) ||
122 (!HANDLE_IS_BUILTIN(op)) || (count < pof2)) {
123 /* use recursive doubling */
124 mask = 0x1;
125 while (mask < pof2) {
126 newdst = newrank ^ mask;
127 /* find real rank of dest */
128 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
129 to_comm_rank(cdst, dst);
130
131 /* Send the most current data, which is in recvbuf. Recv
132 * into tmp_buf */
133 mpi_errno = MPIC_Sendrecv(recvbuf, count, datatype,
134 cdst, tag, tmp_buf,
135 count, datatype, cdst,
136 tag, comm_ptr, MPI_STATUS_IGNORE, errflag);
137 if (mpi_errno) {
138 /* for communication errors, just record the error but continue */
139 *errflag =
140 MPIX_ERR_PROC_FAILED ==
141 MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER;
142 MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
143 MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
144 } else {
145
146 /* tmp_buf contains data received in this step.
147 * recvbuf contains data accumulated so far */
148
149 if (is_commutative || (dst < group_rank)) {
150 /* op is commutative OR the order is already right */
151 mpi_errno = MPIR_Reduce_local(tmp_buf, recvbuf, count, datatype, op);
152 MPIR_ERR_CHECK(mpi_errno);
153 } else {
154 /* op is noncommutative and the order is not right */
155 mpi_errno = MPIR_Reduce_local(recvbuf, tmp_buf, count, datatype, op);
156 MPIR_ERR_CHECK(mpi_errno);
157
158 /* copy result back into recvbuf */
159 mpi_errno = MPIR_Localcopy(tmp_buf, count, datatype,
160 recvbuf, count, datatype);
161 MPIR_ERR_CHECK(mpi_errno);
162 }
163 }
164 mask <<= 1;
165 }
166 }
167
168 else {
169
170 /* do a reduce-scatter followed by allgather */
171
172 /* for the reduce-scatter, calculate the count that
173 * each process receives and the displacement within
174 * the buffer */
175
176 MPIR_CHKLMEM_MALLOC(cnts, int *, pof2 * sizeof(int), mpi_errno, "counts",
177 MPL_MEM_BUFFER);
178 MPIR_CHKLMEM_MALLOC(disps, int *, pof2 * sizeof(int), mpi_errno, "displacements",
179 MPL_MEM_BUFFER);
180
181 for (i = 0; i < (pof2 - 1); i++)
182 cnts[i] = count / pof2;
183 cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
184
185 if (pof2)
186 disps[0] = 0;
187 for (i = 1; i < pof2; i++)
188 disps[i] = disps[i - 1] + cnts[i - 1];
189
190 mask = 0x1;
191 send_idx = recv_idx = 0;
192 last_idx = pof2;
193 while (mask < pof2) {
194 newdst = newrank ^ mask;
195 /* find real rank of dest */
196 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
197 to_comm_rank(cdst, dst);
198
199 send_cnt = recv_cnt = 0;
200 if (newrank < newdst) {
201 send_idx = recv_idx + pof2 / (mask * 2);
202 for (i = send_idx; i < last_idx; i++)
203 send_cnt += cnts[i];
204 for (i = recv_idx; i < send_idx; i++)
205 recv_cnt += cnts[i];
206 } else {
207 recv_idx = send_idx + pof2 / (mask * 2);
208 for (i = send_idx; i < recv_idx; i++)
209 send_cnt += cnts[i];
210 for (i = recv_idx; i < last_idx; i++)
211 recv_cnt += cnts[i];
212 }
213
214 /* Send data from recvbuf. Recv into tmp_buf */
215 mpi_errno = MPIC_Sendrecv((char *) recvbuf +
216 disps[send_idx] * extent,
217 send_cnt, datatype,
218 cdst, tag,
219 (char *) tmp_buf +
220 disps[recv_idx] * extent,
221 recv_cnt, datatype, cdst,
222 tag, comm_ptr, MPI_STATUS_IGNORE, errflag);
223 if (mpi_errno) {
224 /* for communication errors, just record the error but continue */
225 *errflag =
226 MPIX_ERR_PROC_FAILED ==
227 MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER;
228 MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
229 MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
230 }
231
232 /* tmp_buf contains data received in this step.
233 * recvbuf contains data accumulated so far */
234
235 /* This algorithm is used only for predefined ops
236 * and predefined ops are always commutative. */
237 mpi_errno = MPIR_Reduce_local(((char *) tmp_buf + disps[recv_idx] * extent),
238 ((char *) recvbuf + disps[recv_idx] * extent),
239 recv_cnt, datatype, op);
240 MPIR_ERR_CHECK(mpi_errno);
241
242 /* update send_idx for next iteration */
243 send_idx = recv_idx;
244 mask <<= 1;
245
246 /* update last_idx, but not in last iteration
247 * because the value is needed in the allgather
248 * step below. */
249 if (mask < pof2)
250 last_idx = recv_idx + pof2 / mask;
251 }
252
253 /* now do the allgather */
254
255 mask >>= 1;
256 while (mask > 0) {
257 newdst = newrank ^ mask;
258 /* find real rank of dest */
259 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
260 to_comm_rank(cdst, dst);
261
262 send_cnt = recv_cnt = 0;
263 if (newrank < newdst) {
264 /* update last_idx except on first iteration */
265 if (mask != pof2 / 2)
266 last_idx = last_idx + pof2 / (mask * 2);
267
268 recv_idx = send_idx + pof2 / (mask * 2);
269 for (i = send_idx; i < recv_idx; i++)
270 send_cnt += cnts[i];
271 for (i = recv_idx; i < last_idx; i++)
272 recv_cnt += cnts[i];
273 } else {
274 recv_idx = send_idx - pof2 / (mask * 2);
275 for (i = send_idx; i < last_idx; i++)
276 send_cnt += cnts[i];
277 for (i = recv_idx; i < send_idx; i++)
278 recv_cnt += cnts[i];
279 }
280
281 mpi_errno = MPIC_Sendrecv((char *) recvbuf +
282 disps[send_idx] * extent,
283 send_cnt, datatype,
284 cdst, tag,
285 (char *) recvbuf +
286 disps[recv_idx] * extent,
287 recv_cnt, datatype, cdst,
288 tag, comm_ptr, MPI_STATUS_IGNORE, errflag);
289 if (mpi_errno) {
290 /* for communication errors, just record the error but continue */
291 *errflag =
292 MPIX_ERR_PROC_FAILED ==
293 MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER;
294 MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
295 MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
296 }
297
298 if (newrank > newdst)
299 send_idx = recv_idx;
300
301 mask >>= 1;
302 }
303 }
304 }
305
306 /* In the non-power-of-two case, all odd-numbered
307 * processes of rank < 2*rem send the result to
308 * (rank-1), the ranks who didn't participate above. */
309 if (group_rank < 2 * rem) {
310 if (group_rank % 2) { /* odd */
311 to_comm_rank(cdst, group_rank - 1);
312 mpi_errno = MPIC_Send(recvbuf, count, datatype, cdst, tag, comm_ptr, errflag);
313 } else { /* even */
314 to_comm_rank(csrc, group_rank + 1);
315 mpi_errno = MPIC_Recv(recvbuf, count,
316 datatype, csrc, tag, comm_ptr, MPI_STATUS_IGNORE, errflag);
317 }
318 if (mpi_errno) {
319 /* for communication errors, just record the error but continue */
320 *errflag =
321 MPIX_ERR_PROC_FAILED ==
322 MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER;
323 MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
324 MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
325 }
326 }
327
328 fn_exit:
329 MPIR_CHKLMEM_FREEALL();
330 if (mpi_errno_ret)
331 mpi_errno = mpi_errno_ret;
332 else if (*errflag != MPIR_ERR_NONE)
333 MPIR_ERR_SET(mpi_errno, *errflag, "**coll_fail");
334 return (mpi_errno);
335
336 fn_fail:
337 goto fn_exit;
338 }
339
MPII_Allreduce_group(void * sendbuf,void * recvbuf,int count,MPI_Datatype datatype,MPI_Op op,MPIR_Comm * comm_ptr,MPIR_Group * group_ptr,int tag,MPIR_Errflag_t * errflag)340 int MPII_Allreduce_group(void *sendbuf, void *recvbuf, int count,
341 MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr,
342 MPIR_Group * group_ptr, int tag, MPIR_Errflag_t * errflag)
343 {
344 int mpi_errno = MPI_SUCCESS;
345
346 MPIR_ERR_CHKANDJUMP(comm_ptr->comm_kind != MPIR_COMM_KIND__INTRACOMM, mpi_errno, MPI_ERR_OTHER,
347 "**commnotintra");
348
349 mpi_errno = MPII_Allreduce_group_intra(sendbuf, recvbuf, count, datatype,
350 op, comm_ptr, group_ptr, tag, errflag);
351 MPIR_ERR_CHECK(mpi_errno);
352
353 fn_exit:
354 return mpi_errno;
355 fn_fail:
356 goto fn_exit;
357 }
358