1 /*
2  * Copyright (C) by Argonne National Laboratory
3  *     See COPYRIGHT in top-level directory
4  */
5 
6 #include <mpir_pmi.h>
7 #include <mpiimpl.h>
8 #include "mpir_nodemap.h"
9 
10 static int build_nodemap(int *nodemap, int sz, int *p_max_node_id);
11 static int build_locality(void);
12 
13 static int pmi_version = 1;
14 static int pmi_subversion = 1;
15 
16 static int pmi_max_key_size;
17 static int pmi_max_val_size;
18 
19 #ifdef USE_PMI1_API
20 static int pmi_max_kvs_name_length;
21 static char *pmi_kvs_name;
22 #elif defined USE_PMI2_API
23 static char *pmi_jobid;
24 #elif defined USE_PMIX_API
25 static pmix_proc_t pmix_proc;
26 static pmix_proc_t pmix_wcproc;
27 #endif
28 
MPIR_pmi_init(void)29 int MPIR_pmi_init(void)
30 {
31     int mpi_errno = MPI_SUCCESS;
32     int pmi_errno;
33 
34     /* See if the user wants to override our default values */
35     MPL_env2int("PMI_VERSION", &pmi_version);
36     MPL_env2int("PMI_SUBVERSION", &pmi_subversion);
37 
38     int has_parent, rank, size, appnum;
39 #ifdef USE_PMI1_API
40     pmi_errno = PMI_Init(&has_parent);
41     MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI_SUCCESS, mpi_errno, MPI_ERR_OTHER,
42                          "**pmi_init", "**pmi_init %d", pmi_errno);
43     pmi_errno = PMI_Get_rank(&rank);
44     MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI_SUCCESS, mpi_errno, MPI_ERR_OTHER,
45                          "**pmi_get_rank", "**pmi_get_rank %d", pmi_errno);
46     pmi_errno = PMI_Get_size(&size);
47     MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI_SUCCESS, mpi_errno, MPI_ERR_OTHER,
48                          "**pmi_get_size", "**pmi_get_size %d", pmi_errno);
49     pmi_errno = PMI_Get_appnum(&appnum);
50     MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI_SUCCESS, mpi_errno, MPI_ERR_OTHER,
51                          "**pmi_get_appnum", "**pmi_get_appnum %d", pmi_errno);
52 
53     pmi_errno = PMI_KVS_Get_name_length_max(&pmi_max_kvs_name_length);
54     MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI_SUCCESS, mpi_errno, MPI_ERR_OTHER,
55                          "**pmi_kvs_get_name_length_max",
56                          "**pmi_kvs_get_name_length_max %d", pmi_errno);
57     pmi_kvs_name = (char *) MPL_malloc(pmi_max_kvs_name_length, MPL_MEM_OTHER);
58     pmi_errno = PMI_KVS_Get_my_name(pmi_kvs_name, pmi_max_kvs_name_length);
59     MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI_SUCCESS, mpi_errno, MPI_ERR_OTHER,
60                          "**pmi_kvs_get_my_name", "**pmi_kvs_get_my_name %d", pmi_errno);
61 
62     pmi_errno = PMI_KVS_Get_key_length_max(&pmi_max_key_size);
63     MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI_SUCCESS, mpi_errno, MPI_ERR_OTHER,
64                          "**pmi_kvs_get_key_length_max",
65                          "**pmi_kvs_get_key_length_max %d", pmi_errno);
66     pmi_errno = PMI_KVS_Get_value_length_max(&pmi_max_val_size);
67     MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI_SUCCESS, mpi_errno, MPI_ERR_OTHER,
68                          "**pmi_kvs_get_value_length_max",
69                          "**pmi_kvs_get_value_length_max %d", pmi_errno);
70 
71 #elif defined USE_PMI2_API
72     pmi_max_key_size = PMI2_MAX_KEYLEN;
73     pmi_max_val_size = PMI2_MAX_VALLEN;
74 
75     pmi_errno = PMI2_Init(&has_parent, &size, &rank, &appnum);
76     MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI2_SUCCESS, mpi_errno, MPI_ERR_OTHER,
77                          "**pmi_init", "**pmi_init %d", pmi_errno);
78 
79     pmi_jobid = (char *) MPL_malloc(PMI2_MAX_VALLEN, MPL_MEM_OTHER);
80     pmi_errno = PMI2_Job_GetId(pmi_jobid, PMI2_MAX_VALLEN);
81     MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI2_SUCCESS, mpi_errno, MPI_ERR_OTHER,
82                          "**pmi_job_getid", "**pmi_job_getid %d", pmi_errno);
83 
84 #elif defined USE_PMIX_API
85     pmi_max_key_size = PMIX_MAX_KEYLEN;
86     pmi_max_val_size = 1024;    /* this is what PMI2_MAX_VALLEN currently set to */
87 
88     pmix_value_t *pvalue = NULL;
89 
90     pmi_errno = PMIx_Init(&pmix_proc, NULL, 0);
91     MPIR_ERR_CHKANDJUMP1(pmi_errno != PMIX_SUCCESS, mpi_errno, MPI_ERR_OTHER,
92                          "**pmix_init", "**pmix_init %d", pmi_errno);
93 
94     rank = pmix_proc.rank;
95     PMIX_PROC_CONSTRUCT(&pmix_wcproc);
96     MPL_strncpy(pmix_wcproc.nspace, pmix_proc.nspace, PMIX_MAX_NSLEN);
97     pmix_wcproc.rank = PMIX_RANK_WILDCARD;
98 
99     pmi_errno = PMIx_Get(&pmix_wcproc, PMIX_JOB_SIZE, NULL, 0, &pvalue);
100     MPIR_ERR_CHKANDJUMP1(pmi_errno != PMIX_SUCCESS, mpi_errno, MPI_ERR_OTHER,
101                          "**pmix_get", "**pmix_get %d", pmi_errno);
102     size = pvalue->data.uint32;
103     PMIX_VALUE_RELEASE(pvalue);
104 
105     /* appnum, has_parent is not set for now */
106     appnum = 0;
107     has_parent = 0;
108 
109 #endif
110     MPIR_Process.has_parent = has_parent;
111     MPIR_Process.rank = rank;
112     MPIR_Process.size = size;
113     MPIR_Process.appnum = appnum;
114 
115     static int g_max_node_id = -1;
116     MPIR_Process.node_map = (int *) MPL_malloc(size * sizeof(int), MPL_MEM_ADDRESS);
117 
118     mpi_errno = build_nodemap(MPIR_Process.node_map, size, &g_max_node_id);
119     MPIR_ERR_CHECK(mpi_errno);
120     MPIR_Process.num_nodes = g_max_node_id + 1;
121 
122     /* allocate and populate MPIR_Process.node_local_map and MPIR_Process.node_root_map */
123     mpi_errno = build_locality();
124 
125   fn_exit:
126     return mpi_errno;
127   fn_fail:
128     goto fn_exit;
129 }
130 
MPIR_pmi_finalize(void)131 void MPIR_pmi_finalize(void)
132 {
133 #ifdef USE_PMI1_API
134     PMI_Finalize();
135     MPL_free(pmi_kvs_name);
136 #elif defined(USE_PMI2_API)
137     PMI2_Finalize();
138     MPL_free(pmi_jobid);
139 #elif defined(USE_PMIX_API)
140     PMIx_Finalize(NULL, 0);
141     /* pmix_proc does not need free */
142 #endif
143 
144     MPL_free(MPIR_Process.node_map);
145     MPL_free(MPIR_Process.node_root_map);
146     MPL_free(MPIR_Process.node_local_map);
147 }
148 
MPIR_pmi_abort(int exit_code,const char * error_msg)149 void MPIR_pmi_abort(int exit_code, const char *error_msg)
150 {
151 #ifdef USE_PMI1_API
152     PMI_Abort(exit_code, error_msg);
153 #elif defined(USE_PMI2_API)
154     PMI2_Abort(TRUE, error_msg);
155 #elif defined(USE_PMIX_API)
156     PMIx_Abort(exit_code, error_msg, NULL, 0);
157 #endif
158 }
159 
160 /* getters for internal constants */
MPIR_pmi_max_key_size(void)161 int MPIR_pmi_max_key_size(void)
162 {
163     return pmi_max_key_size;
164 }
165 
MPIR_pmi_max_val_size(void)166 int MPIR_pmi_max_val_size(void)
167 {
168     return pmi_max_val_size;
169 }
170 
MPIR_pmi_job_id(void)171 const char *MPIR_pmi_job_id(void)
172 {
173 #ifdef USE_PMI1_API
174     return (const char *) pmi_kvs_name;
175 #elif defined USE_PMI2_API
176     return (const char *) pmi_jobid;
177 #elif defined USE_PMIX_API
178     return (const char *) pmix_proc.nspace;
179 #endif
180 }
181 
182 /* wrapper functions */
MPIR_pmi_kvs_put(const char * key,const char * val)183 int MPIR_pmi_kvs_put(const char *key, const char *val)
184 {
185     int mpi_errno = MPI_SUCCESS;
186     int pmi_errno;
187 
188 #ifdef USE_PMI1_API
189     pmi_errno = PMI_KVS_Put(pmi_kvs_name, key, val);
190     MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI_SUCCESS, mpi_errno, MPI_ERR_OTHER,
191                          "**pmi_kvs_put", "**pmi_kvs_put %d", pmi_errno);
192     pmi_errno = PMI_KVS_Commit(pmi_kvs_name);
193     MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI_SUCCESS, mpi_errno, MPI_ERR_OTHER,
194                          "**pmi_kvs_commit", "**pmi_kvs_commit %d", pmi_errno);
195 #elif defined(USE_PMI2_API)
196     pmi_errno = PMI2_KVS_Put(key, val);
197     MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI2_SUCCESS, mpi_errno, MPI_ERR_OTHER,
198                          "**pmi_kvsput", "**pmi_kvsput %d", pmi_errno);
199 #elif defined(USE_PMIX_API)
200     pmix_value_t value;
201     value.type = PMIX_STRING;
202     value.data.string = (char *) val;
203     pmi_errno = PMIx_Put(PMIX_GLOBAL, key, &value);
204     MPIR_ERR_CHKANDJUMP1(pmi_errno != PMIX_SUCCESS, mpi_errno, MPI_ERR_OTHER,
205                          "**pmix_put", "**pmix_put %d", pmi_errno);
206     pmi_errno = PMIx_Commit();
207     MPIR_ERR_CHKANDJUMP1(pmi_errno != PMIX_SUCCESS, mpi_errno, MPI_ERR_OTHER,
208                          "**pmix_commit", "**pmix_commit %d", pmi_errno);
209 #endif
210 
211   fn_exit:
212     return mpi_errno;
213   fn_fail:
214     goto fn_exit;
215 }
216 
217 /* NOTE: src is a hint, use src = -1 if not known */
MPIR_pmi_kvs_get(int src,const char * key,char * val,int val_size)218 int MPIR_pmi_kvs_get(int src, const char *key, char *val, int val_size)
219 {
220     int mpi_errno = MPI_SUCCESS;
221     int pmi_errno;
222 
223 #ifdef USE_PMI1_API
224     /* src is not used in PMI1 */
225     pmi_errno = PMI_KVS_Get(pmi_kvs_name, key, val, val_size);
226     MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI_SUCCESS, mpi_errno, MPI_ERR_OTHER,
227                          "**pmi_kvs_get", "**pmi_kvs_get %d", pmi_errno);
228 #elif defined(USE_PMI2_API)
229     if (src < 0)
230         src = PMI2_ID_NULL;
231     int out_len;
232     pmi_errno = PMI2_KVS_Get(pmi_jobid, src, key, val, val_size, &out_len);
233     MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI2_SUCCESS, mpi_errno, MPI_ERR_OTHER,
234                          "**pmi_kvsget", "**pmi_kvsget %d", pmi_errno);
235 #elif defined(USE_PMIX_API)
236     pmix_value_t *pvalue;
237     if (src < 0) {
238         pmi_errno = PMIx_Get(NULL, key, NULL, 0, &pvalue);
239     } else {
240         pmix_proc_t proc;
241         PMIX_PROC_CONSTRUCT(&proc);
242         proc.rank = src;
243 
244         pmi_errno = PMIx_Get(&proc, key, NULL, 0, &pvalue);
245     }
246     MPIR_ERR_CHKANDJUMP1(pmi_errno != PMIX_SUCCESS, mpi_errno, MPI_ERR_OTHER,
247                          "**pmix_get", "**pmix_get %d", pmi_errno);
248     strncpy(val, pvalue->data.string, val_size);
249     PMIX_VALUE_RELEASE(pvalue);
250 #endif
251 
252   fn_exit:
253     return mpi_errno;
254   fn_fail:
255     goto fn_exit;
256 }
257 
258 /* ---- utils functions ---- */
259 
MPIR_pmi_barrier(void)260 int MPIR_pmi_barrier(void)
261 {
262     int mpi_errno = MPI_SUCCESS;
263     int pmi_errno;
264 
265 #ifdef USE_PMI1_API
266     pmi_errno = PMI_Barrier();
267     MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI_SUCCESS, mpi_errno, MPI_ERR_OTHER,
268                          "**pmi_barrier", "**pmi_barrier %d", pmi_errno);
269 #elif defined(USE_PMI2_API)
270     pmi_errno = PMI2_KVS_Fence();
271     MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI2_SUCCESS, mpi_errno, MPI_ERR_OTHER,
272                          "**pmi_kvsfence", "**pmi_kvsfence %d", pmi_errno);
273 #elif defined(USE_PMIX_API)
274     pmix_info_t *info;
275     PMIX_INFO_CREATE(info, 1);
276     int flag = 1;
277     PMIX_INFO_LOAD(info, PMIX_COLLECT_DATA, &flag, PMIX_BOOL);
278 
279     /* use global wildcard proc set */
280     pmi_errno = PMIx_Fence(&pmix_wcproc, 1, info, 1);
281     MPIR_ERR_CHKANDJUMP1(pmi_errno != PMIX_SUCCESS, mpi_errno, MPI_ERR_OTHER,
282                          "**pmix_fence", "**pmix_fence %d", pmi_errno);
283     PMIX_INFO_FREE(info, 1);
284 #endif
285 
286   fn_exit:
287     return mpi_errno;
288   fn_fail:
289     goto fn_exit;
290 }
291 
MPIR_pmi_barrier_local(void)292 int MPIR_pmi_barrier_local(void)
293 {
294 #if defined(USE_PMIX_API)
295     int mpi_errno = MPI_SUCCESS;
296     int pmi_errno;
297     int local_size = MPIR_Process.local_size;
298     pmix_proc_t *procs = MPL_malloc(local_size * sizeof(pmix_proc_t), MPL_MEM_OTHER);
299     for (int i = 0; i < local_size; i++) {
300         PMIX_PROC_CONSTRUCT(&procs[i]);
301         strncpy(procs[i].nspace, pmix_proc.nspace, PMIX_MAX_NSLEN);
302         procs[i].rank = MPIR_Process.node_local_map[i];
303     }
304 
305     pmix_info_t *info;
306     int flag = 1;
307     PMIX_INFO_CREATE(info, 1);
308     PMIX_INFO_LOAD(info, PMIX_COLLECT_DATA, &flag, PMIX_BOOL);
309 
310     pmi_errno = PMIx_Fence(procs, local_size, info, 1);
311     MPIR_ERR_CHKANDJUMP1(pmi_errno != PMIX_SUCCESS, mpi_errno, MPI_ERR_OTHER, "**pmix_fence",
312                          "**pmix_fence %d", pmi_errno);
313 
314     PMIX_INFO_FREE(info, 1);
315     MPL_free(procs);
316 
317   fn_exit:
318     return mpi_errno;
319   fn_fail:
320     goto fn_exit;
321 #else
322     /* If local barrier is not supported (PMI1 and PMI2), simply fallback */
323     return MPIR_pmi_barrier();
324 #endif
325 }
326 
327 /* declare static functions used in bcast/allgather */
328 static void encode(int size, const char *src, char *dest);
329 static void decode(int size, const char *src, char *dest);
330 
331 /* is_local is a hint that we optimize for node local access when we can */
optimized_put(const char * key,const char * val,int is_local)332 static int optimized_put(const char *key, const char *val, int is_local)
333 {
334     int mpi_errno = MPI_SUCCESS;
335 #if defined(USE_PMI1_API)
336     mpi_errno = MPIR_pmi_kvs_put(key, val);
337 #elif defined(USE_PMI2_API)
338     if (!is_local) {
339         mpi_errno = MPIR_pmi_kvs_put(key, val);
340     } else {
341         int pmi_errno = PMI2_Info_PutNodeAttr(key, val);
342         MPIR_ERR_CHKANDJUMP(pmi_errno != PMI2_SUCCESS, mpi_errno, MPI_ERR_OTHER,
343                             "**pmi_putnodeattr");
344     }
345 #elif defined(USE_PMIX_API)
346     int pmi_errno;
347     pmix_value_t value;
348     value.type = PMIX_STRING;
349     value.data.string = (char *) val;
350     pmi_errno = PMIx_Put(is_local ? PMIX_LOCAL : PMIX_GLOBAL, key, &value);
351     MPIR_ERR_CHKANDJUMP1(pmi_errno != PMIX_SUCCESS, mpi_errno, MPI_ERR_OTHER,
352                          "**pmix_put", "**pmix_put %d", pmi_errno);
353     pmi_errno = PMIx_Commit();
354     MPIR_ERR_CHKANDJUMP1(pmi_errno != PMIX_SUCCESS, mpi_errno, MPI_ERR_OTHER,
355                          "**pmix_commit", "**pmix_commit %d", pmi_errno);
356 #endif
357 
358   fn_exit:
359     return mpi_errno;
360   fn_fail:
361     goto fn_exit;
362 }
363 
optimized_get(int src,const char * key,char * val,int valsize,int is_local)364 static int optimized_get(int src, const char *key, char *val, int valsize, int is_local)
365 {
366 #if defined(USE_PMI1_API)
367     return MPIR_pmi_kvs_get(src, key, val, valsize);
368 #elif defined(USE_PMI2_API)
369     if (is_local) {
370         int mpi_errno = MPI_SUCCESS;
371         int found;
372         int pmi_errno = PMI2_Info_GetNodeAttr(key, val, valsize, &found, TRUE);
373         if (pmi_errno != PMI2_SUCCESS) {
374             MPIR_ERR_SET(mpi_errno, MPI_ERR_OTHER, "**pmi_getnodeattr");
375         } else if (!found) {
376             MPIR_ERR_SET(mpi_errno, MPI_ERR_OTHER, "**pmi_getnodeattr");
377         }
378         return mpi_errno;
379     } else {
380         return MPIR_pmi_kvs_get(src, key, val, valsize);
381     }
382 #else
383     return MPIR_pmi_kvs_get(src, key, val, valsize);
384 #endif
385 }
386 
387 /* higher-level binary put/get:
388  * 1. binary encoding/decoding
389  * 2. chops long values into multiple segments
390  * 3. uses optimized_put/get for the case of node-level access
391  */
put_ex(const char * key,const void * buf,int bufsize,int is_local)392 static int put_ex(const char *key, const void *buf, int bufsize, int is_local)
393 {
394     int mpi_errno = MPI_SUCCESS;
395 #if defined(USE_PMI1_API) || defined(USE_PMI2_API)
396     char *val = MPL_malloc(pmi_max_val_size, MPL_MEM_OTHER);
397     /* reserve some spaces for '\0' and maybe newlines
398      * (depends on pmi implementations, and may not be sufficient) */
399     int segsize = (pmi_max_val_size - 2) / 2;
400     if (bufsize < segsize) {
401         encode(bufsize, buf, val);
402         mpi_errno = optimized_put(key, val, is_local);
403         MPIR_ERR_CHECK(mpi_errno);
404     } else {
405         int num_segs = bufsize / segsize;
406         if (bufsize % segsize > 0) {
407             num_segs++;
408         }
409         MPL_snprintf(val, pmi_max_val_size, "segments=%d", num_segs);
410         mpi_errno = MPIR_pmi_kvs_put(key, val);
411         MPIR_ERR_CHECK(mpi_errno);
412         for (int i = 0; i < num_segs; i++) {
413             char seg_key[50];
414             sprintf(seg_key, "%s-seg-%d/%d", key, i + 1, num_segs);
415             int n = segsize;
416             if (i == num_segs - 1) {
417                 n = bufsize - segsize * (num_segs - 1);
418             }
419             encode(n, (char *) buf + i * segsize, val);
420             mpi_errno = optimized_put(seg_key, val, is_local);
421             MPIR_ERR_CHECK(mpi_errno);
422         }
423     }
424 #elif defined(USE_PMIX_API)
425     int n = bufsize * 2 + 1;
426     char *val = MPL_malloc(n, MPL_MEM_OTHER);
427     encode(bufsize, buf, val);
428     mpi_errno = optimized_put(key, val, is_local);
429     MPIR_ERR_CHECK(mpi_errno);
430 #endif
431   fn_exit:
432     MPL_free(val);
433     return mpi_errno;
434   fn_fail:
435     goto fn_exit;
436 }
437 
get_ex(int src,const char * key,void * buf,int * p_size,int is_local)438 static int get_ex(int src, const char *key, void *buf, int *p_size, int is_local)
439 {
440     int mpi_errno = MPI_SUCCESS;
441     char *val = MPL_malloc(pmi_max_val_size, MPL_MEM_OTHER);
442     int segsize = (pmi_max_val_size - 1) / 2;
443 
444     MPIR_Assert(p_size);
445     MPIR_Assert(*p_size > 0);
446     int bufsize = *p_size;
447     int got_size;
448 
449     mpi_errno = optimized_get(src, key, val, pmi_max_val_size, is_local);
450     MPIR_ERR_CHECK(mpi_errno);
451     if (strncmp(val, "segments=", 9) == 0) {
452         int num_segs = atoi(val + 9);
453         got_size = 0;
454         for (int i = 0; i < num_segs; i++) {
455             char seg_key[50];
456             sprintf(seg_key, "%s-seg-%d/%d", key, i + 1, num_segs);
457             mpi_errno = optimized_get(src, seg_key, val, pmi_max_val_size, is_local);
458             MPIR_ERR_CHECK(mpi_errno);
459             int n = strlen(val) / 2;    /* 2-to-1 decode */
460             if (i < num_segs - 1) {
461                 MPIR_Assert(n == segsize);
462             } else {
463                 MPIR_Assert(n <= segsize);
464             }
465             decode(n, val, (char *) buf + i * segsize);
466             got_size += n;
467         }
468     } else {
469         int n = strlen(val) / 2;        /* 2-to-1 decode */
470         decode(n, val, (char *) buf);
471         got_size = n;
472     }
473     MPIR_Assert(got_size <= bufsize);
474     if (got_size < bufsize) {
475         ((char *) buf)[got_size] = '\0';
476     }
477 
478     *p_size = got_size;
479 
480   fn_exit:
481     MPL_free(val);
482     return mpi_errno;
483   fn_fail:
484     goto fn_exit;
485 }
486 
optional_bcast_barrier(MPIR_PMI_DOMAIN domain)487 static int optional_bcast_barrier(MPIR_PMI_DOMAIN domain)
488 {
489 #if defined(USE_PMI1_API)
490     /* unless bcast is skipped alltogether */
491     if (domain == MPIR_PMI_DOMAIN_ALL && MPIR_Process.size == 1) {
492         return MPI_SUCCESS;
493     } else if (domain == MPIR_PMI_DOMAIN_NODE_ROOTS && MPIR_Process.num_nodes == 1) {
494         return MPI_SUCCESS;
495     } else if (domain == MPIR_PMI_DOMAIN_LOCAL && MPIR_Process.size == MPIR_Process.num_nodes) {
496         return MPI_SUCCESS;
497     }
498 #elif defined(USE_PMI2_API)
499     if (domain == MPIR_PMI_DOMAIN_ALL && MPIR_Process.size == 1) {
500         return MPI_SUCCESS;
501     } else if (domain == MPIR_PMI_DOMAIN_NODE_ROOTS && MPIR_Process.num_nodes == 1) {
502         return MPI_SUCCESS;
503     } else if (domain == MPIR_PMI_DOMAIN_LOCAL) {
504         /* PMI2 local uses Put/GetNodeAttr, no need for barrier */
505         return MPI_SUCCESS;
506     }
507 #elif defined(USE_PMIx_API)
508     /* PMIx will block/wait, so barrier unnecessary */
509     return MPI_SUCCESS;
510 #endif
511     return MPIR_pmi_barrier();
512 }
513 
MPIR_pmi_bcast(void * buf,int bufsize,MPIR_PMI_DOMAIN domain)514 int MPIR_pmi_bcast(void *buf, int bufsize, MPIR_PMI_DOMAIN domain)
515 {
516     int mpi_errno = MPI_SUCCESS;
517 
518     int rank = MPIR_Process.rank;
519     int local_node_id = MPIR_Process.node_map[rank];
520     int node_root = MPIR_Process.node_root_map[local_node_id];
521     int is_node_root = (node_root == rank);
522 
523     int in_domain, is_root, is_local, bcast_size;
524     if (domain == MPIR_PMI_DOMAIN_NODE_ROOTS && !is_node_root) {
525         in_domain = 0;
526     } else {
527         in_domain = 1;
528     }
529     if (rank == 0 || (domain == MPIR_PMI_DOMAIN_LOCAL && is_node_root)) {
530         is_root = 1;
531     } else {
532         is_root = 0;
533     }
534     is_local = (domain == MPIR_PMI_DOMAIN_LOCAL);
535 
536     bcast_size = MPIR_Process.size;
537     if (domain == MPIR_PMI_DOMAIN_NODE_ROOTS) {
538         bcast_size = MPIR_Process.num_nodes;
539     } else if (domain == MPIR_PMI_DOMAIN_LOCAL) {
540         bcast_size = MPIR_Process.local_size;
541     }
542     if (bcast_size == 1) {
543         in_domain = 0;
544     }
545 
546     char key[50];
547     int root;
548     static int bcast_seq = 0;
549 
550     if (!in_domain) {
551         /* PMI_Barrier may require all process to participate */
552         mpi_errno = optional_bcast_barrier(domain);
553         MPIR_ERR_CHECK(mpi_errno);
554     } else {
555         MPIR_Assert(buf);
556         MPIR_Assert(bufsize > 0);
557 
558         bcast_seq++;
559 
560         root = 0;
561         if (domain == MPIR_PMI_DOMAIN_LOCAL) {
562             root = node_root;
563         }
564         /* add root to the key since potentially we may have multiple root(s)
565          * on a single node due to odd-even-cliques */
566         sprintf(key, "-bcast-%d-%d", bcast_seq, root);
567 
568         if (is_root) {
569             mpi_errno = put_ex(key, buf, bufsize, is_local);
570             MPIR_ERR_CHECK(mpi_errno);
571         }
572 
573         mpi_errno = optional_bcast_barrier(domain);
574         MPIR_ERR_CHECK(mpi_errno);
575 
576         if (!is_root) {
577             int got_size = bufsize;
578             mpi_errno = get_ex(root, key, buf, &got_size, is_local);
579             MPIR_ERR_CHECK(mpi_errno);
580         }
581     }
582 
583   fn_exit:
584     return mpi_errno;
585   fn_fail:
586     goto fn_exit;
587 }
588 
MPIR_pmi_allgather(const void * sendbuf,int sendsize,void * recvbuf,int recvsize,MPIR_PMI_DOMAIN domain)589 int MPIR_pmi_allgather(const void *sendbuf, int sendsize, void *recvbuf, int recvsize,
590                        MPIR_PMI_DOMAIN domain)
591 {
592     int mpi_errno = MPI_SUCCESS;
593 
594     MPIR_Assert(domain != MPIR_PMI_DOMAIN_LOCAL);
595 
596     int local_node_id = MPIR_Process.node_map[MPIR_Process.rank];
597     int is_node_root = (MPIR_Process.node_root_map[local_node_id] == MPIR_Process.rank);
598     int in_domain = 1;
599     if (domain == MPIR_PMI_DOMAIN_NODE_ROOTS && !is_node_root) {
600         in_domain = 0;
601     }
602 
603     static int allgather_seq = 0;
604     allgather_seq++;
605 
606     char key[50];
607     sprintf(key, "-allgather-%d-%d", allgather_seq, MPIR_Process.rank);
608 
609     if (in_domain) {
610         mpi_errno = put_ex(key, sendbuf, sendsize, 0);
611         MPIR_ERR_CHECK(mpi_errno);
612     }
613 #ifndef USE_PMIX_API
614     /* PMIx will wait, so barrier unnecessary */
615     mpi_errno = MPIR_pmi_barrier();
616     MPIR_ERR_CHECK(mpi_errno);
617 #endif
618 
619     if (in_domain) {
620         int domain_size = MPIR_Process.size;
621         if (domain == MPIR_PMI_DOMAIN_NODE_ROOTS) {
622             domain_size = MPIR_Process.num_nodes;
623         }
624         for (int i = 0; i < domain_size; i++) {
625             int rank = i;
626             if (domain == MPIR_PMI_DOMAIN_NODE_ROOTS) {
627                 rank = MPIR_Process.node_root_map[i];
628             }
629             sprintf(key, "-allgather-%d-%d", allgather_seq, rank);
630             int got_size = recvsize;
631             mpi_errno = get_ex(rank, key, (unsigned char *) recvbuf + i * recvsize, &got_size, 0);
632             MPIR_ERR_CHECK(mpi_errno);
633         }
634     }
635 
636   fn_exit:
637     return mpi_errno;
638   fn_fail:
639     goto fn_exit;
640 }
641 
642 /* This version assumes shm_buf is shared across local procs. Each process
643  * participate in the gather part by distributing the task over local procs.
644  *
645  * NOTE: the behavior is different with MPIR_pmi_allgather when domain is
646  * MPIR_PMI_DOMAIN_NODE_ROOTS. With MPIR_pmi_allgather, only the root_nodes participate.
647  */
MPIR_pmi_allgather_shm(const void * sendbuf,int sendsize,void * shm_buf,int recvsize,MPIR_PMI_DOMAIN domain)648 int MPIR_pmi_allgather_shm(const void *sendbuf, int sendsize, void *shm_buf, int recvsize,
649                            MPIR_PMI_DOMAIN domain)
650 {
651     int mpi_errno = MPI_SUCCESS;
652 
653     MPIR_Assert(domain != MPIR_PMI_DOMAIN_LOCAL);
654 
655     int rank = MPIR_Process.rank;
656     int size = MPIR_Process.size;
657     int local_size = MPIR_Process.local_size;
658     int local_rank = MPIR_Process.local_rank;
659     int local_node_id = MPIR_Process.node_map[rank];
660     int node_root = MPIR_Process.node_root_map[local_node_id];
661     int is_node_root = (node_root == MPIR_Process.rank);
662 
663     static int allgather_shm_seq = 0;
664     allgather_shm_seq++;
665 
666     char key[50];
667     sprintf(key, "-allgather-shm-%d-%d", allgather_shm_seq, rank);
668 
669     /* in roots-only, non-roots would skip the put */
670     if (domain != MPIR_PMI_DOMAIN_NODE_ROOTS || is_node_root) {
671         mpi_errno = put_ex(key, (unsigned char *) sendbuf, sendsize, 0);
672         MPIR_ERR_CHECK(mpi_errno);
673     }
674 
675     mpi_errno = MPIR_pmi_barrier();
676     MPIR_ERR_CHECK(mpi_errno);
677 
678     /* Each rank need get val from "size" ranks, divide the task evenly over local ranks */
679     if (domain == MPIR_PMI_DOMAIN_NODE_ROOTS) {
680         size = MPIR_Process.num_nodes;
681     }
682     int per_local_rank = size / local_size;
683     if (per_local_rank * local_size < size) {
684         per_local_rank++;
685     }
686     int start = local_rank * per_local_rank;
687     int end = start + per_local_rank;
688     if (end > size) {
689         end = size;
690     }
691     for (int i = start; i < end; i++) {
692         int src = i;
693         if (domain == MPIR_PMI_DOMAIN_NODE_ROOTS) {
694             src = MPIR_Process.node_root_map[i];
695         }
696         sprintf(key, "-allgather-shm-%d-%d", allgather_shm_seq, src);
697         int got_size = recvsize;
698         mpi_errno = get_ex(src, key, (unsigned char *) shm_buf + i * recvsize, &got_size, 0);
699         MPIR_ERR_CHECK(mpi_errno);
700         MPIR_Assert(got_size <= recvsize);
701     }
702 
703   fn_exit:
704     return mpi_errno;
705   fn_fail:
706     goto fn_exit;
707 }
708 
MPIR_pmi_get_universe_size(int * universe_size)709 int MPIR_pmi_get_universe_size(int *universe_size)
710 {
711     int mpi_errno = MPI_SUCCESS;
712     int pmi_errno;
713 
714 #ifdef USE_PMI1_API
715     pmi_errno = PMI_Get_universe_size(universe_size);
716     MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI_SUCCESS, mpi_errno, MPI_ERR_OTHER,
717                          "**pmi_get_universe_size", "**pmi_get_universe_size %d", pmi_errno);
718 #elif defined(USE_PMI2_API)
719     char val[PMI2_MAX_VALLEN];
720     int found = 0;
721     char *endptr;
722 
723     pmi_errno = PMI2_Info_GetJobAttr("universeSize", val, sizeof(val), &found);
724     MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI2_SUCCESS, mpi_errno, MPI_ERR_OTHER,
725                          "**pmi_getjobattr", "**pmi_getjobattr %d", pmi_errno);
726     if (!found) {
727         *universe_size = MPIR_UNIVERSE_SIZE_NOT_AVAILABLE;
728     } else {
729         *universe_size = strtol(val, &endptr, 0);
730         MPIR_ERR_CHKINTERNAL(endptr - val != strlen(val), mpi_errno, "can't parse universe size");
731     }
732 #elif defined(USE_PMIX_API)
733     pmix_value_t *pvalue = NULL;
734 
735     pmi_errno = PMIx_Get(&pmix_wcproc, PMIX_UNIV_SIZE, NULL, 0, &pvalue);
736     MPIR_ERR_CHKANDJUMP1(pmi_errno != PMIX_SUCCESS, mpi_errno, MPI_ERR_OTHER,
737                          "**pmix_get", "**pmix_get %d", pmi_errno);
738     *universe_size = pvalue->data.uint32;
739     PMIX_VALUE_RELEASE(pvalue);
740 #endif
741   fn_exit:
742     return mpi_errno;
743   fn_fail:
744     goto fn_exit;
745 }
746 
MPIR_pmi_get_failed_procs(void)747 char *MPIR_pmi_get_failed_procs(void)
748 {
749     int pmi_errno;
750     char *failed_procs_string = NULL;
751 
752     failed_procs_string = MPL_malloc(pmi_max_val_size, MPL_MEM_OTHER);
753     MPIR_Assert(failed_procs_string);
754 #ifdef USE_PMI1_API
755     pmi_errno = PMI_KVS_Get(pmi_kvs_name, "PMI_dead_processes",
756                             failed_procs_string, pmi_max_val_size);
757     if (pmi_errno != PMI_SUCCESS)
758         goto fn_fail;
759 #elif defined(USE_PMI2_API)
760     int out_len;
761     pmi_errno = PMI2_KVS_Get(pmi_jobid, PMI2_ID_NULL, "PMI_dead_processes",
762                              failed_procs_string, pmi_max_val_size, &out_len);
763     if (pmi_errno != PMI2_SUCCESS)
764         goto fn_fail;
765 #elif defined(USE_PMIX_API)
766     goto fn_fail;
767 #endif
768 
769   fn_exit:
770     return failed_procs_string;
771   fn_fail:
772     /* FIXME: approprate error messages here? */
773     MPL_free(failed_procs_string);
774     failed_procs_string = NULL;
775     goto fn_exit;
776 }
777 
778 /* static functions only for MPIR_pmi_spawn_multiple */
779 #if defined(USE_PMI1_API) || defined(USE_PMI2_API)
780 /* PMI_keyval_t is only defined in PMI1 or PMI2 */
781 static int mpi_to_pmi_keyvals(MPIR_Info * info_ptr, PMI_keyval_t ** kv_ptr, int *nkeys_ptr);
782 static void free_pmi_keyvals(PMI_keyval_t ** kv, int size, int *counts);
783 #endif
784 
785 /* NOTE: MPIR_pmi_spawn_multiple is to be called by a single root spawning process */
MPIR_pmi_spawn_multiple(int count,char * commands[],char ** argvs[],const int maxprocs[],MPIR_Info * info_ptrs[],int num_preput_keyval,struct MPIR_PMI_KEYVAL * preput_keyvals,int * pmi_errcodes)786 int MPIR_pmi_spawn_multiple(int count, char *commands[], char **argvs[],
787                             const int maxprocs[], MPIR_Info * info_ptrs[],
788                             int num_preput_keyval, struct MPIR_PMI_KEYVAL *preput_keyvals,
789                             int *pmi_errcodes)
790 {
791     int mpi_errno = MPI_SUCCESS;
792     int pmi_errno;
793 
794 #ifdef USE_PMI1_API
795     int *info_keyval_sizes = NULL;
796     PMI_keyval_t **info_keyval_vectors = NULL;
797 
798     info_keyval_sizes = (int *) MPL_malloc(count * sizeof(int), MPL_MEM_BUFFER);
799     MPIR_ERR_CHKANDJUMP(!info_keyval_sizes, mpi_errno, MPI_ERR_OTHER, "**nomem");
800 
801     info_keyval_vectors =
802         (PMI_keyval_t **) MPL_malloc(count * sizeof(PMI_keyval_t *), MPL_MEM_BUFFER);
803     MPIR_ERR_CHKANDJUMP(!info_keyval_vectors, mpi_errno, MPI_ERR_OTHER, "**nomem");
804 
805     if (!info_ptrs) {
806         for (int i = 0; i < count; i++) {
807             info_keyval_vectors[i] = 0;
808             info_keyval_sizes[i] = 0;
809         }
810     } else {
811         for (int i = 0; i < count; i++) {
812             mpi_errno = mpi_to_pmi_keyvals(info_ptrs[i], &info_keyval_vectors[i],
813                                            &info_keyval_sizes[i]);
814             MPIR_ERR_CHECK(mpi_errno);
815         }
816     }
817 
818     pmi_errno = PMI_Spawn_multiple(count, (const char **) commands, (const char ***) argvs,
819                                    maxprocs,
820                                    info_keyval_sizes,
821                                    (const PMI_keyval_t **) info_keyval_vectors,
822                                    num_preput_keyval, (PMI_keyval_t *) preput_keyvals,
823                                    pmi_errcodes);
824 
825     MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI_SUCCESS, mpi_errno, MPI_ERR_OTHER,
826                          "**pmi_spawn_multiple", "**pmi_spawn_multiple %d", pmi_errno);
827 #elif defined(USE_PMI2_API)
828     /* not supported yet */
829     MPIR_Assert(0);
830 #elif defined(USE_PMIX_API)
831     /* not supported yet */
832     MPIR_Assert(0);
833 #endif
834 
835   fn_exit:
836 #ifdef USE_PMI1_API
837     if (info_keyval_vectors) {
838         free_pmi_keyvals(info_keyval_vectors, count, info_keyval_sizes);
839         MPL_free(info_keyval_vectors);
840     }
841 
842     MPL_free(info_keyval_sizes);
843 #endif
844 
845     return mpi_errno;
846   fn_fail:
847     goto fn_exit;
848 }
849 
850 /* ---- static functions ---- */
851 
852 /* The following static function declares are only for build_nodemap() */
853 static int get_option_no_local(void);
854 static int get_option_num_cliques(void);
855 static int build_nodemap_nolocal(int *nodemap, int sz, int *p_max_node_id);
856 static int build_nodemap_roundrobin(int num_cliques, int *nodemap, int sz, int *p_max_node_id);
857 
858 #ifdef USE_PMI1_API
859 static int build_nodemap_pmi1(int *nodemap, int sz, int *p_max_node_id);
860 static int build_nodemap_fallback(int *nodemap, int sz, int *p_max_node_id);
861 #elif defined(USE_PMI2_API)
862 static int build_nodemap_pmi2(int *nodemap, int sz, int *p_max_node_id);
863 #elif defined(USE_PMIX_API)
864 static int build_nodemap_pmix(int *nodemap, int sz, int *p_max_node_id);
865 #endif
866 
build_nodemap(int * nodemap,int sz,int * p_max_node_id)867 static int build_nodemap(int *nodemap, int sz, int *p_max_node_id)
868 {
869     int mpi_errno = MPI_SUCCESS;
870 
871     if (sz == 1 || get_option_no_local()) {
872         mpi_errno = build_nodemap_nolocal(nodemap, sz, p_max_node_id);
873         goto fn_exit;
874     }
875 #ifdef USE_PMI1_API
876     mpi_errno = build_nodemap_pmi1(nodemap, sz, p_max_node_id);
877 #elif defined(USE_PMI2_API)
878     mpi_errno = build_nodemap_pmi2(nodemap, sz, p_max_node_id);
879 #elif defined(USE_PMIX_API)
880     mpi_errno = build_nodemap_pmix(nodemap, sz, p_max_node_id);
881 #endif
882     MPIR_ERR_CHECK(mpi_errno);
883 
884     int num_cliques = get_option_num_cliques();
885     if (num_cliques > sz) {
886         num_cliques = sz;
887     }
888     if (*p_max_node_id == 0 && num_cliques > 1) {
889         mpi_errno = build_nodemap_roundrobin(num_cliques, nodemap, sz, p_max_node_id);
890         MPIR_ERR_CHECK(mpi_errno);
891     }
892 
893   fn_exit:
894     return mpi_errno;
895   fn_fail:
896     goto fn_exit;
897 }
898 
get_option_no_local(void)899 static int get_option_no_local(void)
900 {
901     /* Used for debugging only.  This disables communication over shared memory */
902 #ifdef ENABLE_NO_LOCAL
903     return 1;
904 #else
905     return MPIR_CVAR_NOLOCAL;
906 #endif
907 }
908 
get_option_num_cliques(void)909 static int get_option_num_cliques(void)
910 {
911     /* Used for debugging on a single machine: split procs into num_cliques nodes.
912      * If ODD_EVEN_CLIQUES were enabled, split procs into 2 nodes.
913      */
914     if (MPIR_CVAR_NUM_CLIQUES > 1) {
915         return MPIR_CVAR_NUM_CLIQUES;
916     } else {
917         return MPIR_CVAR_ODD_EVEN_CLIQUES ? 2 : 1;
918     }
919 }
920 
921 /* one process per node */
build_nodemap_nolocal(int * nodemap,int sz,int * p_max_node_id)922 int build_nodemap_nolocal(int *nodemap, int sz, int *p_max_node_id)
923 {
924     for (int i = 0; i < sz; ++i) {
925         nodemap[i] = i;
926     }
927     *p_max_node_id = sz - 1;
928     return MPI_SUCCESS;
929 }
930 
931 /* assign processes to num_cliques nodes in a round-robin fashion */
build_nodemap_roundrobin(int num_cliques,int * nodemap,int sz,int * p_max_node_id)932 static int build_nodemap_roundrobin(int num_cliques, int *nodemap, int sz, int *p_max_node_id)
933 {
934     for (int i = 0; i < sz; ++i) {
935         nodemap[i] = i % num_cliques;
936     }
937     *p_max_node_id = num_cliques - 1;
938     return MPI_SUCCESS;
939 }
940 
941 #ifdef USE_PMI1_API
942 
943 /* build nodemap based on allgather hostnames */
944 /* FIXME: migrate the function */
build_nodemap_fallback(int * nodemap,int sz,int * p_max_node_id)945 static int build_nodemap_fallback(int *nodemap, int sz, int *p_max_node_id)
946 {
947     return MPIR_NODEMAP_build_nodemap_fallback(sz, MPIR_Process.rank, nodemap, p_max_node_id);
948 }
949 
950 /* build nodemap using PMI1 process_mapping or fallback with hostnames */
build_nodemap_pmi1(int * nodemap,int sz,int * p_max_node_id)951 static int build_nodemap_pmi1(int *nodemap, int sz, int *p_max_node_id)
952 {
953     int mpi_errno = MPI_SUCCESS;
954     int pmi_errno;
955     int did_map = 0;
956     if (pmi_version == 1 && pmi_subversion == 1) {
957         char *process_mapping = MPL_malloc(pmi_max_val_size, MPL_MEM_ADDRESS);
958         pmi_errno = PMI_KVS_Get(pmi_kvs_name, "PMI_process_mapping",
959                                 process_mapping, pmi_max_val_size);
960         if (pmi_errno == PMI_SUCCESS) {
961             mpi_errno = MPIR_NODEMAP_populate_ids_from_mapping(process_mapping, sz, nodemap,
962                                                                p_max_node_id, &did_map);
963             MPIR_ERR_CHECK(mpi_errno);
964             MPIR_ERR_CHKINTERNAL(!did_map, mpi_errno,
965                                  "unable to populate node ids from PMI_process_mapping");
966         }
967         MPL_free(process_mapping);
968     }
969     if (!did_map) {
970         mpi_errno = build_nodemap_fallback(nodemap, sz, p_max_node_id);
971     }
972   fn_exit:
973     return mpi_errno;
974   fn_fail:
975     goto fn_exit;
976 }
977 
978 #elif defined USE_PMI2_API
979 
980 /* build nodemap using PMI2 process_mapping or error */
build_nodemap_pmi2(int * nodemap,int sz,int * p_max_node_id)981 static int build_nodemap_pmi2(int *nodemap, int sz, int *p_max_node_id)
982 {
983     int mpi_errno = MPI_SUCCESS;
984     int pmi_errno;
985     char process_mapping[PMI2_MAX_VALLEN];
986     int found;
987 
988     pmi_errno = PMI2_Info_GetJobAttr("PMI_process_mapping", process_mapping, PMI2_MAX_VALLEN,
989                                      &found);
990     MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI2_SUCCESS, mpi_errno, MPI_ERR_OTHER,
991                          "**pmi2_info_getjobattr", "**pmi2_info_getjobattr %d", pmi_errno);
992     MPIR_ERR_CHKINTERNAL(!found, mpi_errno, "PMI_process_mapping attribute not found");
993 
994     int did_map;
995     mpi_errno = MPIR_NODEMAP_populate_ids_from_mapping(process_mapping, sz, nodemap,
996                                                        p_max_node_id, &did_map);
997     MPIR_ERR_CHECK(mpi_errno);
998     MPIR_ERR_CHKINTERNAL(!did_map, mpi_errno,
999                          "unable to populate node ids from PMI_process_mapping");
1000   fn_exit:
1001     return mpi_errno;
1002   fn_fail:
1003     goto fn_exit;
1004 }
1005 
1006 #elif defined USE_PMIX_API
1007 
1008 /* build nodemap using PMIx_Resolve_nodes */
build_nodemap_pmix(int * nodemap,int sz,int * p_max_node_id)1009 int build_nodemap_pmix(int *nodemap, int sz, int *p_max_node_id)
1010 {
1011     int mpi_errno = MPI_SUCCESS;
1012     int pmi_errno;
1013     char *nodelist = NULL, *node = NULL;
1014     pmix_proc_t *procs = NULL;
1015     size_t nprocs, node_id = 0;
1016 
1017     pmi_errno = PMIx_Resolve_nodes(pmix_proc.nspace, &nodelist);
1018     MPIR_ERR_CHKANDJUMP1(pmi_errno != PMIX_SUCCESS, mpi_errno, MPI_ERR_OTHER,
1019                          "**pmix_resolve_nodes", "**pmix_resolve_nodes %d", pmi_errno);
1020     MPIR_Assert(nodelist);
1021 
1022     node = strtok(nodelist, ",");
1023     while (node) {
1024         pmi_errno = PMIx_Resolve_peers(node, pmix_proc.nspace, &procs, &nprocs);
1025         MPIR_ERR_CHKANDJUMP1(pmi_errno != PMIX_SUCCESS, mpi_errno, MPI_ERR_OTHER,
1026                              "**pmix_resolve_peers", "**pmix_resolve_peers %d", pmi_errno);
1027         for (int i = 0; i < nprocs; i++) {
1028             nodemap[procs[i].rank] = node_id;
1029         }
1030         node_id++;
1031         node = strtok(NULL, ",");
1032     }
1033     *p_max_node_id = node_id - 1;
1034     /* PMIx latest adds pmix_free. We should switch to that at some point */
1035     MPL_external_free(nodelist);
1036     PMIX_PROC_FREE(procs, nprocs);
1037 
1038   fn_exit:
1039     return mpi_errno;
1040   fn_fail:
1041     goto fn_exit;
1042 }
1043 
1044 #endif
1045 
1046 /* allocate and populate MPIR_Process.node_local_map and MPIR_Process.node_root_map */
build_locality(void)1047 static int build_locality(void)
1048 {
1049     int local_rank = -1;
1050     int local_size = 0;
1051     int *node_root_map, *node_local_map;
1052 
1053     int rank = MPIR_Process.rank;
1054     int size = MPIR_Process.size;
1055     int *node_map = MPIR_Process.node_map;
1056     int num_nodes = MPIR_Process.num_nodes;
1057     int local_node_id = node_map[rank];
1058 
1059     node_root_map = MPL_malloc(num_nodes * sizeof(int), MPL_MEM_ADDRESS);
1060     for (int i = 0; i < num_nodes; i++) {
1061         node_root_map[i] = -1;
1062     }
1063 
1064     for (int i = 0; i < size; i++) {
1065         int node_id = node_map[i];
1066         if (node_root_map[node_id] < 0) {
1067             node_root_map[node_id] = i;
1068         }
1069         if (node_id == local_node_id) {
1070             local_size++;
1071         }
1072     }
1073 
1074     node_local_map = MPL_malloc(local_size * sizeof(int), MPL_MEM_ADDRESS);
1075     int j = 0;
1076     for (int i = 0; i < size; i++) {
1077         int node_id = node_map[i];
1078         if (node_id == local_node_id) {
1079             node_local_map[j] = i;
1080             if (i == rank) {
1081                 local_rank = j;
1082             }
1083             j++;
1084         }
1085     }
1086 
1087     MPIR_Process.node_root_map = node_root_map;
1088     MPIR_Process.node_local_map = node_local_map;
1089     MPIR_Process.local_size = local_size;
1090     MPIR_Process.local_rank = local_rank;
1091 
1092     return MPI_SUCCESS;
1093 }
1094 
1095 /* similar to functions in mpl/src/str/mpl_argstr.c, but much simpler */
hex(unsigned char c)1096 static int hex(unsigned char c)
1097 {
1098     if (c >= '0' && c <= '9') {
1099         return c - '0';
1100     } else if (c >= 'a' && c <= 'f') {
1101         return 10 + c - 'a';
1102     } else if (c >= 'A' && c <= 'F') {
1103         return 10 + c - 'A';
1104     } else {
1105         MPIR_Assert(0);
1106         return -1;
1107     }
1108 }
1109 
encode(int size,const char * src,char * dest)1110 static void encode(int size, const char *src, char *dest)
1111 {
1112     for (int i = 0; i < size; i++) {
1113         MPL_snprintf(dest, 3, "%02X", (unsigned char) *src);
1114         src++;
1115         dest += 2;
1116     }
1117 }
1118 
decode(int size,const char * src,char * dest)1119 static void decode(int size, const char *src, char *dest)
1120 {
1121     for (int i = 0; i < size; i++) {
1122         *dest = (char) (hex(src[0]) << 4) + hex(src[1]);
1123         src += 2;
1124         dest++;
1125     }
1126 }
1127 
1128 /* static functions used in MPIR_pmi_spawn_multiple */
1129 #if defined(USE_PMI1_API) || defined(USE_PMI2_API)
mpi_to_pmi_keyvals(MPIR_Info * info_ptr,PMI_keyval_t ** kv_ptr,int * nkeys_ptr)1130 static int mpi_to_pmi_keyvals(MPIR_Info * info_ptr, PMI_keyval_t ** kv_ptr, int *nkeys_ptr)
1131 {
1132     char key[MPI_MAX_INFO_KEY];
1133     PMI_keyval_t *kv = 0;
1134     int nkeys = 0, vallen, flag, mpi_errno = MPI_SUCCESS;
1135 
1136     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPI_TO_PMI_KEYVALS);
1137     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPI_TO_PMI_KEYVALS);
1138 
1139     if (!info_ptr || info_ptr->handle == MPI_INFO_NULL)
1140         goto fn_exit;
1141 
1142     MPIR_Info_get_nkeys_impl(info_ptr, &nkeys);
1143 
1144     if (nkeys == 0)
1145         goto fn_exit;
1146 
1147     kv = (PMI_keyval_t *) MPL_malloc(nkeys * sizeof(PMI_keyval_t), MPL_MEM_BUFFER);
1148 
1149     for (int i = 0; i < nkeys; i++) {
1150         mpi_errno = MPIR_Info_get_nthkey_impl(info_ptr, i, key);
1151         MPIR_ERR_CHECK(mpi_errno);
1152         MPIR_Info_get_valuelen_impl(info_ptr, key, &vallen, &flag);
1153         kv[i].key = MPL_strdup(key);
1154         kv[i].val = (char *) MPL_malloc(vallen + 1, MPL_MEM_BUFFER);
1155         MPIR_Info_get_impl(info_ptr, key, vallen + 1, kv[i].val, &flag);
1156     }
1157 
1158   fn_exit:
1159     *kv_ptr = kv;
1160     *nkeys_ptr = nkeys;
1161     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPI_TO_PMI_KEYVALS);
1162     return mpi_errno;
1163 
1164   fn_fail:
1165     goto fn_exit;
1166 }
1167 
free_pmi_keyvals(PMI_keyval_t ** kv,int size,int * counts)1168 static void free_pmi_keyvals(PMI_keyval_t ** kv, int size, int *counts)
1169 {
1170     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_FREE_PMI_KEYVALS);
1171     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_FREE_PMI_KEYVALS);
1172 
1173     for (int i = 0; i < size; i++) {
1174         for (int j = 0; j < counts[i]; j++) {
1175             MPL_free((char *) kv[i][j].key);
1176             MPL_free(kv[i][j].val);
1177         }
1178         MPL_free(kv[i]);
1179     }
1180 
1181     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_FREE_PMI_KEYVALS);
1182 }
1183 #endif /* USE_PMI1_API or USE_PMI2_API */
1184