1 /*
2 * Copyright (C) by Argonne National Laboratory
3 * See COPYRIGHT in top-level directory
4 */
5
6 #include "mpiimpl.h"
7
8 /* Intercommunicator Allgather
9 *
10 * Each group does a gather to local root with the local
11 * intracommunicator, and then does an intercommunicator
12 * broadcast.
13 */
14
MPIR_Allgather_inter_local_gather_remote_bcast(const void * sendbuf,int sendcount,MPI_Datatype sendtype,void * recvbuf,int recvcount,MPI_Datatype recvtype,MPIR_Comm * comm_ptr,MPIR_Errflag_t * errflag)15 int MPIR_Allgather_inter_local_gather_remote_bcast(const void *sendbuf, int sendcount,
16 MPI_Datatype sendtype, void *recvbuf,
17 int recvcount, MPI_Datatype recvtype,
18 MPIR_Comm * comm_ptr, MPIR_Errflag_t * errflag)
19 {
20 int rank, local_size, remote_size, mpi_errno = MPI_SUCCESS, root;
21 int mpi_errno_ret = MPI_SUCCESS;
22 MPI_Aint sendtype_sz;
23 void *tmp_buf = NULL;
24 MPIR_Comm *newcomm_ptr = NULL;
25
26 MPIR_CHKLMEM_DECL(1);
27
28 local_size = comm_ptr->local_size;
29 remote_size = comm_ptr->remote_size;
30 rank = comm_ptr->rank;
31
32 if ((rank == 0) && (sendcount != 0)) {
33 /* In each group, rank 0 allocates temp. buffer for local
34 * gather */
35 MPIR_Datatype_get_size_macro(sendtype, sendtype_sz);
36 MPIR_CHKLMEM_MALLOC(tmp_buf, void *, sendcount * sendtype_sz * local_size, mpi_errno,
37 "tmp_buf", MPL_MEM_BUFFER);
38 } else {
39 /* silence -Wmaybe-uninitialized due to MPIR_{Gather,Bcast} calls by non-zero ranks */
40 sendtype_sz = 0;
41 }
42
43 /* Get the local intracommunicator */
44 if (!comm_ptr->local_comm)
45 MPII_Setup_intercomm_localcomm(comm_ptr);
46
47 newcomm_ptr = comm_ptr->local_comm;
48
49 if (sendcount != 0) {
50 mpi_errno = MPIR_Gather(sendbuf, sendcount, sendtype, tmp_buf, sendcount * sendtype_sz,
51 MPI_BYTE, 0, newcomm_ptr, errflag);
52 if (mpi_errno) {
53 /* for communication errors, just record the error but continue */
54 *errflag =
55 MPIX_ERR_PROC_FAILED ==
56 MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER;
57 MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
58 MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
59 }
60 }
61
62 /* first broadcast from left to right group, then from right to
63 * left group */
64 if (comm_ptr->is_low_group) {
65 /* bcast to right */
66 if (sendcount != 0) {
67 root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL;
68 mpi_errno = MPIR_Bcast(tmp_buf, sendcount * sendtype_sz * local_size,
69 MPI_BYTE, root, comm_ptr, errflag);
70 if (mpi_errno) {
71 /* for communication errors, just record the error but continue */
72 *errflag =
73 MPIX_ERR_PROC_FAILED ==
74 MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER;
75 MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
76 MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
77 }
78 }
79
80 /* receive bcast from right */
81 if (recvcount != 0) {
82 root = 0;
83 mpi_errno = MPIR_Bcast(recvbuf, recvcount * remote_size,
84 recvtype, root, comm_ptr, errflag);
85 if (mpi_errno) {
86 /* for communication errors, just record the error but continue */
87 *errflag =
88 MPIX_ERR_PROC_FAILED ==
89 MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER;
90 MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
91 MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
92 }
93 }
94 } else {
95 /* receive bcast from left */
96 if (recvcount != 0) {
97 root = 0;
98 mpi_errno = MPIR_Bcast(recvbuf, recvcount * remote_size,
99 recvtype, root, comm_ptr, errflag);
100 if (mpi_errno) {
101 /* for communication errors, just record the error but continue */
102 *errflag =
103 MPIX_ERR_PROC_FAILED ==
104 MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER;
105 MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
106 MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
107 }
108 }
109
110 /* bcast to left */
111 if (sendcount != 0) {
112 root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL;
113 mpi_errno = MPIR_Bcast(tmp_buf, sendcount * sendtype_sz * local_size,
114 MPI_BYTE, root, comm_ptr, errflag);
115 if (mpi_errno) {
116 /* for communication errors, just record the error but continue */
117 *errflag =
118 MPIX_ERR_PROC_FAILED ==
119 MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER;
120 MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
121 MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
122 }
123 }
124 }
125
126 fn_exit:
127 MPIR_CHKLMEM_FREEALL();
128 if (mpi_errno_ret)
129 mpi_errno = mpi_errno_ret;
130 else if (*errflag != MPIR_ERR_NONE)
131 MPIR_ERR_SET(mpi_errno, *errflag, "**coll_fail");
132
133 return mpi_errno;
134
135 fn_fail:
136 goto fn_exit;
137 }
138