1 /*
2  * Copyright (C) by Argonne National Laboratory
3  *     See COPYRIGHT in top-level directory
4  */
5 
6 #include "mpiimpl.h"
7 
8 /*
9 === BEGIN_MPI_T_CVAR_INFO_BLOCK ===
10 
11 cvars:
12     - name        : MPIR_CVAR_IEXSCAN_INTRA_ALGORITHM
13       category    : COLLECTIVE
14       type        : enum
15       default     : auto
16       class       : none
17       verbosity   : MPI_T_VERBOSITY_USER_BASIC
18       scope       : MPI_T_SCOPE_ALL_EQ
19       description : |-
20         Variable to select iexscan algorithm
21         auto - Internal algorithm selection (can be overridden with MPIR_CVAR_COLL_SELECTION_TUNING_JSON_FILE)
22         sched_auto - Internal algorithm selection for sched-based algorithms
23         sched_recursive_doubling - Force recursive doubling algorithm
24 
25     - name        : MPIR_CVAR_IEXSCAN_DEVICE_COLLECTIVE
26       category    : COLLECTIVE
27       type        : boolean
28       default     : true
29       class       : none
30       verbosity   : MPI_T_VERBOSITY_USER_BASIC
31       scope       : MPI_T_SCOPE_ALL_EQ
32       description : >-
33         This CVAR is only used when MPIR_CVAR_DEVICE_COLLECTIVES
34         is set to "percoll".  If set to true, MPI_Iexscan will
35         allow the device to override the MPIR-level collective
36         algorithms.  The device might still call the MPIR-level
37         algorithms manually.  If set to false, the device-override
38         will be disabled.
39 
40 === END_MPI_T_CVAR_INFO_BLOCK ===
41 */
42 
43 /* -- Begin Profiling Symbol Block for routine MPI_Iexscan */
44 #if defined(HAVE_PRAGMA_WEAK)
45 #pragma weak MPI_Iexscan = PMPI_Iexscan
46 #elif defined(HAVE_PRAGMA_HP_SEC_DEF)
47 #pragma _HP_SECONDARY_DEF PMPI_Iexscan  MPI_Iexscan
48 #elif defined(HAVE_PRAGMA_CRI_DUP)
49 #pragma _CRI duplicate MPI_Iexscan as PMPI_Iexscan
50 #elif defined(HAVE_WEAK_ATTRIBUTE)
51 int MPI_Iexscan(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype,
52                 MPI_Op op, MPI_Comm comm, MPI_Request * request)
53     __attribute__ ((weak, alias("PMPI_Iexscan")));
54 #endif
55 /* -- End Profiling Symbol Block */
56 
57 /* Define MPICH_MPI_FROM_PMPI if weak symbols are not supported to build
58    the MPI routines */
59 #ifndef MPICH_MPI_FROM_PMPI
60 #undef MPI_Iexscan
61 #define MPI_Iexscan PMPI_Iexscan
62 
63 
MPIR_Iexscan_allcomm_auto(const void * sendbuf,void * recvbuf,int count,MPI_Datatype datatype,MPI_Op op,MPIR_Comm * comm_ptr,MPIR_Request ** request)64 int MPIR_Iexscan_allcomm_auto(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype,
65                               MPI_Op op, MPIR_Comm * comm_ptr, MPIR_Request ** request)
66 {
67     int mpi_errno = MPI_SUCCESS;
68 
69     MPIR_Csel_coll_sig_s coll_sig = {
70         .coll_type = MPIR_CSEL_COLL_TYPE__IEXSCAN,
71         .comm_ptr = comm_ptr,
72 
73         .u.iexscan.sendbuf = sendbuf,
74         .u.iexscan.recvbuf = recvbuf,
75         .u.iexscan.count = count,
76         .u.iexscan.datatype = datatype,
77         .u.iexscan.op = op,
78     };
79 
80     MPII_Csel_container_s *cnt = MPIR_Csel_search(comm_ptr->csel_comm, coll_sig);
81     MPIR_Assert(cnt);
82 
83     switch (cnt->id) {
84         case MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Iexscan_intra_sched_auto:
85             MPII_SCHED_WRAPPER(MPIR_Iexscan_intra_sched_auto, comm_ptr, request, sendbuf, recvbuf,
86                                count, datatype, op);
87             break;
88 
89         case MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Iexscan_intra_sched_recursive_doubling:
90             MPII_SCHED_WRAPPER(MPIR_Iexscan_intra_sched_recursive_doubling, comm_ptr, request,
91                                sendbuf, recvbuf, count, datatype, op);
92             break;
93 
94         default:
95             MPIR_Assert(0);
96     }
97 
98   fn_exit:
99     return mpi_errno;
100   fn_fail:
101     goto fn_exit;
102 }
103 
MPIR_Iexscan_intra_sched_auto(const void * sendbuf,void * recvbuf,int count,MPI_Datatype datatype,MPI_Op op,MPIR_Comm * comm_ptr,MPIR_Sched_t s)104 int MPIR_Iexscan_intra_sched_auto(const void *sendbuf, void *recvbuf, int count,
105                                   MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr,
106                                   MPIR_Sched_t s)
107 {
108     int mpi_errno = MPI_SUCCESS;
109 
110     mpi_errno =
111         MPIR_Iexscan_intra_sched_recursive_doubling(sendbuf, recvbuf, count, datatype, op, comm_ptr,
112                                                     s);
113     MPIR_ERR_CHECK(mpi_errno);
114 
115   fn_exit:
116     return mpi_errno;
117 
118   fn_fail:
119     goto fn_exit;
120 }
121 
MPIR_Iexscan_impl(const void * sendbuf,void * recvbuf,int count,MPI_Datatype datatype,MPI_Op op,MPIR_Comm * comm_ptr,MPIR_Request ** request)122 int MPIR_Iexscan_impl(const void *sendbuf, void *recvbuf, int count,
123                       MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr,
124                       MPIR_Request ** request)
125 {
126     int mpi_errno = MPI_SUCCESS;
127 
128     *request = NULL;
129 
130     switch (MPIR_CVAR_IEXSCAN_INTRA_ALGORITHM) {
131         case MPIR_CVAR_IEXSCAN_INTRA_ALGORITHM_sched_recursive_doubling:
132             MPII_SCHED_WRAPPER(MPIR_Iexscan_intra_sched_recursive_doubling, comm_ptr, request,
133                                sendbuf, recvbuf, count, datatype, op);
134             break;
135 
136         case MPIR_CVAR_IEXSCAN_INTRA_ALGORITHM_sched_auto:
137             MPII_SCHED_WRAPPER(MPIR_Iexscan_intra_sched_auto, comm_ptr, request, sendbuf, recvbuf,
138                                count, datatype, op);
139             break;
140 
141         case MPIR_CVAR_IEXSCAN_INTRA_ALGORITHM_auto:
142             mpi_errno = MPIR_Iexscan_allcomm_auto(sendbuf, recvbuf, count, datatype, op, comm_ptr,
143                                                   request);
144             break;
145 
146         default:
147             MPIR_Assert(0);
148     }
149 
150   fn_exit:
151     return mpi_errno;
152   fn_fail:
153     goto fn_exit;
154 }
155 
MPIR_Iexscan(const void * sendbuf,void * recvbuf,int count,MPI_Datatype datatype,MPI_Op op,MPIR_Comm * comm_ptr,MPIR_Request ** request)156 int MPIR_Iexscan(const void *sendbuf, void *recvbuf, int count,
157                  MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, MPIR_Request ** request)
158 {
159     int mpi_errno = MPI_SUCCESS;
160     void *in_recvbuf = recvbuf;
161     void *host_sendbuf;
162     void *host_recvbuf;
163 
164     MPIR_Coll_host_buffer_alloc(sendbuf, recvbuf, count, datatype, &host_sendbuf, &host_recvbuf);
165     if (host_sendbuf)
166         sendbuf = host_sendbuf;
167     if (host_recvbuf)
168         recvbuf = host_recvbuf;
169 
170     if ((MPIR_CVAR_DEVICE_COLLECTIVES == MPIR_CVAR_DEVICE_COLLECTIVES_all) ||
171         ((MPIR_CVAR_DEVICE_COLLECTIVES == MPIR_CVAR_DEVICE_COLLECTIVES_percoll) &&
172          MPIR_CVAR_IEXSCAN_DEVICE_COLLECTIVE)) {
173         mpi_errno = MPID_Iexscan(sendbuf, recvbuf, count, datatype, op, comm_ptr, request);
174     } else {
175         mpi_errno = MPIR_Iexscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, request);
176     }
177 
178     MPIR_Coll_host_buffer_swap_back(host_sendbuf, host_recvbuf, in_recvbuf, count, datatype,
179                                     *request);
180 
181     return mpi_errno;
182 }
183 
184 #endif /* MPICH_MPI_FROM_PMPI */
185 
186 /*@
187 MPI_Iexscan - Computes the exclusive scan (partial reductions) of data on a
188               collection of processes in a nonblocking way
189 
190 
191 Input Parameters:
192 + sendbuf - starting address of the send buffer (choice)
193 . count - number of elements in input buffer (non-negative integer)
194 . datatype - data type of elements of input buffer (handle)
195 . op - operation (handle)
196 - comm - communicator (handle)
197 
198 Output Parameters:
199 + recvbuf - starting address of the receive buffer (choice)
200 - request - communication request (handle)
201 
202 .N ThreadSafe
203 
204 .N Fortran
205 
206 .N Errors
207 @*/
MPI_Iexscan(const void * sendbuf,void * recvbuf,int count,MPI_Datatype datatype,MPI_Op op,MPI_Comm comm,MPI_Request * request)208 int MPI_Iexscan(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype,
209                 MPI_Op op, MPI_Comm comm, MPI_Request * request)
210 {
211     int mpi_errno = MPI_SUCCESS;
212     MPIR_Comm *comm_ptr = NULL;
213     MPIR_Request *request_ptr = NULL;
214     MPIR_FUNC_TERSE_STATE_DECL(MPID_STATE_MPI_IEXSCAN);
215 
216     MPID_THREAD_CS_ENTER(GLOBAL, MPIR_THREAD_GLOBAL_ALLFUNC_MUTEX);
217     MPIR_FUNC_TERSE_ENTER(MPID_STATE_MPI_IEXSCAN);
218 
219     /* Validate parameters, especially handles needing to be converted */
220 #ifdef HAVE_ERROR_CHECKING
221     {
222         MPID_BEGIN_ERROR_CHECKS;
223         {
224             MPIR_ERRTEST_DATATYPE(datatype, "datatype", mpi_errno);
225             MPIR_ERRTEST_OP(op, mpi_errno);
226             MPIR_ERRTEST_COMM(comm, mpi_errno);
227 
228             /* TODO more checks may be appropriate */
229         }
230         MPID_END_ERROR_CHECKS;
231     }
232 #endif /* HAVE_ERROR_CHECKING */
233 
234     /* Convert MPI object handles to object pointers */
235     MPIR_Comm_get_ptr(comm, comm_ptr);
236     MPIR_Assert(comm_ptr != NULL);
237 
238     /* Validate parameters and objects (post conversion) */
239 #ifdef HAVE_ERROR_CHECKING
240     {
241         MPID_BEGIN_ERROR_CHECKS;
242         {
243             MPIR_Comm_valid_ptr(comm_ptr, mpi_errno, FALSE);
244             MPIR_ERRTEST_COMM_INTRA(comm_ptr, mpi_errno);
245             if (!HANDLE_IS_BUILTIN(datatype)) {
246                 MPIR_Datatype *datatype_ptr = NULL;
247                 MPIR_Datatype_get_ptr(datatype, datatype_ptr);
248                 MPIR_Datatype_valid_ptr(datatype_ptr, mpi_errno);
249                 if (mpi_errno != MPI_SUCCESS)
250                     goto fn_fail;
251                 MPIR_Datatype_committed_ptr(datatype_ptr, mpi_errno);
252                 if (mpi_errno != MPI_SUCCESS)
253                     goto fn_fail;
254             }
255 
256             if (!HANDLE_IS_BUILTIN(op)) {
257                 MPIR_Op *op_ptr = NULL;
258                 MPIR_Op_get_ptr(op, op_ptr);
259                 MPIR_Op_valid_ptr(op_ptr, mpi_errno);
260             } else {
261                 mpi_errno = (*MPIR_OP_HDL_TO_DTYPE_FN(op)) (datatype);
262             }
263             if (mpi_errno != MPI_SUCCESS)
264                 goto fn_fail;
265 
266             MPIR_ERRTEST_ARGNULL(request, "request", mpi_errno);
267 
268             if (sendbuf != MPI_IN_PLACE && count != 0)
269                 MPIR_ERRTEST_ALIAS_COLL(sendbuf, recvbuf, mpi_errno);
270             /* TODO more checks may be appropriate (counts, in_place, etc) */
271         }
272         MPID_END_ERROR_CHECKS;
273     }
274 #endif /* HAVE_ERROR_CHECKING */
275 
276     /* ... body of routine ...  */
277 
278     mpi_errno = MPIR_Iexscan(sendbuf, recvbuf, count, datatype, op, comm_ptr, &request_ptr);
279     MPIR_ERR_CHECK(mpi_errno);
280 
281     /* create a complete request, if needed */
282     if (!request_ptr)
283         request_ptr = MPIR_Request_create_complete(MPIR_REQUEST_KIND__COLL);
284     /* return the handle of the request to the user */
285     *request = request_ptr->handle;
286 
287     /* ... end of body of routine ... */
288 
289   fn_exit:
290     MPIR_FUNC_TERSE_EXIT(MPID_STATE_MPI_IEXSCAN);
291     MPID_THREAD_CS_EXIT(GLOBAL, MPIR_THREAD_GLOBAL_ALLFUNC_MUTEX);
292     return mpi_errno;
293 
294   fn_fail:
295     /* --BEGIN ERROR HANDLING-- */
296 #ifdef HAVE_ERROR_CHECKING
297     {
298         mpi_errno =
299             MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, __func__, __LINE__, MPI_ERR_OTHER,
300                                  "**mpi_iexscan", "**mpi_iexscan %p %p %d %D %O %C %p", sendbuf,
301                                  recvbuf, count, datatype, op, comm, request);
302     }
303 #endif
304     mpi_errno = MPIR_Err_return_comm(comm_ptr, __func__, mpi_errno);
305     goto fn_exit;
306     /* --END ERROR HANDLING-- */
307 }
308