1 /*
2  * Copyright (C) by Argonne National Laboratory
3  *     See COPYRIGHT in top-level directory
4  */
5 
6 #include "mpiimpl.h"
7 
8 /* Local gather remote send
9  *
10  * Remote group does a local intracommunicator gather to rank 0. Rank
11  * 0 then sends data to root.
12  *
13  * Cost: (lgp+1).alpha + n.((p-1)/p).beta + n.beta
14  */
15 
MPIR_Gather_inter_local_gather_remote_send(const void * sendbuf,int sendcount,MPI_Datatype sendtype,void * recvbuf,int recvcount,MPI_Datatype recvtype,int root,MPIR_Comm * comm_ptr,MPIR_Errflag_t * errflag)16 int MPIR_Gather_inter_local_gather_remote_send(const void *sendbuf, int sendcount,
17                                                MPI_Datatype sendtype, void *recvbuf, int recvcount,
18                                                MPI_Datatype recvtype, int root,
19                                                MPIR_Comm * comm_ptr, MPIR_Errflag_t * errflag)
20 {
21     int rank, local_size, remote_size, mpi_errno = MPI_SUCCESS;
22     int mpi_errno_ret = MPI_SUCCESS;
23     MPI_Status status;
24     MPIR_Comm *newcomm_ptr = NULL;
25     MPIR_CHKLMEM_DECL(1);
26 
27     if (root == MPI_PROC_NULL) {
28         /* local processes other than root do nothing */
29         return MPI_SUCCESS;
30     }
31 
32     remote_size = comm_ptr->remote_size;
33     local_size = comm_ptr->local_size;
34 
35     if (root == MPI_ROOT) {
36         /* root receives data from rank 0 on remote group */
37         mpi_errno =
38             MPIC_Recv(recvbuf, recvcount * remote_size, recvtype, 0, MPIR_GATHER_TAG, comm_ptr,
39                       &status, errflag);
40         if (mpi_errno) {
41             /* for communication errors, just record the error but continue */
42             *errflag =
43                 MPIX_ERR_PROC_FAILED ==
44                 MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER;
45             MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
46             MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
47         }
48     } else {
49         /* remote group. Rank 0 allocates temporary buffer, does
50          * local intracommunicator gather, and then sends the data
51          * to root. */
52         MPI_Aint sendtype_sz;
53         void *tmp_buf = NULL;
54 
55         rank = comm_ptr->rank;
56 
57         if (rank == 0) {
58             MPIR_Datatype_get_size_macro(sendtype, sendtype_sz);
59             MPIR_CHKLMEM_MALLOC(tmp_buf, void *,
60                                 sendcount * local_size * sendtype_sz, mpi_errno,
61                                 "tmp_buf", MPL_MEM_BUFFER);
62         } else {
63             /* silence -Wmaybe-uninitialized due to MPIR_Gather by non-zero ranks */
64             sendtype_sz = 0;
65         }
66 
67         /* all processes in remote group form new intracommunicator */
68         if (!comm_ptr->local_comm) {
69             mpi_errno = MPII_Setup_intercomm_localcomm(comm_ptr);
70             MPIR_ERR_CHECK(mpi_errno);
71         }
72 
73         newcomm_ptr = comm_ptr->local_comm;
74 
75         /* now do the a local gather on this intracommunicator */
76         mpi_errno = MPIR_Gather(sendbuf, sendcount, sendtype,
77                                 tmp_buf, sendcount * sendtype_sz, MPI_BYTE, 0, newcomm_ptr,
78                                 errflag);
79         if (mpi_errno) {
80             /* for communication errors, just record the error but continue */
81             *errflag =
82                 MPIX_ERR_PROC_FAILED ==
83                 MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER;
84             MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
85             MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
86         }
87 
88         if (rank == 0) {
89             mpi_errno = MPIC_Send(tmp_buf, sendcount * local_size * sendtype_sz, MPI_BYTE,
90                                   root, MPIR_GATHER_TAG, comm_ptr, errflag);
91             if (mpi_errno) {
92                 /* for communication errors, just record the error but continue */
93                 *errflag =
94                     MPIX_ERR_PROC_FAILED ==
95                     MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER;
96                 MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
97                 MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
98             }
99         }
100     }
101 
102   fn_exit:
103     MPIR_CHKLMEM_FREEALL();
104     if (mpi_errno_ret)
105         mpi_errno = mpi_errno_ret;
106     else if (*errflag != MPIR_ERR_NONE)
107         MPIR_ERR_SET(mpi_errno, *errflag, "**coll_fail");
108     return mpi_errno;
109 
110   fn_fail:
111     goto fn_exit;
112 }
113