1 /*
2 * Copyright (C) by Argonne National Laboratory
3 * See COPYRIGHT in top-level directory
4 */
5
6 #include "mpiimpl.h"
7
8 /* Algorithm: Noncommutative recursive doubling
9 *
10 * Restrictions: This function currently only implements support for the
11 * power-of-2, block-regular case (all receive counts are equal).
12 *
13 * Implements the reduce-scatter butterfly algorithm described in J. L. Traff's
14 * "An Improved Algorithm for (Non-commutative) Reduce-Scatter with an
15 * Application" from EuroPVM/MPI 2005.
16 *
17 * It takes lgp steps. At step 1, processes exchange (n-n/p) amount of
18 * data; at step 2, (n-2n/p) amount of data; at step 3, (n-4n/p)
19 * amount of data, and so forth.
20 *
21 * Cost = lgp.alpha + n.(lgp-(p-1)/p).beta + n.(lgp-(p-1)/p).gamma
22 */
MPIR_Reduce_scatter_intra_noncommutative(const void * sendbuf,void * recvbuf,const int recvcounts[],MPI_Datatype datatype,MPI_Op op,MPIR_Comm * comm_ptr,MPIR_Errflag_t * errflag)23 int MPIR_Reduce_scatter_intra_noncommutative(const void *sendbuf, void *recvbuf,
24 const int recvcounts[], MPI_Datatype datatype,
25 MPI_Op op, MPIR_Comm * comm_ptr,
26 MPIR_Errflag_t * errflag)
27 {
28 int mpi_errno = MPI_SUCCESS;
29 int mpi_errno_ret = MPI_SUCCESS;
30 int comm_size = comm_ptr->local_size;
31 int rank = comm_ptr->rank;
32 int pof2;
33 int log2_comm_size;
34 int i, k;
35 int recv_offset, send_offset;
36 int block_size, total_count, size;
37 MPI_Aint true_extent, true_lb;
38 int buf0_was_inout;
39 void *tmp_buf0;
40 void *tmp_buf1;
41 void *result_ptr;
42 MPIR_CHKLMEM_DECL(3);
43
44 MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent);
45
46 pof2 = 1;
47 log2_comm_size = 0;
48 while (pof2 < comm_size) {
49 pof2 <<= 1;
50 ++log2_comm_size;
51 }
52
53 #ifdef HAVE_ERROR_CHECKING
54 /* begin error checking */
55 MPIR_Assert(pof2 == comm_size); /* FIXME this version only works for power of 2 procs */
56
57 for (i = 0; i < (comm_size - 1); ++i) {
58 MPIR_Assert(recvcounts[i] == recvcounts[i + 1]);
59 }
60 /* end error checking */
61 #endif
62
63 /* size of a block (count of datatype per block, NOT bytes per block) */
64 block_size = recvcounts[0];
65 total_count = block_size * comm_size;
66
67 MPIR_CHKLMEM_MALLOC(tmp_buf0, void *, true_extent * total_count, mpi_errno, "tmp_buf0",
68 MPL_MEM_BUFFER);
69 MPIR_CHKLMEM_MALLOC(tmp_buf1, void *, true_extent * total_count, mpi_errno, "tmp_buf1",
70 MPL_MEM_BUFFER);
71 /* adjust for potential negative lower bound in datatype */
72 tmp_buf0 = (void *) ((char *) tmp_buf0 - true_lb);
73 tmp_buf1 = (void *) ((char *) tmp_buf1 - true_lb);
74
75 /* Copy our send data to tmp_buf0. We do this one block at a time and
76 * permute the blocks as we go according to the mirror permutation. */
77 for (i = 0; i < comm_size; ++i) {
78 mpi_errno =
79 MPIR_Localcopy((char *) (sendbuf ==
80 MPI_IN_PLACE ? (const void *) recvbuf : sendbuf) +
81 (i * true_extent * block_size), block_size, datatype,
82 (char *) tmp_buf0 +
83 (MPL_mirror_permutation(i, log2_comm_size) * true_extent * block_size),
84 block_size, datatype);
85 MPIR_ERR_CHECK(mpi_errno);
86 }
87 buf0_was_inout = 1;
88
89 send_offset = 0;
90 recv_offset = 0;
91 size = total_count;
92 for (k = 0; k < log2_comm_size; ++k) {
93 /* use a double-buffering scheme to avoid local copies */
94 char *incoming_data = (buf0_was_inout ? tmp_buf1 : tmp_buf0);
95 char *outgoing_data = (buf0_was_inout ? tmp_buf0 : tmp_buf1);
96 int peer = rank ^ (0x1 << k);
97 size /= 2;
98
99 if (rank > peer) {
100 /* we have the higher rank: send top half, recv bottom half */
101 recv_offset += size;
102 } else {
103 /* we have the lower rank: recv top half, send bottom half */
104 send_offset += size;
105 }
106
107 mpi_errno = MPIC_Sendrecv(outgoing_data + send_offset * true_extent,
108 size, datatype, peer, MPIR_REDUCE_SCATTER_TAG,
109 incoming_data + recv_offset * true_extent,
110 size, datatype, peer, MPIR_REDUCE_SCATTER_TAG,
111 comm_ptr, MPI_STATUS_IGNORE, errflag);
112 if (mpi_errno) {
113 /* for communication errors, just record the error but continue */
114 *errflag =
115 MPIX_ERR_PROC_FAILED ==
116 MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER;
117 MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
118 MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
119 }
120 /* always perform the reduction at recv_offset, the data at send_offset
121 * is now our peer's responsibility */
122 if (rank > peer) {
123 /* higher ranked value so need to call op(received_data, my_data) */
124 mpi_errno = MPIR_Reduce_local(incoming_data + recv_offset * true_extent,
125 outgoing_data + recv_offset * true_extent,
126 size, datatype, op);
127 MPIR_ERR_CHECK(mpi_errno);
128 } else {
129 /* lower ranked value so need to call op(my_data, received_data) */
130 MPIR_Reduce_local(outgoing_data + recv_offset * true_extent,
131 incoming_data + recv_offset * true_extent, size, datatype, op);
132 MPIR_ERR_CHECK(mpi_errno);
133 buf0_was_inout = !buf0_was_inout;
134 }
135
136 /* the next round of send/recv needs to happen within the block (of size
137 * "size") that we just received and reduced */
138 send_offset = recv_offset;
139 }
140
141 MPIR_Assert(size == recvcounts[rank]);
142
143 /* copy the reduced data to the recvbuf */
144 result_ptr = (char *) (buf0_was_inout ? tmp_buf0 : tmp_buf1) + recv_offset * true_extent;
145 mpi_errno = MPIR_Localcopy(result_ptr, size, datatype, recvbuf, size, datatype);
146 MPIR_ERR_CHECK(mpi_errno);
147
148 fn_exit:
149 MPIR_CHKLMEM_FREEALL();
150 if (mpi_errno_ret)
151 mpi_errno = mpi_errno_ret;
152 else if (*errflag != MPIR_ERR_NONE)
153 MPIR_ERR_SET(mpi_errno, *errflag, "**coll_fail");
154 return mpi_errno;
155 fn_fail:
156 goto fn_exit;
157 }
158