1 /*
2  * Copyright (C) by Argonne National Laboratory
3  *     See COPYRIGHT in top-level directory
4  */
5 
6 #include "mpiimpl.h"
7 
8 /*
9 === BEGIN_MPI_T_CVAR_INFO_BLOCK ===
10 
11 cvars:
12     - name        : MPIR_CVAR_INEIGHBOR_ALLTOALLW_INTRA_ALGORITHM
13       category    : COLLECTIVE
14       type        : enum
15       default     : auto
16       class       : none
17       verbosity   : MPI_T_VERBOSITY_USER_BASIC
18       scope       : MPI_T_SCOPE_ALL_EQ
19       description : |-
20         Variable to select ineighbor_alltoallw algorithm
21         auto - Internal algorithm selection (can be overridden with MPIR_CVAR_COLL_SELECTION_TUNING_JSON_FILE)
22         sched_auto - Internal algorithm selection for sched-based algorithms
23         sched_linear          - Force linear algorithm
24         gentran_linear        - Force generic transport based linear algorithm
25 
26     - name        : MPIR_CVAR_INEIGHBOR_ALLTOALLW_INTER_ALGORITHM
27       category    : COLLECTIVE
28       type        : enum
29       default     : auto
30       class       : none
31       verbosity   : MPI_T_VERBOSITY_USER_BASIC
32       scope       : MPI_T_SCOPE_ALL_EQ
33       description : |-
34         Variable to select ineighbor_alltoallw algorithm
35         auto - Internal algorithm selection (can be overridden with MPIR_CVAR_COLL_SELECTION_TUNING_JSON_FILE)
36         sched_auto - Internal algorithm selection for sched-based algorithms
37         sched_linear          - Force linear algorithm
38         gentran_linear        - Force generic transport based linear algorithm
39 
40     - name        : MPIR_CVAR_INEIGHBOR_ALLTOALLW_DEVICE_COLLECTIVE
41       category    : COLLECTIVE
42       type        : boolean
43       default     : true
44       class       : none
45       verbosity   : MPI_T_VERBOSITY_USER_BASIC
46       scope       : MPI_T_SCOPE_ALL_EQ
47       description : >-
48         This CVAR is only used when MPIR_CVAR_DEVICE_COLLECTIVES
49         is set to "percoll".  If set to true, MPI_Ineighbor_alltoallw will
50         allow the device to override the MPIR-level collective
51         algorithms.  The device might still call the MPIR-level
52         algorithms manually.  If set to false, the device-override
53         will be disabled.
54 
55 === END_MPI_T_CVAR_INFO_BLOCK ===
56 */
57 
58 /* -- Begin Profiling Symbol Block for routine MPI_Ineighbor_alltoallw */
59 #if defined(HAVE_PRAGMA_WEAK)
60 #pragma weak MPI_Ineighbor_alltoallw = PMPI_Ineighbor_alltoallw
61 #elif defined(HAVE_PRAGMA_HP_SEC_DEF)
62 #pragma _HP_SECONDARY_DEF PMPI_Ineighbor_alltoallw  MPI_Ineighbor_alltoallw
63 #elif defined(HAVE_PRAGMA_CRI_DUP)
64 #pragma _CRI duplicate MPI_Ineighbor_alltoallw as PMPI_Ineighbor_alltoallw
65 #elif defined(HAVE_WEAK_ATTRIBUTE)
66 int MPI_Ineighbor_alltoallw(const void *sendbuf, const int sendcounts[],
67                             const MPI_Aint sdispls[], const MPI_Datatype sendtypes[],
68                             void *recvbuf, const int recvcounts[], const MPI_Aint rdispls[],
69                             const MPI_Datatype recvtypes[], MPI_Comm comm, MPI_Request * request)
70     __attribute__ ((weak, alias("PMPI_Ineighbor_alltoallw")));
71 #endif
72 /* -- End Profiling Symbol Block */
73 
74 /* Define MPICH_MPI_FROM_PMPI if weak symbols are not supported to build
75    the MPI routines */
76 #ifndef MPICH_MPI_FROM_PMPI
77 #undef MPI_Ineighbor_alltoallw
78 #define MPI_Ineighbor_alltoallw PMPI_Ineighbor_alltoallw
79 
80 
MPIR_Ineighbor_alltoallw_allcomm_auto(const void * sendbuf,const int sendcounts[],const MPI_Aint sdispls[],const MPI_Datatype sendtypes[],void * recvbuf,const int recvcounts[],const MPI_Aint rdispls[],const MPI_Datatype recvtypes[],MPIR_Comm * comm_ptr,MPIR_Request ** request)81 int MPIR_Ineighbor_alltoallw_allcomm_auto(const void *sendbuf, const int sendcounts[],
82                                           const MPI_Aint sdispls[], const MPI_Datatype sendtypes[],
83                                           void *recvbuf, const int recvcounts[],
84                                           const MPI_Aint rdispls[], const MPI_Datatype recvtypes[],
85                                           MPIR_Comm * comm_ptr, MPIR_Request ** request)
86 {
87     int mpi_errno = MPI_SUCCESS;
88 
89     MPIR_Csel_coll_sig_s coll_sig = {
90         .coll_type = MPIR_CSEL_COLL_TYPE__INEIGHBOR_ALLTOALLW,
91         .comm_ptr = comm_ptr,
92 
93         .u.ineighbor_alltoallw.sendbuf = sendbuf,
94         .u.ineighbor_alltoallw.sendcounts = sendcounts,
95         .u.ineighbor_alltoallw.sdispls = sdispls,
96         .u.ineighbor_alltoallw.sendtypes = sendtypes,
97         .u.ineighbor_alltoallw.recvbuf = recvbuf,
98         .u.ineighbor_alltoallw.recvcounts = recvcounts,
99         .u.ineighbor_alltoallw.rdispls = rdispls,
100         .u.ineighbor_alltoallw.recvtypes = recvtypes,
101     };
102 
103     MPII_Csel_container_s *cnt = MPIR_Csel_search(comm_ptr->csel_comm, coll_sig);
104     MPIR_Assert(cnt);
105 
106     switch (cnt->id) {
107         case MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Ineighbor_alltoallw_allcomm_gentran_linear:
108             mpi_errno =
109                 MPIR_Ineighbor_alltoallw_allcomm_gentran_linear(sendbuf, sendcounts, sdispls,
110                                                                 sendtypes, recvbuf, recvcounts,
111                                                                 rdispls, recvtypes, comm_ptr,
112                                                                 request);
113             break;
114 
115         case MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Ineighbor_alltoallw_intra_sched_auto:
116             MPII_SCHED_WRAPPER(MPIR_Ineighbor_alltoallw_intra_sched_auto, comm_ptr, request,
117                                sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts,
118                                rdispls, recvtypes);
119             break;
120 
121         case MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Ineighbor_alltoallw_inter_sched_auto:
122             MPII_SCHED_WRAPPER(MPIR_Ineighbor_alltoallw_inter_sched_auto, comm_ptr, request,
123                                sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts,
124                                rdispls, recvtypes);
125             break;
126 
127         case MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Ineighbor_alltoallw_allcomm_sched_linear:
128             MPII_SCHED_WRAPPER(MPIR_Ineighbor_alltoallw_allcomm_sched_linear, comm_ptr, request,
129                                sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts,
130                                rdispls, recvtypes);
131             break;
132 
133         default:
134             MPIR_Assert(0);
135     }
136 
137   fn_exit:
138     return mpi_errno;
139   fn_fail:
140     goto fn_exit;
141 }
142 
MPIR_Ineighbor_alltoallw_intra_sched_auto(const void * sendbuf,const int sendcounts[],const MPI_Aint sdispls[],const MPI_Datatype sendtypes[],void * recvbuf,const int recvcounts[],const MPI_Aint rdispls[],const MPI_Datatype recvtypes[],MPIR_Comm * comm_ptr,MPIR_Sched_t s)143 int MPIR_Ineighbor_alltoallw_intra_sched_auto(const void *sendbuf, const int sendcounts[],
144                                               const MPI_Aint sdispls[],
145                                               const MPI_Datatype sendtypes[], void *recvbuf,
146                                               const int recvcounts[], const MPI_Aint rdispls[],
147                                               const MPI_Datatype recvtypes[], MPIR_Comm * comm_ptr,
148                                               MPIR_Sched_t s)
149 {
150     int mpi_errno = MPI_SUCCESS;
151 
152     mpi_errno =
153         MPIR_Ineighbor_alltoallw_allcomm_sched_linear(sendbuf, sendcounts, sdispls, sendtypes,
154                                                       recvbuf, recvcounts, rdispls, recvtypes,
155                                                       comm_ptr, s);
156     MPIR_ERR_CHECK(mpi_errno);
157 
158   fn_exit:
159     return mpi_errno;
160 
161   fn_fail:
162     goto fn_exit;
163 }
164 
MPIR_Ineighbor_alltoallw_inter_sched_auto(const void * sendbuf,const int sendcounts[],const MPI_Aint sdispls[],const MPI_Datatype sendtypes[],void * recvbuf,const int recvcounts[],const MPI_Aint rdispls[],const MPI_Datatype recvtypes[],MPIR_Comm * comm_ptr,MPIR_Sched_t s)165 int MPIR_Ineighbor_alltoallw_inter_sched_auto(const void *sendbuf, const int sendcounts[],
166                                               const MPI_Aint sdispls[],
167                                               const MPI_Datatype sendtypes[], void *recvbuf,
168                                               const int recvcounts[], const MPI_Aint rdispls[],
169                                               const MPI_Datatype recvtypes[], MPIR_Comm * comm_ptr,
170                                               MPIR_Sched_t s)
171 {
172     int mpi_errno = MPI_SUCCESS;
173 
174     mpi_errno =
175         MPIR_Ineighbor_alltoallw_allcomm_sched_linear(sendbuf, sendcounts, sdispls, sendtypes,
176                                                       recvbuf, recvcounts, rdispls, recvtypes,
177                                                       comm_ptr, s);
178     MPIR_ERR_CHECK(mpi_errno);
179 
180   fn_exit:
181     return mpi_errno;
182 
183   fn_fail:
184     goto fn_exit;
185 }
186 
MPIR_Ineighbor_alltoallw_sched_auto(const void * sendbuf,const int sendcounts[],const MPI_Aint sdispls[],const MPI_Datatype sendtypes[],void * recvbuf,const int recvcounts[],const MPI_Aint rdispls[],const MPI_Datatype recvtypes[],MPIR_Comm * comm_ptr,MPIR_Sched_t s)187 int MPIR_Ineighbor_alltoallw_sched_auto(const void *sendbuf, const int sendcounts[],
188                                         const MPI_Aint sdispls[], const MPI_Datatype sendtypes[],
189                                         void *recvbuf, const int recvcounts[],
190                                         const MPI_Aint rdispls[], const MPI_Datatype recvtypes[],
191                                         MPIR_Comm * comm_ptr, MPIR_Sched_t s)
192 {
193     int mpi_errno = MPI_SUCCESS;
194 
195     if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) {
196         mpi_errno =
197             MPIR_Ineighbor_alltoallw_intra_sched_auto(sendbuf, sendcounts, sdispls,
198                                                       sendtypes, recvbuf, recvcounts,
199                                                       rdispls, recvtypes, comm_ptr, s);
200     } else {
201         mpi_errno =
202             MPIR_Ineighbor_alltoallw_inter_sched_auto(sendbuf, sendcounts, sdispls,
203                                                       sendtypes, recvbuf, recvcounts,
204                                                       rdispls, recvtypes, comm_ptr, s);
205     }
206 
207     return mpi_errno;
208 }
209 
MPIR_Ineighbor_alltoallw_impl(const void * sendbuf,const int sendcounts[],const MPI_Aint sdispls[],const MPI_Datatype sendtypes[],void * recvbuf,const int recvcounts[],const MPI_Aint rdispls[],const MPI_Datatype recvtypes[],MPIR_Comm * comm_ptr,MPIR_Request ** request)210 int MPIR_Ineighbor_alltoallw_impl(const void *sendbuf, const int sendcounts[],
211                                   const MPI_Aint sdispls[],
212                                   const MPI_Datatype sendtypes[],
213                                   void *recvbuf, const int recvcounts[],
214                                   const MPI_Aint rdispls[],
215                                   const MPI_Datatype recvtypes[],
216                                   MPIR_Comm * comm_ptr, MPIR_Request ** request)
217 {
218     int mpi_errno = MPI_SUCCESS;
219 
220     *request = NULL;
221     /* If the user picks one of the transport-enabled algorithms, branch there
222      * before going down to the MPIR_Sched-based algorithms. */
223     /* TODO - Eventually the intention is to replace all of the
224      * MPIR_Sched-based algorithms with transport-enabled algorithms, but that
225      * will require sufficient performance testing and replacement algorithms. */
226     if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) {
227         switch (MPIR_CVAR_INEIGHBOR_ALLTOALLW_INTRA_ALGORITHM) {
228             case MPIR_CVAR_INEIGHBOR_ALLTOALLW_INTRA_ALGORITHM_gentran_linear:
229                 mpi_errno =
230                     MPIR_Ineighbor_alltoallw_allcomm_gentran_linear(sendbuf, sendcounts, sdispls,
231                                                                     sendtypes, recvbuf, recvcounts,
232                                                                     rdispls, recvtypes, comm_ptr,
233                                                                     request);
234                 break;
235 
236             case MPIR_CVAR_INEIGHBOR_ALLTOALLW_INTRA_ALGORITHM_sched_linear:
237                 MPII_SCHED_WRAPPER(MPIR_Ineighbor_alltoallw_allcomm_sched_linear, comm_ptr, request,
238                                    sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts,
239                                    rdispls, recvtypes);
240                 break;
241 
242             case MPIR_CVAR_INEIGHBOR_ALLTOALLW_INTRA_ALGORITHM_sched_auto:
243                 MPII_SCHED_WRAPPER(MPIR_Ineighbor_alltoallw_intra_sched_auto, comm_ptr, request,
244                                    sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts,
245                                    rdispls, recvtypes);
246                 break;
247 
248             case MPIR_CVAR_INEIGHBOR_ALLTOALLW_INTRA_ALGORITHM_auto:
249                 mpi_errno =
250                     MPIR_Ineighbor_alltoallw_allcomm_auto(sendbuf, sendcounts, sdispls, sendtypes,
251                                                           recvbuf, recvcounts, rdispls, recvtypes,
252                                                           comm_ptr, request);
253                 break;
254 
255             default:
256                 MPIR_Assert(0);
257         }
258     } else {
259         switch (MPIR_CVAR_INEIGHBOR_ALLTOALLW_INTER_ALGORITHM) {
260             case MPIR_CVAR_INEIGHBOR_ALLTOALLW_INTER_ALGORITHM_gentran_linear:
261                 mpi_errno =
262                     MPIR_Ineighbor_alltoallw_allcomm_gentran_linear(sendbuf, sendcounts, sdispls,
263                                                                     sendtypes, recvbuf, recvcounts,
264                                                                     rdispls, recvtypes, comm_ptr,
265                                                                     request);
266                 break;
267 
268             case MPIR_CVAR_INEIGHBOR_ALLTOALLW_INTER_ALGORITHM_sched_linear:
269                 MPII_SCHED_WRAPPER(MPIR_Ineighbor_alltoallw_allcomm_sched_linear, comm_ptr, request,
270                                    sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts,
271                                    rdispls, recvtypes);
272                 break;
273 
274             case MPIR_CVAR_INEIGHBOR_ALLTOALLW_INTER_ALGORITHM_sched_auto:
275                 MPII_SCHED_WRAPPER(MPIR_Ineighbor_alltoallw_inter_sched_auto, comm_ptr, request,
276                                    sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts,
277                                    rdispls, recvtypes);
278                 break;
279 
280             case MPIR_CVAR_INEIGHBOR_ALLTOALLW_INTER_ALGORITHM_auto:
281                 mpi_errno =
282                     MPIR_Ineighbor_alltoallw_allcomm_auto(sendbuf, sendcounts, sdispls, sendtypes,
283                                                           recvbuf, recvcounts, rdispls, recvtypes,
284                                                           comm_ptr, request);
285                 break;
286 
287             default:
288                 MPIR_Assert(0);
289         }
290     }
291 
292     MPIR_ERR_CHECK(mpi_errno);
293 
294   fn_exit:
295     return mpi_errno;
296   fn_fail:
297     goto fn_exit;
298 }
299 
MPIR_Ineighbor_alltoallw(const void * sendbuf,const int sendcounts[],const MPI_Aint sdispls[],const MPI_Datatype sendtypes[],void * recvbuf,const int recvcounts[],const MPI_Aint rdispls[],const MPI_Datatype recvtypes[],MPIR_Comm * comm_ptr,MPIR_Request ** request)300 int MPIR_Ineighbor_alltoallw(const void *sendbuf, const int sendcounts[],
301                              const MPI_Aint sdispls[],
302                              const MPI_Datatype sendtypes[], void *recvbuf,
303                              const int recvcounts[], const MPI_Aint rdispls[],
304                              const MPI_Datatype recvtypes[],
305                              MPIR_Comm * comm_ptr, MPIR_Request ** request)
306 {
307     int mpi_errno = MPI_SUCCESS;
308 
309     if ((MPIR_CVAR_DEVICE_COLLECTIVES == MPIR_CVAR_DEVICE_COLLECTIVES_all) ||
310         ((MPIR_CVAR_DEVICE_COLLECTIVES == MPIR_CVAR_DEVICE_COLLECTIVES_percoll) &&
311          MPIR_CVAR_BARRIER_DEVICE_COLLECTIVE)) {
312         mpi_errno =
313             MPID_Ineighbor_alltoallw(sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts,
314                                      rdispls, recvtypes, comm_ptr, request);
315     } else {
316         mpi_errno = MPIR_Ineighbor_alltoallw_impl(sendbuf, sendcounts, sdispls, sendtypes, recvbuf,
317                                                   recvcounts, rdispls, recvtypes, comm_ptr,
318                                                   request);
319     }
320 
321     return mpi_errno;
322 }
323 
324 #endif /* MPICH_MPI_FROM_PMPI */
325 
326 /*@
327 MPI_Ineighbor_alltoallw - Nonblocking version of MPI_Neighbor_alltoallw.
328 
329 Input Parameters:
330 + sendbuf - starting address of the send buffer (choice)
331 . sendcounts - non-negative integer array (of length outdegree) specifying the number of elements to send to each neighbor
332 . sdispls - integer array (of length outdegree).  Entry j specifies the displacement in bytes (relative to sendbuf) from which to take the outgoing data destined for neighbor j (array of integers)
333 . sendtypes - array of datatypes (of length outdegree).  Entry j specifies the type of data to send to neighbor j (array of handles)
334 . recvcounts - non-negative integer array (of length indegree) specifying the number of elements that are received from each neighbor
335 . rdispls - integer array (of length indegree).  Entry i specifies the displacement in bytes (relative to recvbuf) at which to place the incoming data from neighbor i (array of integers).
336 . recvtypes - array of datatypes (of length indegree).  Entry i specifies the type of data received from neighbor i (array of handles).
337 - comm - communicator with topology structure (handle)
338 
339 Output Parameters:
340 + recvbuf - starting address of the receive buffer (choice)
341 - request - communication request (handle)
342 
343 .N ThreadSafe
344 
345 .N Fortran
346 
347 .N Errors
348 @*/
MPI_Ineighbor_alltoallw(const void * sendbuf,const int sendcounts[],const MPI_Aint sdispls[],const MPI_Datatype sendtypes[],void * recvbuf,const int recvcounts[],const MPI_Aint rdispls[],const MPI_Datatype recvtypes[],MPI_Comm comm,MPI_Request * request)349 int MPI_Ineighbor_alltoallw(const void *sendbuf, const int sendcounts[], const MPI_Aint sdispls[],
350                             const MPI_Datatype sendtypes[], void *recvbuf, const int recvcounts[],
351                             const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], MPI_Comm comm,
352                             MPI_Request * request)
353 {
354     int mpi_errno = MPI_SUCCESS;
355     MPIR_Comm *comm_ptr = NULL;
356     MPIR_Request *request_ptr = NULL;
357     MPIR_FUNC_TERSE_STATE_DECL(MPID_STATE_MPI_INEIGHBOR_ALLTOALLW);
358 
359     MPID_THREAD_CS_ENTER(GLOBAL, MPIR_THREAD_GLOBAL_ALLFUNC_MUTEX);
360     MPIR_FUNC_TERSE_ENTER(MPID_STATE_MPI_INEIGHBOR_ALLTOALLW);
361 
362     /* Validate parameters, especially handles needing to be converted */
363 #ifdef HAVE_ERROR_CHECKING
364     {
365         MPID_BEGIN_ERROR_CHECKS;
366         {
367             MPIR_ERRTEST_COMM(comm, mpi_errno);
368 
369             /* TODO more checks may be appropriate */
370         }
371         MPID_END_ERROR_CHECKS;
372     }
373 #endif /* HAVE_ERROR_CHECKING */
374 
375     /* Convert MPI object handles to object pointers */
376     MPIR_Comm_get_ptr(comm, comm_ptr);
377     MPIR_Assert(comm_ptr != NULL);
378 
379     /* Validate parameters and objects (post conversion) */
380 #ifdef HAVE_ERROR_CHECKING
381     {
382         MPID_BEGIN_ERROR_CHECKS;
383         {
384             MPIR_Comm_valid_ptr(comm_ptr, mpi_errno, FALSE);
385             MPIR_ERRTEST_ARGNULL(request, "request", mpi_errno);
386             /* TODO more checks may be appropriate (counts, in_place, buffer aliasing, etc) */
387             if (mpi_errno != MPI_SUCCESS)
388                 goto fn_fail;
389         }
390         MPID_END_ERROR_CHECKS;
391     }
392 #endif /* HAVE_ERROR_CHECKING */
393 
394     /* ... body of routine ...  */
395 
396     mpi_errno =
397         MPIR_Ineighbor_alltoallw(sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts,
398                                  rdispls, recvtypes, comm_ptr, &request_ptr);
399     MPIR_ERR_CHECK(mpi_errno);
400 
401     /* create a complete request, if needed */
402     if (!request_ptr)
403         request_ptr = MPIR_Request_create_complete(MPIR_REQUEST_KIND__COLL);
404     /* return the handle of the request to the user */
405     *request = request_ptr->handle;
406 
407     /* ... end of body of routine ... */
408 
409   fn_exit:
410     MPIR_FUNC_TERSE_EXIT(MPID_STATE_MPI_INEIGHBOR_ALLTOALLW);
411     MPID_THREAD_CS_EXIT(GLOBAL, MPIR_THREAD_GLOBAL_ALLFUNC_MUTEX);
412     return mpi_errno;
413 
414   fn_fail:
415     /* --BEGIN ERROR HANDLING-- */
416 #ifdef HAVE_ERROR_CHECKING
417     {
418         mpi_errno =
419             MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, __func__, __LINE__, MPI_ERR_OTHER,
420                                  "**mpi_ineighbor_alltoallw",
421                                  "**mpi_ineighbor_alltoallw %p %p %p %p %p %p %p %p %C %p", sendbuf,
422                                  sendcounts, sdispls, sendtypes, recvbuf, recvcounts, rdispls,
423                                  recvtypes, comm, request);
424     }
425 #endif
426     mpi_errno = MPIR_Err_return_comm(NULL, __func__, mpi_errno);
427     goto fn_exit;
428     /* --END ERROR HANDLING-- */
429 }
430