1 /*
2  * Copyright (C) by Argonne National Laboratory
3  *     See COPYRIGHT in top-level directory
4  */
5 
6 #include "mpiimpl.h"
7 
8 /* -- Begin Profiling Symbol Block for routine MPI_Get_elements_x */
9 #if defined(HAVE_PRAGMA_WEAK)
10 #pragma weak MPI_Get_elements_x = PMPI_Get_elements_x
11 #elif defined(HAVE_PRAGMA_HP_SEC_DEF)
12 #pragma _HP_SECONDARY_DEF PMPI_Get_elements_x  MPI_Get_elements_x
13 #elif defined(HAVE_PRAGMA_CRI_DUP)
14 #pragma _CRI duplicate MPI_Get_elements_x as PMPI_Get_elements_x
15 #elif defined(HAVE_WEAK_ATTRIBUTE)
16 int MPI_Get_elements_x(const MPI_Status * status, MPI_Datatype datatype, MPI_Count * count)
17     __attribute__ ((weak, alias("PMPI_Get_elements_x")));
18 #endif
19 /* -- End Profiling Symbol Block */
20 
21 /* Internal helper routines.  If you want to get the number of elements from
22  * within the MPI library, call MPIR_Get_elements_x_impl instead. */
23 PMPI_LOCAL MPI_Count MPIR_Type_get_basic_type_elements(MPI_Count * bytes_p,
24                                                        MPI_Count count, MPI_Datatype datatype);
25 PMPI_LOCAL MPI_Count MPIR_Type_get_elements(MPI_Count * bytes_p,
26                                             MPI_Count count, MPI_Datatype datatype);
27 
28 /* Define MPICH_MPI_FROM_PMPI if weak symbols are not supported to build
29    the MPI routines */
30 #ifndef MPICH_MPI_FROM_PMPI
31 #undef MPI_Get_elements_x
32 #define MPI_Get_elements_x PMPI_Get_elements_x
33 
34 /* any non-MPI functions go here, especially non-static ones */
35 
36 /* MPIR_Type_get_basic_type_elements()
37  *
38  * Arguments:
39  * - bytes_p - input/output byte count
40  * - count - maximum number of this type to subtract from the bytes; a count
41  *           of -1 indicates use as many as we like
42  * - datatype - input datatype
43  *
44  * Returns number of elements available given the two constraints of number of
45  * bytes and count of types.  Also reduces the byte count by the amount taken
46  * up by the types.
47  *
48  * Assumptions:
49  * - the type passed to this function must be a basic *or* a pairtype
50  *   (which aren't basic types)
51  * - the count is not zero (otherwise we can't tell between a "no more
52  *   complete types" case and a "zero count" case)
53  *
54  * As per section 4.9.3 of the MPI 1.1 specification, the two-part reduction
55  * types are to be treated as structs of the constituent types.  So we have to
56  * do something special to handle them correctly in here.
57  *
58  * As per section 3.12.5 get_count and get_elements report the same value for
59  * basic datatypes; I'm currently interpreting this to *not* include these
60  * reduction types, as they are considered structs.
61  */
MPIR_Type_get_basic_type_elements(MPI_Count * bytes_p,MPI_Count count,MPI_Datatype datatype)62 PMPI_LOCAL MPI_Count MPIR_Type_get_basic_type_elements(MPI_Count * bytes_p,
63                                                        MPI_Count count, MPI_Datatype datatype)
64 {
65     MPI_Count elements, usable_bytes, used_bytes, type1_sz, type2_sz;
66 
67     if (count == 0)
68         return 0;
69 
70     /* determine the maximum number of bytes we should take from the
71      * byte count.
72      */
73     if (count < 0) {
74         usable_bytes = *bytes_p;
75     } else {
76         usable_bytes = MPL_MIN(*bytes_p, count * MPIR_Datatype_get_basic_size(datatype));
77     }
78 
79     switch (datatype) {
80             /* we don't get valid fortran datatype handles in all cases... */
81 #ifdef HAVE_FORTRAN_BINDING
82         case MPI_2REAL:
83             type1_sz = type2_sz = MPIR_Datatype_get_basic_size(MPI_REAL);
84             break;
85         case MPI_2DOUBLE_PRECISION:
86             type1_sz = type2_sz = MPIR_Datatype_get_basic_size(MPI_DOUBLE_PRECISION);
87             break;
88         case MPI_2INTEGER:
89             type1_sz = type2_sz = MPIR_Datatype_get_basic_size(MPI_INTEGER);
90             break;
91 #endif
92         case MPI_2INT:
93             type1_sz = type2_sz = MPIR_Datatype_get_basic_size(MPI_INT);
94             break;
95         case MPI_FLOAT_INT:
96             type1_sz = MPIR_Datatype_get_basic_size(MPI_FLOAT);
97             type2_sz = MPIR_Datatype_get_basic_size(MPI_INT);
98             break;
99         case MPI_DOUBLE_INT:
100             type1_sz = MPIR_Datatype_get_basic_size(MPI_DOUBLE);
101             type2_sz = MPIR_Datatype_get_basic_size(MPI_INT);
102             break;
103         case MPI_LONG_INT:
104             type1_sz = MPIR_Datatype_get_basic_size(MPI_LONG);
105             type2_sz = MPIR_Datatype_get_basic_size(MPI_INT);
106             break;
107         case MPI_SHORT_INT:
108             type1_sz = MPIR_Datatype_get_basic_size(MPI_SHORT);
109             type2_sz = MPIR_Datatype_get_basic_size(MPI_INT);
110             break;
111         case MPI_LONG_DOUBLE_INT:
112             type1_sz = MPIR_Datatype_get_basic_size(MPI_LONG_DOUBLE);
113             type2_sz = MPIR_Datatype_get_basic_size(MPI_INT);
114             break;
115         default:
116             /* all other types.  this is more complicated than
117              * necessary for handling these types, but it puts us in the
118              * same code path for all the basics, so we stick with it.
119              */
120             type1_sz = type2_sz = MPIR_Datatype_get_basic_size(datatype);
121             break;
122     }
123 
124     /* determine the number of elements in the region */
125     elements = 2 * (usable_bytes / (type1_sz + type2_sz));
126     if (usable_bytes % (type1_sz + type2_sz) >= type1_sz)
127         elements++;
128 
129     /* determine how many bytes we used up with those elements */
130     used_bytes = ((elements / 2) * (type1_sz + type2_sz));
131     if (elements % 2 == 1)
132         used_bytes += type1_sz;
133 
134     *bytes_p -= used_bytes;
135 
136     return elements;
137 }
138 
139 
140 /* MPIR_Type_get_elements
141  *
142  * Arguments:
143  * - bytes_p - input/output byte count
144  * - count - maximum number of this type to subtract from the bytes; a count
145  *           of <0 indicates use as many as we like
146  * - datatype - input datatype
147  *
148  * Returns number of elements available given the two constraints of number of
149  * bytes and count of types.  Also reduces the byte count by the amount taken
150  * up by the types.
151  *
152  * This is called from MPI_Get_elements() when it sees a type with multiple
153  * element types (datatype_ptr->element_sz = -1).  This function calls itself too.
154  */
MPIR_Type_get_elements(MPI_Count * bytes_p,MPI_Count count,MPI_Datatype datatype)155 PMPI_LOCAL MPI_Count MPIR_Type_get_elements(MPI_Count * bytes_p,
156                                             MPI_Count count, MPI_Datatype datatype)
157 {
158     MPIR_Datatype *datatype_ptr = NULL;
159 
160     MPIR_Datatype_get_ptr(datatype, datatype_ptr);      /* invalid if builtin */
161 
162     /* if we have gotten down to a type with only one element type,
163      * call MPIR_Type_get_basic_type_elements() and return.
164      */
165     if (HANDLE_IS_BUILTIN(datatype) ||
166         datatype == MPI_FLOAT_INT ||
167         datatype == MPI_DOUBLE_INT ||
168         datatype == MPI_LONG_INT || datatype == MPI_SHORT_INT || datatype == MPI_LONG_DOUBLE_INT) {
169         return MPIR_Type_get_basic_type_elements(bytes_p, count, datatype);
170     } else if (datatype_ptr->builtin_element_size >= 0) {
171         MPI_Datatype basic_type = MPI_DATATYPE_NULL;
172         MPIR_Datatype_get_basic_type(datatype_ptr->basic_type, basic_type);
173         return MPIR_Type_get_basic_type_elements(bytes_p,
174                                                  count * datatype_ptr->n_builtin_elements,
175                                                  basic_type);
176     } else {
177         /* we have bytes left and still don't have a single element size; must
178          * recurse.
179          */
180         int i, j, *ints;
181         MPI_Count typecount = 0, nr_elements = 0, last_nr_elements;
182         MPI_Aint *aints;
183         MPI_Datatype *types;
184 
185         /* Establish locations of arrays */
186         MPIR_Type_access_contents(datatype_ptr->handle, &ints, &aints, &types);
187         if (!ints || !aints || !types)
188             return MPI_ERR_TYPE;
189 
190         switch (datatype_ptr->contents->combiner) {
191             case MPI_COMBINER_NAMED:
192             case MPI_COMBINER_DUP:
193             case MPI_COMBINER_RESIZED:
194                 return MPIR_Type_get_elements(bytes_p, count, *types);
195                 break;
196             case MPI_COMBINER_CONTIGUOUS:
197             case MPI_COMBINER_VECTOR:
198             case MPI_COMBINER_HVECTOR_INTEGER:
199             case MPI_COMBINER_HVECTOR:
200             case MPI_COMBINER_SUBARRAY:
201                 /* count is first in ints array */
202                 return MPIR_Type_get_elements(bytes_p, count * (*ints), *types);
203                 break;
204             case MPI_COMBINER_INDEXED_BLOCK:
205             case MPI_COMBINER_HINDEXED_BLOCK:
206                 /* count is first in ints array, blocklength is second */
207                 return MPIR_Type_get_elements(bytes_p, count * ints[0] * ints[1], *types);
208                 break;
209             case MPI_COMBINER_INDEXED:
210             case MPI_COMBINER_HINDEXED_INTEGER:
211             case MPI_COMBINER_HINDEXED:
212                 for (i = 0; i < (*ints); i++) {
213                     /* add up the blocklengths to get a max. # of the next type */
214                     typecount += ints[i + 1];
215                 }
216                 return MPIR_Type_get_elements(bytes_p, count * typecount, *types);
217                 break;
218             case MPI_COMBINER_STRUCT_INTEGER:
219             case MPI_COMBINER_STRUCT:
220                 /* In this case we can't simply multiply the count of the next
221                  * type by the count of the current type, because we need to
222                  * cycle through the types just as the struct would.  thus the
223                  * nested loops.
224                  *
225                  * We need to keep going until we get less elements than expected
226                  * or we run out of bytes.
227                  */
228 
229 
230                 last_nr_elements = 1;   /* seed value */
231                 for (j = 0; (count < 0 || j < count) && *bytes_p > 0 && last_nr_elements > 0; j++) {
232                     /* recurse on each type; bytes are reduced in calls */
233                     for (i = 0; i < (*ints); i++) {
234                         /* skip zero-count elements of the struct */
235                         if (ints[i + 1] == 0)
236                             continue;
237 
238                         last_nr_elements = MPIR_Type_get_elements(bytes_p, ints[i + 1], types[i]);
239                         nr_elements += last_nr_elements;
240 
241                         MPIR_Assert(last_nr_elements >= 0);
242 
243                         if (last_nr_elements < ints[i + 1])
244                             break;
245                     }
246                 }
247                 return nr_elements;
248                 break;
249             case MPI_COMBINER_DARRAY:
250             case MPI_COMBINER_F90_REAL:
251             case MPI_COMBINER_F90_COMPLEX:
252             case MPI_COMBINER_F90_INTEGER:
253             default:
254                 /* --BEGIN ERROR HANDLING-- */
255                 MPIR_Assert(0);
256                 return -1;
257                 break;
258                 /* --END ERROR HANDLING-- */
259         }
260     }
261 }
262 
263 /* MPIR_Get_elements_x_impl
264  *
265  * Arguments:
266  * - byte_count - input/output byte count
267  * - datatype - input datatype
268  * - elements - Number of basic elements this byte_count would contain
269  *
270  * Returns number of elements available given the two constraints of number of
271  * bytes and count of types.  Also reduces the byte count by the amount taken
272  * up by the types.
273  */
MPIR_Get_elements_x_impl(MPI_Count * byte_count,MPI_Datatype datatype,MPI_Count * elements)274 int MPIR_Get_elements_x_impl(MPI_Count * byte_count, MPI_Datatype datatype, MPI_Count * elements)
275 {
276     int mpi_errno = MPI_SUCCESS;
277     MPIR_Datatype *datatype_ptr = NULL;
278 
279     if (!HANDLE_IS_BUILTIN(datatype)) {
280         MPIR_Datatype_get_ptr(datatype, datatype_ptr);
281     }
282 
283     /* three cases:
284      * - nice, simple, single element type
285      * - derived type with a zero size
286      * - type with multiple element types (nastiest)
287      */
288     if (HANDLE_IS_BUILTIN(datatype) ||
289         (datatype_ptr->builtin_element_size != -1 && datatype_ptr->size > 0)) {
290         /* in both cases we do not limit the number of types that might
291          * be in bytes
292          */
293         if (!HANDLE_IS_BUILTIN(datatype)) {
294             MPI_Datatype basic_type = MPI_DATATYPE_NULL;
295             MPIR_Datatype_get_basic_type(datatype_ptr->basic_type, basic_type);
296             *elements = MPIR_Type_get_basic_type_elements(byte_count, -1, basic_type);
297         } else {
298             /* Behaves just like MPI_Get_Count in the predefined case */
299             MPI_Count size;
300             MPIR_Datatype_get_size_macro(datatype, size);
301             if ((*byte_count % size) != 0)
302                 *elements = MPI_UNDEFINED;
303             else
304                 *elements = MPIR_Type_get_basic_type_elements(byte_count, -1, datatype);
305         }
306         MPIR_Assert(*byte_count >= 0);
307     } else if (datatype_ptr->size == 0) {
308         if (*byte_count > 0) {
309             /* --BEGIN ERROR HANDLING-- */
310 
311             /* datatype size of zero and count > 0 should never happen. */
312 
313             (*elements) = MPI_UNDEFINED;
314             /* --END ERROR HANDLING-- */
315         } else {
316             /* This is ambiguous.  However, discussions on MPI Forum
317              * reached a consensus that this is the correct return
318              * value
319              */
320             (*elements) = 0;
321         }
322     } else {    /* derived type with weird element type or weird size */
323 
324         MPIR_Assert(datatype_ptr->builtin_element_size == -1);
325 
326         *elements = MPIR_Type_get_elements(byte_count, -1, datatype);
327     }
328 
329     return mpi_errno;
330 }
331 
332 #endif /* MPICH_MPI_FROM_PMPI */
333 
334 /* N.B. "count" is the name mandated by the MPI-3 standard, but it should
335  * probably be called "elements" instead and is handled that way in the _impl
336  * routine [goodell@ 2012-11-05 */
337 /*@
338 MPI_Get_elements_x - Returns the number of basic elements
339                      in a datatype
340 
341 Input Parameters:
342 + status - return status of receive operation (Status)
343 - datatype - datatype used by receive operation (handle)
344 
345 Output Parameters:
346 . count - number of received basic elements (integer)
347 
348 .N ThreadSafe
349 
350 .N Fortran
351 
352 .N Errors
353 @*/
MPI_Get_elements_x(const MPI_Status * status,MPI_Datatype datatype,MPI_Count * count)354 int MPI_Get_elements_x(const MPI_Status * status, MPI_Datatype datatype, MPI_Count * count)
355 {
356     int mpi_errno = MPI_SUCCESS;
357     MPI_Count byte_count;
358     MPIR_FUNC_TERSE_STATE_DECL(MPID_STATE_MPI_GET_ELEMENTS_X);
359 
360     MPID_THREAD_CS_ENTER(GLOBAL, MPIR_THREAD_GLOBAL_ALLFUNC_MUTEX);
361     MPIR_FUNC_TERSE_ENTER(MPID_STATE_MPI_GET_ELEMENTS_X);
362 
363     /* Validate parameters, especially handles needing to be converted */
364 #ifdef HAVE_ERROR_CHECKING
365     {
366         MPID_BEGIN_ERROR_CHECKS;
367         {
368             MPIR_ERRTEST_DATATYPE(datatype, "datatype", mpi_errno);
369 
370             /* TODO more checks may be appropriate */
371             if (mpi_errno != MPI_SUCCESS)
372                 goto fn_fail;
373         }
374         MPID_END_ERROR_CHECKS;
375     }
376 #endif /* HAVE_ERROR_CHECKING */
377 
378     /* Convert MPI object handles to object pointers */
379 
380     /* Validate parameters and objects (post conversion) */
381 #ifdef HAVE_ERROR_CHECKING
382     {
383         MPID_BEGIN_ERROR_CHECKS;
384         {
385             if (!HANDLE_IS_BUILTIN(datatype)) {
386                 MPIR_Datatype *datatype_ptr = NULL;
387                 MPIR_Datatype_get_ptr(datatype, datatype_ptr);
388                 MPIR_Datatype_valid_ptr(datatype_ptr, mpi_errno);
389                 MPIR_Datatype_committed_ptr(datatype_ptr, mpi_errno);
390             }
391 
392             /* TODO more checks may be appropriate (counts, in_place, buffer aliasing, etc) */
393             if (mpi_errno != MPI_SUCCESS)
394                 goto fn_fail;
395         }
396         MPID_END_ERROR_CHECKS;
397     }
398 #endif /* HAVE_ERROR_CHECKING */
399 
400     /* ... body of routine ...  */
401 
402     byte_count = MPIR_STATUS_GET_COUNT(*status);
403     mpi_errno = MPIR_Get_elements_x_impl(&byte_count, datatype, count);
404     MPIR_ERR_CHECK(mpi_errno);
405 
406     /* ... end of body of routine ... */
407 
408   fn_exit:
409     MPIR_FUNC_TERSE_EXIT(MPID_STATE_MPI_GET_ELEMENTS_X);
410     MPID_THREAD_CS_EXIT(GLOBAL, MPIR_THREAD_GLOBAL_ALLFUNC_MUTEX);
411     return mpi_errno;
412 
413   fn_fail:
414     /* --BEGIN ERROR HANDLING-- */
415 #ifdef HAVE_ERROR_CHECKING
416     {
417         mpi_errno =
418             MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, __func__, __LINE__, MPI_ERR_OTHER,
419                                  "**mpi_get_elements_x", "**mpi_get_elements_x %p %D %p", status,
420                                  datatype, count);
421     }
422 #endif
423     mpi_errno = MPIR_Err_return_comm(NULL, __func__, mpi_errno);
424     goto fn_exit;
425     /* --END ERROR HANDLING-- */
426 }
427