1 /*
2  * Copyright (C) by Argonne National Laboratory
3  *     See COPYRIGHT in top-level directory
4  */
5 
6 #include "mpiimpl.h"
7 
8 /* Intercommunicator Allgather
9  *
10  * Each group does a gather to local root with the local
11  * intracommunicator, and then does an intercommunicator
12  * broadcast.
13  */
14 
MPIR_Allgather_inter_local_gather_remote_bcast(const void * sendbuf,int sendcount,MPI_Datatype sendtype,void * recvbuf,int recvcount,MPI_Datatype recvtype,MPIR_Comm * comm_ptr,MPIR_Errflag_t * errflag)15 int MPIR_Allgather_inter_local_gather_remote_bcast(const void *sendbuf, int sendcount,
16                                                    MPI_Datatype sendtype, void *recvbuf,
17                                                    int recvcount, MPI_Datatype recvtype,
18                                                    MPIR_Comm * comm_ptr, MPIR_Errflag_t * errflag)
19 {
20     int rank, local_size, remote_size, mpi_errno = MPI_SUCCESS, root;
21     int mpi_errno_ret = MPI_SUCCESS;
22     MPI_Aint sendtype_sz;
23     void *tmp_buf = NULL;
24     MPIR_Comm *newcomm_ptr = NULL;
25 
26     MPIR_CHKLMEM_DECL(1);
27 
28     local_size = comm_ptr->local_size;
29     remote_size = comm_ptr->remote_size;
30     rank = comm_ptr->rank;
31 
32     if ((rank == 0) && (sendcount != 0)) {
33         /* In each group, rank 0 allocates temp. buffer for local
34          * gather */
35         MPIR_Datatype_get_size_macro(sendtype, sendtype_sz);
36         MPIR_CHKLMEM_MALLOC(tmp_buf, void *, sendcount * sendtype_sz * local_size, mpi_errno,
37                             "tmp_buf", MPL_MEM_BUFFER);
38     } else {
39         /* silence -Wmaybe-uninitialized due to MPIR_{Gather,Bcast} calls by non-zero ranks */
40         sendtype_sz = 0;
41     }
42 
43     /* Get the local intracommunicator */
44     if (!comm_ptr->local_comm)
45         MPII_Setup_intercomm_localcomm(comm_ptr);
46 
47     newcomm_ptr = comm_ptr->local_comm;
48 
49     if (sendcount != 0) {
50         mpi_errno = MPIR_Gather(sendbuf, sendcount, sendtype, tmp_buf, sendcount * sendtype_sz,
51                                 MPI_BYTE, 0, newcomm_ptr, errflag);
52         if (mpi_errno) {
53             /* for communication errors, just record the error but continue */
54             *errflag =
55                 MPIX_ERR_PROC_FAILED ==
56                 MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER;
57             MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
58             MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
59         }
60     }
61 
62     /* first broadcast from left to right group, then from right to
63      * left group */
64     if (comm_ptr->is_low_group) {
65         /* bcast to right */
66         if (sendcount != 0) {
67             root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL;
68             mpi_errno = MPIR_Bcast(tmp_buf, sendcount * sendtype_sz * local_size,
69                                    MPI_BYTE, root, comm_ptr, errflag);
70             if (mpi_errno) {
71                 /* for communication errors, just record the error but continue */
72                 *errflag =
73                     MPIX_ERR_PROC_FAILED ==
74                     MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER;
75                 MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
76                 MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
77             }
78         }
79 
80         /* receive bcast from right */
81         if (recvcount != 0) {
82             root = 0;
83             mpi_errno = MPIR_Bcast(recvbuf, recvcount * remote_size,
84                                    recvtype, root, comm_ptr, errflag);
85             if (mpi_errno) {
86                 /* for communication errors, just record the error but continue */
87                 *errflag =
88                     MPIX_ERR_PROC_FAILED ==
89                     MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER;
90                 MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
91                 MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
92             }
93         }
94     } else {
95         /* receive bcast from left */
96         if (recvcount != 0) {
97             root = 0;
98             mpi_errno = MPIR_Bcast(recvbuf, recvcount * remote_size,
99                                    recvtype, root, comm_ptr, errflag);
100             if (mpi_errno) {
101                 /* for communication errors, just record the error but continue */
102                 *errflag =
103                     MPIX_ERR_PROC_FAILED ==
104                     MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER;
105                 MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
106                 MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
107             }
108         }
109 
110         /* bcast to left */
111         if (sendcount != 0) {
112             root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL;
113             mpi_errno = MPIR_Bcast(tmp_buf, sendcount * sendtype_sz * local_size,
114                                    MPI_BYTE, root, comm_ptr, errflag);
115             if (mpi_errno) {
116                 /* for communication errors, just record the error but continue */
117                 *errflag =
118                     MPIX_ERR_PROC_FAILED ==
119                     MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER;
120                 MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
121                 MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
122             }
123         }
124     }
125 
126   fn_exit:
127     MPIR_CHKLMEM_FREEALL();
128     if (mpi_errno_ret)
129         mpi_errno = mpi_errno_ret;
130     else if (*errflag != MPIR_ERR_NONE)
131         MPIR_ERR_SET(mpi_errno, *errflag, "**coll_fail");
132 
133     return mpi_errno;
134 
135   fn_fail:
136     goto fn_exit;
137 }
138