1 /*
2  * Copyright (C) by Argonne National Laboratory
3  *     See COPYRIGHT in top-level directory
4  */
5 
6 #include "mpiimpl.h"
7 
8 /*
9 === BEGIN_MPI_T_CVAR_INFO_BLOCK ===
10 
11 cvars:
12     - name        : MPIR_CVAR_NEIGHBOR_ALLTOALL_INTRA_ALGORITHM
13       category    : COLLECTIVE
14       type        : enum
15       default     : auto
16       class       : none
17       verbosity   : MPI_T_VERBOSITY_USER_BASIC
18       scope       : MPI_T_SCOPE_ALL_EQ
19       description : |-
20         Variable to select neighbor_alltoall algorithm
21         auto - Internal algorithm selection (can be overridden with MPIR_CVAR_COLL_SELECTION_TUNING_JSON_FILE)
22         nb   - Force nb algorithm
23 
24     - name        : MPIR_CVAR_NEIGHBOR_ALLTOALL_INTER_ALGORITHM
25       category    : COLLECTIVE
26       type        : enum
27       default     : auto
28       class       : none
29       verbosity   : MPI_T_VERBOSITY_USER_BASIC
30       scope       : MPI_T_SCOPE_ALL_EQ
31       description : |-
32         Variable to select neighbor_alltoall algorithm
33         auto - Internal algorithm selection (can be overridden with MPIR_CVAR_COLL_SELECTION_TUNING_JSON_FILE)
34         nb   - Force nb algorithm
35 
36     - name        : MPIR_CVAR_NEIGHBOR_ALLTOALL_DEVICE_COLLECTIVE
37       category    : COLLECTIVE
38       type        : boolean
39       default     : true
40       class       : none
41       verbosity   : MPI_T_VERBOSITY_USER_BASIC
42       scope       : MPI_T_SCOPE_ALL_EQ
43       description : >-
44         This CVAR is only used when MPIR_CVAR_DEVICE_COLLECTIVES
45         is set to "percoll".  If set to true, MPI_Neighbor_alltoall will
46         allow the device to override the MPIR-level collective
47         algorithms.  The device might still call the MPIR-level
48         algorithms manually.  If set to false, the device-override
49         will be disabled.
50 
51 === END_MPI_T_CVAR_INFO_BLOCK ===
52 */
53 
54 /* -- Begin Profiling Symbol Block for routine MPI_Neighbor_alltoall */
55 #if defined(HAVE_PRAGMA_WEAK)
56 #pragma weak MPI_Neighbor_alltoall = PMPI_Neighbor_alltoall
57 #elif defined(HAVE_PRAGMA_HP_SEC_DEF)
58 #pragma _HP_SECONDARY_DEF PMPI_Neighbor_alltoall  MPI_Neighbor_alltoall
59 #elif defined(HAVE_PRAGMA_CRI_DUP)
60 #pragma _CRI duplicate MPI_Neighbor_alltoall as PMPI_Neighbor_alltoall
61 #elif defined(HAVE_WEAK_ATTRIBUTE)
62 int MPI_Neighbor_alltoall(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
63                           void *recvbuf, int recvcount, MPI_Datatype recvtype, MPI_Comm comm)
64     __attribute__ ((weak, alias("PMPI_Neighbor_alltoall")));
65 #endif
66 /* -- End Profiling Symbol Block */
67 
68 /* Define MPICH_MPI_FROM_PMPI if weak symbols are not supported to build
69    the MPI routines */
70 #ifndef MPICH_MPI_FROM_PMPI
71 #undef MPI_Neighbor_alltoall
72 #define MPI_Neighbor_alltoall PMPI_Neighbor_alltoall
73 
74 /* any non-MPI functions go here, especially non-static ones */
75 
76 
MPIR_Neighbor_alltoall_allcomm_auto(const void * sendbuf,int sendcount,MPI_Datatype sendtype,void * recvbuf,int recvcount,MPI_Datatype recvtype,MPIR_Comm * comm_ptr)77 int MPIR_Neighbor_alltoall_allcomm_auto(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
78                                         void *recvbuf, int recvcount, MPI_Datatype recvtype,
79                                         MPIR_Comm * comm_ptr)
80 {
81     int mpi_errno = MPI_SUCCESS;
82 
83     MPIR_Csel_coll_sig_s coll_sig = {
84         .coll_type = MPIR_CSEL_COLL_TYPE__NEIGHBOR_ALLTOALL,
85         .comm_ptr = comm_ptr,
86 
87         .u.neighbor_alltoall.sendbuf = sendbuf,
88         .u.neighbor_alltoall.sendcount = sendcount,
89         .u.neighbor_alltoall.sendtype = sendtype,
90         .u.neighbor_alltoall.recvcount = recvcount,
91         .u.neighbor_alltoall.recvbuf = recvbuf,
92         .u.neighbor_alltoall.recvtype = recvtype,
93     };
94 
95     MPII_Csel_container_s *cnt = MPIR_Csel_search(comm_ptr->csel_comm, coll_sig);
96     MPIR_Assert(cnt);
97 
98     switch (cnt->id) {
99         case MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Neighbor_alltoall_allcomm_nb:
100             mpi_errno =
101                 MPIR_Neighbor_alltoall_allcomm_nb(sendbuf, sendcount, sendtype, recvbuf, recvcount,
102                                                   recvtype, comm_ptr);
103             break;
104 
105         default:
106             MPIR_Assert(0);
107     }
108 
109     return mpi_errno;
110 }
111 
MPIR_Neighbor_alltoall_impl(const void * sendbuf,int sendcount,MPI_Datatype sendtype,void * recvbuf,int recvcount,MPI_Datatype recvtype,MPIR_Comm * comm_ptr)112 int MPIR_Neighbor_alltoall_impl(const void *sendbuf, int sendcount,
113                                 MPI_Datatype sendtype, void *recvbuf,
114                                 int recvcount, MPI_Datatype recvtype, MPIR_Comm * comm_ptr)
115 {
116     int mpi_errno = MPI_SUCCESS;
117 
118     if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) {
119         switch (MPIR_CVAR_NEIGHBOR_ALLTOALL_INTRA_ALGORITHM) {
120             case MPIR_CVAR_NEIGHBOR_ALLTOALL_INTRA_ALGORITHM_nb:
121                 mpi_errno = MPIR_Neighbor_alltoall_allcomm_nb(sendbuf, sendcount, sendtype,
122                                                               recvbuf, recvcount, recvtype,
123                                                               comm_ptr);
124                 break;
125             case MPIR_CVAR_NEIGHBOR_ALLTOALL_INTRA_ALGORITHM_auto:
126                 mpi_errno = MPIR_Neighbor_alltoall_allcomm_auto(sendbuf, sendcount, sendtype,
127                                                                 recvbuf, recvcount, recvtype,
128                                                                 comm_ptr);
129                 break;
130             default:
131                 MPIR_Assert(0);
132         }
133     } else {
134         switch (MPIR_CVAR_NEIGHBOR_ALLTOALL_INTER_ALGORITHM) {
135             case MPIR_CVAR_NEIGHBOR_ALLTOALL_INTER_ALGORITHM_nb:
136                 mpi_errno = MPIR_Neighbor_alltoall_allcomm_nb(sendbuf, sendcount, sendtype,
137                                                               recvbuf, recvcount, recvtype,
138                                                               comm_ptr);
139                 break;
140             case MPIR_CVAR_NEIGHBOR_ALLTOALL_INTER_ALGORITHM_auto:
141                 mpi_errno = MPIR_Neighbor_alltoall_allcomm_auto(sendbuf, sendcount, sendtype,
142                                                                 recvbuf, recvcount, recvtype,
143                                                                 comm_ptr);
144                 break;
145             default:
146                 MPIR_Assert(0);
147         }
148     }
149     MPIR_ERR_CHECK(mpi_errno);
150 
151   fn_exit:
152     return mpi_errno;
153   fn_fail:
154     goto fn_exit;
155 }
156 
MPIR_Neighbor_alltoall(const void * sendbuf,int sendcount,MPI_Datatype sendtype,void * recvbuf,int recvcount,MPI_Datatype recvtype,MPIR_Comm * comm_ptr)157 int MPIR_Neighbor_alltoall(const void *sendbuf, int sendcount,
158                            MPI_Datatype sendtype, void *recvbuf, int recvcount,
159                            MPI_Datatype recvtype, MPIR_Comm * comm_ptr)
160 {
161     int mpi_errno = MPI_SUCCESS;
162 
163     if ((MPIR_CVAR_DEVICE_COLLECTIVES == MPIR_CVAR_DEVICE_COLLECTIVES_all) ||
164         ((MPIR_CVAR_DEVICE_COLLECTIVES == MPIR_CVAR_DEVICE_COLLECTIVES_percoll) &&
165          MPIR_CVAR_BARRIER_DEVICE_COLLECTIVE)) {
166         mpi_errno =
167             MPID_Neighbor_alltoall(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype,
168                                    comm_ptr);
169     } else {
170         mpi_errno = MPIR_Neighbor_alltoall_impl(sendbuf, sendcount, sendtype,
171                                                 recvbuf, recvcount, recvtype, comm_ptr);
172     }
173 
174     return mpi_errno;
175 }
176 
177 #endif /* MPICH_MPI_FROM_PMPI */
178 
179 /*@
180 MPI_Neighbor_alltoall - In this function, each process i receives data items
181 from each process j if an edge (j,i) exists in the topology graph or Cartesian
182 topology.  Similarly, each process i sends data items to all processes j where an
183 edge (i,j) exists. This call is more general than MPI_NEIGHBOR_ALLGATHER in that
184 different data items can be sent to each neighbor. The k-th block in send buffer
185 is sent to the k-th neighboring process and the l-th block in the receive buffer
186 is received from the l-th neighbor.
187 
188 Input Parameters:
189 + sendbuf - starting address of the send buffer (choice)
190 . sendcount - number of elements sent to each neighbor (non-negative integer)
191 . sendtype - data type of send buffer elements (handle)
192 . recvcount - number of elements received from each neighbor (non-negative integer)
193 . recvtype - data type of receive buffer elements (handle)
194 - comm - communicator (handle)
195 
196 Output Parameters:
197 . recvbuf - starting address of the receive buffer (choice)
198 
199 .N ThreadSafe
200 
201 .N Fortran
202 
203 .N Errors
204 @*/
MPI_Neighbor_alltoall(const void * sendbuf,int sendcount,MPI_Datatype sendtype,void * recvbuf,int recvcount,MPI_Datatype recvtype,MPI_Comm comm)205 int MPI_Neighbor_alltoall(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf,
206                           int recvcount, MPI_Datatype recvtype, MPI_Comm comm)
207 {
208     int mpi_errno = MPI_SUCCESS;
209     MPIR_Comm *comm_ptr = NULL;
210     MPIR_FUNC_TERSE_STATE_DECL(MPID_STATE_MPI_NEIGHBOR_ALLTOALL);
211 
212     MPID_THREAD_CS_ENTER(GLOBAL, MPIR_THREAD_GLOBAL_ALLFUNC_MUTEX);
213     MPIR_FUNC_TERSE_ENTER(MPID_STATE_MPI_NEIGHBOR_ALLTOALL);
214 
215     /* Validate parameters, especially handles needing to be converted */
216 #ifdef HAVE_ERROR_CHECKING
217     {
218         MPID_BEGIN_ERROR_CHECKS;
219         {
220             MPIR_ERRTEST_DATATYPE(sendtype, "sendtype", mpi_errno);
221             MPIR_ERRTEST_DATATYPE(recvtype, "recvtype", mpi_errno);
222             MPIR_ERRTEST_COMM(comm, mpi_errno);
223 
224             /* TODO more checks may be appropriate */
225         }
226         MPID_END_ERROR_CHECKS;
227     }
228 #endif /* HAVE_ERROR_CHECKING */
229 
230     /* Convert MPI object handles to object pointers */
231     MPIR_Comm_get_ptr(comm, comm_ptr);
232 
233     /* Validate parameters and objects (post conversion) */
234 #ifdef HAVE_ERROR_CHECKING
235     {
236         MPID_BEGIN_ERROR_CHECKS;
237         {
238             if (!HANDLE_IS_BUILTIN(sendtype)) {
239                 MPIR_Datatype *sendtype_ptr = NULL;
240                 MPIR_Datatype_get_ptr(sendtype, sendtype_ptr);
241                 MPIR_Datatype_valid_ptr(sendtype_ptr, mpi_errno);
242                 MPIR_Datatype_committed_ptr(sendtype_ptr, mpi_errno);
243             }
244 
245             if (!HANDLE_IS_BUILTIN(recvtype)) {
246                 MPIR_Datatype *recvtype_ptr = NULL;
247                 MPIR_Datatype_get_ptr(recvtype, recvtype_ptr);
248                 MPIR_Datatype_valid_ptr(recvtype_ptr, mpi_errno);
249                 MPIR_Datatype_committed_ptr(recvtype_ptr, mpi_errno);
250             }
251 
252             MPIR_Comm_valid_ptr(comm_ptr, mpi_errno, FALSE);
253             /* TODO more checks may be appropriate (counts, in_place, buffer aliasing, etc) */
254             if (mpi_errno != MPI_SUCCESS)
255                 goto fn_fail;
256         }
257         MPID_END_ERROR_CHECKS;
258     }
259 #endif /* HAVE_ERROR_CHECKING */
260 
261     /* ... body of routine ...  */
262 
263     mpi_errno = MPIR_Neighbor_alltoall(sendbuf, sendcount, sendtype, recvbuf,
264                                        recvcount, recvtype, comm_ptr);
265     MPIR_ERR_CHECK(mpi_errno);
266 
267     /* ... end of body of routine ... */
268 
269   fn_exit:
270     MPIR_FUNC_TERSE_EXIT(MPID_STATE_MPI_NEIGHBOR_ALLTOALL);
271     MPID_THREAD_CS_EXIT(GLOBAL, MPIR_THREAD_GLOBAL_ALLFUNC_MUTEX);
272     return mpi_errno;
273 
274   fn_fail:
275     /* --BEGIN ERROR HANDLING-- */
276 #ifdef HAVE_ERROR_CHECKING
277     {
278         mpi_errno =
279             MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, __func__, __LINE__, MPI_ERR_OTHER,
280                                  "**mpi_neighbor_alltoall",
281                                  "**mpi_neighbor_alltoall %p %d %D %p %d %D %C", sendbuf, sendcount,
282                                  sendtype, recvbuf, recvcount, recvtype, comm);
283     }
284 #endif
285     mpi_errno = MPIR_Err_return_comm(NULL, __func__, mpi_errno);
286     goto fn_exit;
287     /* --END ERROR HANDLING-- */
288 }
289