1 /*
2 * Copyright (C) by Argonne National Laboratory
3 * See COPYRIGHT in top-level directory
4 */
5
6 #include <mpir_pmi.h>
7 #include <mpiimpl.h>
8 #include "mpir_nodemap.h"
9
10 static int build_nodemap(int *nodemap, int sz, int *p_max_node_id);
11 static int build_locality(void);
12
13 static int pmi_version = 1;
14 static int pmi_subversion = 1;
15
16 static int pmi_max_key_size;
17 static int pmi_max_val_size;
18
19 #ifdef USE_PMI1_API
20 static int pmi_max_kvs_name_length;
21 static char *pmi_kvs_name;
22 #elif defined USE_PMI2_API
23 static char *pmi_jobid;
24 #elif defined USE_PMIX_API
25 static pmix_proc_t pmix_proc;
26 static pmix_proc_t pmix_wcproc;
27 #endif
28
MPIR_pmi_init(void)29 int MPIR_pmi_init(void)
30 {
31 int mpi_errno = MPI_SUCCESS;
32 int pmi_errno;
33
34 /* See if the user wants to override our default values */
35 MPL_env2int("PMI_VERSION", &pmi_version);
36 MPL_env2int("PMI_SUBVERSION", &pmi_subversion);
37
38 int has_parent, rank, size, appnum;
39 #ifdef USE_PMI1_API
40 pmi_errno = PMI_Init(&has_parent);
41 MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI_SUCCESS, mpi_errno, MPI_ERR_OTHER,
42 "**pmi_init", "**pmi_init %d", pmi_errno);
43 pmi_errno = PMI_Get_rank(&rank);
44 MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI_SUCCESS, mpi_errno, MPI_ERR_OTHER,
45 "**pmi_get_rank", "**pmi_get_rank %d", pmi_errno);
46 pmi_errno = PMI_Get_size(&size);
47 MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI_SUCCESS, mpi_errno, MPI_ERR_OTHER,
48 "**pmi_get_size", "**pmi_get_size %d", pmi_errno);
49 pmi_errno = PMI_Get_appnum(&appnum);
50 MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI_SUCCESS, mpi_errno, MPI_ERR_OTHER,
51 "**pmi_get_appnum", "**pmi_get_appnum %d", pmi_errno);
52
53 pmi_errno = PMI_KVS_Get_name_length_max(&pmi_max_kvs_name_length);
54 MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI_SUCCESS, mpi_errno, MPI_ERR_OTHER,
55 "**pmi_kvs_get_name_length_max",
56 "**pmi_kvs_get_name_length_max %d", pmi_errno);
57 pmi_kvs_name = (char *) MPL_malloc(pmi_max_kvs_name_length, MPL_MEM_OTHER);
58 pmi_errno = PMI_KVS_Get_my_name(pmi_kvs_name, pmi_max_kvs_name_length);
59 MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI_SUCCESS, mpi_errno, MPI_ERR_OTHER,
60 "**pmi_kvs_get_my_name", "**pmi_kvs_get_my_name %d", pmi_errno);
61
62 pmi_errno = PMI_KVS_Get_key_length_max(&pmi_max_key_size);
63 MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI_SUCCESS, mpi_errno, MPI_ERR_OTHER,
64 "**pmi_kvs_get_key_length_max",
65 "**pmi_kvs_get_key_length_max %d", pmi_errno);
66 pmi_errno = PMI_KVS_Get_value_length_max(&pmi_max_val_size);
67 MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI_SUCCESS, mpi_errno, MPI_ERR_OTHER,
68 "**pmi_kvs_get_value_length_max",
69 "**pmi_kvs_get_value_length_max %d", pmi_errno);
70
71 #elif defined USE_PMI2_API
72 pmi_max_key_size = PMI2_MAX_KEYLEN;
73 pmi_max_val_size = PMI2_MAX_VALLEN;
74
75 pmi_errno = PMI2_Init(&has_parent, &size, &rank, &appnum);
76 MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI2_SUCCESS, mpi_errno, MPI_ERR_OTHER,
77 "**pmi_init", "**pmi_init %d", pmi_errno);
78
79 pmi_jobid = (char *) MPL_malloc(PMI2_MAX_VALLEN, MPL_MEM_OTHER);
80 pmi_errno = PMI2_Job_GetId(pmi_jobid, PMI2_MAX_VALLEN);
81 MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI2_SUCCESS, mpi_errno, MPI_ERR_OTHER,
82 "**pmi_job_getid", "**pmi_job_getid %d", pmi_errno);
83
84 #elif defined USE_PMIX_API
85 pmi_max_key_size = PMIX_MAX_KEYLEN;
86 pmi_max_val_size = 1024; /* this is what PMI2_MAX_VALLEN currently set to */
87
88 pmix_value_t *pvalue = NULL;
89
90 pmi_errno = PMIx_Init(&pmix_proc, NULL, 0);
91 MPIR_ERR_CHKANDJUMP1(pmi_errno != PMIX_SUCCESS, mpi_errno, MPI_ERR_OTHER,
92 "**pmix_init", "**pmix_init %d", pmi_errno);
93
94 rank = pmix_proc.rank;
95 PMIX_PROC_CONSTRUCT(&pmix_wcproc);
96 MPL_strncpy(pmix_wcproc.nspace, pmix_proc.nspace, PMIX_MAX_NSLEN);
97 pmix_wcproc.rank = PMIX_RANK_WILDCARD;
98
99 pmi_errno = PMIx_Get(&pmix_wcproc, PMIX_JOB_SIZE, NULL, 0, &pvalue);
100 MPIR_ERR_CHKANDJUMP1(pmi_errno != PMIX_SUCCESS, mpi_errno, MPI_ERR_OTHER,
101 "**pmix_get", "**pmix_get %d", pmi_errno);
102 size = pvalue->data.uint32;
103 PMIX_VALUE_RELEASE(pvalue);
104
105 /* appnum, has_parent is not set for now */
106 appnum = 0;
107 has_parent = 0;
108
109 #endif
110 MPIR_Process.has_parent = has_parent;
111 MPIR_Process.rank = rank;
112 MPIR_Process.size = size;
113 MPIR_Process.appnum = appnum;
114
115 static int g_max_node_id = -1;
116 MPIR_Process.node_map = (int *) MPL_malloc(size * sizeof(int), MPL_MEM_ADDRESS);
117
118 mpi_errno = build_nodemap(MPIR_Process.node_map, size, &g_max_node_id);
119 MPIR_ERR_CHECK(mpi_errno);
120 MPIR_Process.num_nodes = g_max_node_id + 1;
121
122 /* allocate and populate MPIR_Process.node_local_map and MPIR_Process.node_root_map */
123 mpi_errno = build_locality();
124
125 fn_exit:
126 return mpi_errno;
127 fn_fail:
128 goto fn_exit;
129 }
130
MPIR_pmi_finalize(void)131 void MPIR_pmi_finalize(void)
132 {
133 #ifdef USE_PMI1_API
134 PMI_Finalize();
135 MPL_free(pmi_kvs_name);
136 #elif defined(USE_PMI2_API)
137 PMI2_Finalize();
138 MPL_free(pmi_jobid);
139 #elif defined(USE_PMIX_API)
140 PMIx_Finalize(NULL, 0);
141 /* pmix_proc does not need free */
142 #endif
143
144 MPL_free(MPIR_Process.node_map);
145 MPL_free(MPIR_Process.node_root_map);
146 MPL_free(MPIR_Process.node_local_map);
147 }
148
MPIR_pmi_abort(int exit_code,const char * error_msg)149 void MPIR_pmi_abort(int exit_code, const char *error_msg)
150 {
151 #ifdef USE_PMI1_API
152 PMI_Abort(exit_code, error_msg);
153 #elif defined(USE_PMI2_API)
154 PMI2_Abort(TRUE, error_msg);
155 #elif defined(USE_PMIX_API)
156 PMIx_Abort(exit_code, error_msg, NULL, 0);
157 #endif
158 }
159
160 /* getters for internal constants */
MPIR_pmi_max_key_size(void)161 int MPIR_pmi_max_key_size(void)
162 {
163 return pmi_max_key_size;
164 }
165
MPIR_pmi_max_val_size(void)166 int MPIR_pmi_max_val_size(void)
167 {
168 return pmi_max_val_size;
169 }
170
MPIR_pmi_job_id(void)171 const char *MPIR_pmi_job_id(void)
172 {
173 #ifdef USE_PMI1_API
174 return (const char *) pmi_kvs_name;
175 #elif defined USE_PMI2_API
176 return (const char *) pmi_jobid;
177 #elif defined USE_PMIX_API
178 return (const char *) pmix_proc.nspace;
179 #endif
180 }
181
182 /* wrapper functions */
MPIR_pmi_kvs_put(const char * key,const char * val)183 int MPIR_pmi_kvs_put(const char *key, const char *val)
184 {
185 int mpi_errno = MPI_SUCCESS;
186 int pmi_errno;
187
188 #ifdef USE_PMI1_API
189 pmi_errno = PMI_KVS_Put(pmi_kvs_name, key, val);
190 MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI_SUCCESS, mpi_errno, MPI_ERR_OTHER,
191 "**pmi_kvs_put", "**pmi_kvs_put %d", pmi_errno);
192 pmi_errno = PMI_KVS_Commit(pmi_kvs_name);
193 MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI_SUCCESS, mpi_errno, MPI_ERR_OTHER,
194 "**pmi_kvs_commit", "**pmi_kvs_commit %d", pmi_errno);
195 #elif defined(USE_PMI2_API)
196 pmi_errno = PMI2_KVS_Put(key, val);
197 MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI2_SUCCESS, mpi_errno, MPI_ERR_OTHER,
198 "**pmi_kvsput", "**pmi_kvsput %d", pmi_errno);
199 #elif defined(USE_PMIX_API)
200 pmix_value_t value;
201 value.type = PMIX_STRING;
202 value.data.string = (char *) val;
203 pmi_errno = PMIx_Put(PMIX_GLOBAL, key, &value);
204 MPIR_ERR_CHKANDJUMP1(pmi_errno != PMIX_SUCCESS, mpi_errno, MPI_ERR_OTHER,
205 "**pmix_put", "**pmix_put %d", pmi_errno);
206 pmi_errno = PMIx_Commit();
207 MPIR_ERR_CHKANDJUMP1(pmi_errno != PMIX_SUCCESS, mpi_errno, MPI_ERR_OTHER,
208 "**pmix_commit", "**pmix_commit %d", pmi_errno);
209 #endif
210
211 fn_exit:
212 return mpi_errno;
213 fn_fail:
214 goto fn_exit;
215 }
216
217 /* NOTE: src is a hint, use src = -1 if not known */
MPIR_pmi_kvs_get(int src,const char * key,char * val,int val_size)218 int MPIR_pmi_kvs_get(int src, const char *key, char *val, int val_size)
219 {
220 int mpi_errno = MPI_SUCCESS;
221 int pmi_errno;
222
223 #ifdef USE_PMI1_API
224 /* src is not used in PMI1 */
225 pmi_errno = PMI_KVS_Get(pmi_kvs_name, key, val, val_size);
226 MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI_SUCCESS, mpi_errno, MPI_ERR_OTHER,
227 "**pmi_kvs_get", "**pmi_kvs_get %d", pmi_errno);
228 #elif defined(USE_PMI2_API)
229 if (src < 0)
230 src = PMI2_ID_NULL;
231 int out_len;
232 pmi_errno = PMI2_KVS_Get(pmi_jobid, src, key, val, val_size, &out_len);
233 MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI2_SUCCESS, mpi_errno, MPI_ERR_OTHER,
234 "**pmi_kvsget", "**pmi_kvsget %d", pmi_errno);
235 #elif defined(USE_PMIX_API)
236 pmix_value_t *pvalue;
237 if (src < 0) {
238 pmi_errno = PMIx_Get(NULL, key, NULL, 0, &pvalue);
239 } else {
240 pmix_proc_t proc;
241 PMIX_PROC_CONSTRUCT(&proc);
242 proc.rank = src;
243
244 pmi_errno = PMIx_Get(&proc, key, NULL, 0, &pvalue);
245 }
246 MPIR_ERR_CHKANDJUMP1(pmi_errno != PMIX_SUCCESS, mpi_errno, MPI_ERR_OTHER,
247 "**pmix_get", "**pmix_get %d", pmi_errno);
248 strncpy(val, pvalue->data.string, val_size);
249 PMIX_VALUE_RELEASE(pvalue);
250 #endif
251
252 fn_exit:
253 return mpi_errno;
254 fn_fail:
255 goto fn_exit;
256 }
257
258 /* ---- utils functions ---- */
259
MPIR_pmi_barrier(void)260 int MPIR_pmi_barrier(void)
261 {
262 int mpi_errno = MPI_SUCCESS;
263 int pmi_errno;
264
265 #ifdef USE_PMI1_API
266 pmi_errno = PMI_Barrier();
267 MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI_SUCCESS, mpi_errno, MPI_ERR_OTHER,
268 "**pmi_barrier", "**pmi_barrier %d", pmi_errno);
269 #elif defined(USE_PMI2_API)
270 pmi_errno = PMI2_KVS_Fence();
271 MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI2_SUCCESS, mpi_errno, MPI_ERR_OTHER,
272 "**pmi_kvsfence", "**pmi_kvsfence %d", pmi_errno);
273 #elif defined(USE_PMIX_API)
274 pmix_info_t *info;
275 PMIX_INFO_CREATE(info, 1);
276 int flag = 1;
277 PMIX_INFO_LOAD(info, PMIX_COLLECT_DATA, &flag, PMIX_BOOL);
278
279 /* use global wildcard proc set */
280 pmi_errno = PMIx_Fence(&pmix_wcproc, 1, info, 1);
281 MPIR_ERR_CHKANDJUMP1(pmi_errno != PMIX_SUCCESS, mpi_errno, MPI_ERR_OTHER,
282 "**pmix_fence", "**pmix_fence %d", pmi_errno);
283 PMIX_INFO_FREE(info, 1);
284 #endif
285
286 fn_exit:
287 return mpi_errno;
288 fn_fail:
289 goto fn_exit;
290 }
291
MPIR_pmi_barrier_local(void)292 int MPIR_pmi_barrier_local(void)
293 {
294 #if defined(USE_PMIX_API)
295 int mpi_errno = MPI_SUCCESS;
296 int pmi_errno;
297 int local_size = MPIR_Process.local_size;
298 pmix_proc_t *procs = MPL_malloc(local_size * sizeof(pmix_proc_t), MPL_MEM_OTHER);
299 for (int i = 0; i < local_size; i++) {
300 PMIX_PROC_CONSTRUCT(&procs[i]);
301 strncpy(procs[i].nspace, pmix_proc.nspace, PMIX_MAX_NSLEN);
302 procs[i].rank = MPIR_Process.node_local_map[i];
303 }
304
305 pmix_info_t *info;
306 int flag = 1;
307 PMIX_INFO_CREATE(info, 1);
308 PMIX_INFO_LOAD(info, PMIX_COLLECT_DATA, &flag, PMIX_BOOL);
309
310 pmi_errno = PMIx_Fence(procs, local_size, info, 1);
311 MPIR_ERR_CHKANDJUMP1(pmi_errno != PMIX_SUCCESS, mpi_errno, MPI_ERR_OTHER, "**pmix_fence",
312 "**pmix_fence %d", pmi_errno);
313
314 PMIX_INFO_FREE(info, 1);
315 MPL_free(procs);
316
317 fn_exit:
318 return mpi_errno;
319 fn_fail:
320 goto fn_exit;
321 #else
322 /* If local barrier is not supported (PMI1 and PMI2), simply fallback */
323 return MPIR_pmi_barrier();
324 #endif
325 }
326
327 /* declare static functions used in bcast/allgather */
328 static void encode(int size, const char *src, char *dest);
329 static void decode(int size, const char *src, char *dest);
330
331 /* is_local is a hint that we optimize for node local access when we can */
optimized_put(const char * key,const char * val,int is_local)332 static int optimized_put(const char *key, const char *val, int is_local)
333 {
334 int mpi_errno = MPI_SUCCESS;
335 #if defined(USE_PMI1_API)
336 mpi_errno = MPIR_pmi_kvs_put(key, val);
337 #elif defined(USE_PMI2_API)
338 if (!is_local) {
339 mpi_errno = MPIR_pmi_kvs_put(key, val);
340 } else {
341 int pmi_errno = PMI2_Info_PutNodeAttr(key, val);
342 MPIR_ERR_CHKANDJUMP(pmi_errno != PMI2_SUCCESS, mpi_errno, MPI_ERR_OTHER,
343 "**pmi_putnodeattr");
344 }
345 #elif defined(USE_PMIX_API)
346 int pmi_errno;
347 pmix_value_t value;
348 value.type = PMIX_STRING;
349 value.data.string = (char *) val;
350 pmi_errno = PMIx_Put(is_local ? PMIX_LOCAL : PMIX_GLOBAL, key, &value);
351 MPIR_ERR_CHKANDJUMP1(pmi_errno != PMIX_SUCCESS, mpi_errno, MPI_ERR_OTHER,
352 "**pmix_put", "**pmix_put %d", pmi_errno);
353 pmi_errno = PMIx_Commit();
354 MPIR_ERR_CHKANDJUMP1(pmi_errno != PMIX_SUCCESS, mpi_errno, MPI_ERR_OTHER,
355 "**pmix_commit", "**pmix_commit %d", pmi_errno);
356 #endif
357
358 fn_exit:
359 return mpi_errno;
360 fn_fail:
361 goto fn_exit;
362 }
363
optimized_get(int src,const char * key,char * val,int valsize,int is_local)364 static int optimized_get(int src, const char *key, char *val, int valsize, int is_local)
365 {
366 #if defined(USE_PMI1_API)
367 return MPIR_pmi_kvs_get(src, key, val, valsize);
368 #elif defined(USE_PMI2_API)
369 if (is_local) {
370 int mpi_errno = MPI_SUCCESS;
371 int found;
372 int pmi_errno = PMI2_Info_GetNodeAttr(key, val, valsize, &found, TRUE);
373 if (pmi_errno != PMI2_SUCCESS) {
374 MPIR_ERR_SET(mpi_errno, MPI_ERR_OTHER, "**pmi_getnodeattr");
375 } else if (!found) {
376 MPIR_ERR_SET(mpi_errno, MPI_ERR_OTHER, "**pmi_getnodeattr");
377 }
378 return mpi_errno;
379 } else {
380 return MPIR_pmi_kvs_get(src, key, val, valsize);
381 }
382 #else
383 return MPIR_pmi_kvs_get(src, key, val, valsize);
384 #endif
385 }
386
387 /* higher-level binary put/get:
388 * 1. binary encoding/decoding
389 * 2. chops long values into multiple segments
390 * 3. uses optimized_put/get for the case of node-level access
391 */
put_ex(const char * key,const void * buf,int bufsize,int is_local)392 static int put_ex(const char *key, const void *buf, int bufsize, int is_local)
393 {
394 int mpi_errno = MPI_SUCCESS;
395 #if defined(USE_PMI1_API) || defined(USE_PMI2_API)
396 char *val = MPL_malloc(pmi_max_val_size, MPL_MEM_OTHER);
397 /* reserve some spaces for '\0' and maybe newlines
398 * (depends on pmi implementations, and may not be sufficient) */
399 int segsize = (pmi_max_val_size - 2) / 2;
400 if (bufsize < segsize) {
401 encode(bufsize, buf, val);
402 mpi_errno = optimized_put(key, val, is_local);
403 MPIR_ERR_CHECK(mpi_errno);
404 } else {
405 int num_segs = bufsize / segsize;
406 if (bufsize % segsize > 0) {
407 num_segs++;
408 }
409 MPL_snprintf(val, pmi_max_val_size, "segments=%d", num_segs);
410 mpi_errno = MPIR_pmi_kvs_put(key, val);
411 MPIR_ERR_CHECK(mpi_errno);
412 for (int i = 0; i < num_segs; i++) {
413 char seg_key[50];
414 sprintf(seg_key, "%s-seg-%d/%d", key, i + 1, num_segs);
415 int n = segsize;
416 if (i == num_segs - 1) {
417 n = bufsize - segsize * (num_segs - 1);
418 }
419 encode(n, (char *) buf + i * segsize, val);
420 mpi_errno = optimized_put(seg_key, val, is_local);
421 MPIR_ERR_CHECK(mpi_errno);
422 }
423 }
424 #elif defined(USE_PMIX_API)
425 int n = bufsize * 2 + 1;
426 char *val = MPL_malloc(n, MPL_MEM_OTHER);
427 encode(bufsize, buf, val);
428 mpi_errno = optimized_put(key, val, is_local);
429 MPIR_ERR_CHECK(mpi_errno);
430 #endif
431 fn_exit:
432 MPL_free(val);
433 return mpi_errno;
434 fn_fail:
435 goto fn_exit;
436 }
437
get_ex(int src,const char * key,void * buf,int * p_size,int is_local)438 static int get_ex(int src, const char *key, void *buf, int *p_size, int is_local)
439 {
440 int mpi_errno = MPI_SUCCESS;
441 char *val = MPL_malloc(pmi_max_val_size, MPL_MEM_OTHER);
442 int segsize = (pmi_max_val_size - 1) / 2;
443
444 MPIR_Assert(p_size);
445 MPIR_Assert(*p_size > 0);
446 int bufsize = *p_size;
447 int got_size;
448
449 mpi_errno = optimized_get(src, key, val, pmi_max_val_size, is_local);
450 MPIR_ERR_CHECK(mpi_errno);
451 if (strncmp(val, "segments=", 9) == 0) {
452 int num_segs = atoi(val + 9);
453 got_size = 0;
454 for (int i = 0; i < num_segs; i++) {
455 char seg_key[50];
456 sprintf(seg_key, "%s-seg-%d/%d", key, i + 1, num_segs);
457 mpi_errno = optimized_get(src, seg_key, val, pmi_max_val_size, is_local);
458 MPIR_ERR_CHECK(mpi_errno);
459 int n = strlen(val) / 2; /* 2-to-1 decode */
460 if (i < num_segs - 1) {
461 MPIR_Assert(n == segsize);
462 } else {
463 MPIR_Assert(n <= segsize);
464 }
465 decode(n, val, (char *) buf + i * segsize);
466 got_size += n;
467 }
468 } else {
469 int n = strlen(val) / 2; /* 2-to-1 decode */
470 decode(n, val, (char *) buf);
471 got_size = n;
472 }
473 MPIR_Assert(got_size <= bufsize);
474 if (got_size < bufsize) {
475 ((char *) buf)[got_size] = '\0';
476 }
477
478 *p_size = got_size;
479
480 fn_exit:
481 MPL_free(val);
482 return mpi_errno;
483 fn_fail:
484 goto fn_exit;
485 }
486
optional_bcast_barrier(MPIR_PMI_DOMAIN domain)487 static int optional_bcast_barrier(MPIR_PMI_DOMAIN domain)
488 {
489 #if defined(USE_PMI1_API)
490 /* unless bcast is skipped alltogether */
491 if (domain == MPIR_PMI_DOMAIN_ALL && MPIR_Process.size == 1) {
492 return MPI_SUCCESS;
493 } else if (domain == MPIR_PMI_DOMAIN_NODE_ROOTS && MPIR_Process.num_nodes == 1) {
494 return MPI_SUCCESS;
495 } else if (domain == MPIR_PMI_DOMAIN_LOCAL && MPIR_Process.size == MPIR_Process.num_nodes) {
496 return MPI_SUCCESS;
497 }
498 #elif defined(USE_PMI2_API)
499 if (domain == MPIR_PMI_DOMAIN_ALL && MPIR_Process.size == 1) {
500 return MPI_SUCCESS;
501 } else if (domain == MPIR_PMI_DOMAIN_NODE_ROOTS && MPIR_Process.num_nodes == 1) {
502 return MPI_SUCCESS;
503 } else if (domain == MPIR_PMI_DOMAIN_LOCAL) {
504 /* PMI2 local uses Put/GetNodeAttr, no need for barrier */
505 return MPI_SUCCESS;
506 }
507 #elif defined(USE_PMIx_API)
508 /* PMIx will block/wait, so barrier unnecessary */
509 return MPI_SUCCESS;
510 #endif
511 return MPIR_pmi_barrier();
512 }
513
MPIR_pmi_bcast(void * buf,int bufsize,MPIR_PMI_DOMAIN domain)514 int MPIR_pmi_bcast(void *buf, int bufsize, MPIR_PMI_DOMAIN domain)
515 {
516 int mpi_errno = MPI_SUCCESS;
517
518 int rank = MPIR_Process.rank;
519 int local_node_id = MPIR_Process.node_map[rank];
520 int node_root = MPIR_Process.node_root_map[local_node_id];
521 int is_node_root = (node_root == rank);
522
523 int in_domain, is_root, is_local, bcast_size;
524 if (domain == MPIR_PMI_DOMAIN_NODE_ROOTS && !is_node_root) {
525 in_domain = 0;
526 } else {
527 in_domain = 1;
528 }
529 if (rank == 0 || (domain == MPIR_PMI_DOMAIN_LOCAL && is_node_root)) {
530 is_root = 1;
531 } else {
532 is_root = 0;
533 }
534 is_local = (domain == MPIR_PMI_DOMAIN_LOCAL);
535
536 bcast_size = MPIR_Process.size;
537 if (domain == MPIR_PMI_DOMAIN_NODE_ROOTS) {
538 bcast_size = MPIR_Process.num_nodes;
539 } else if (domain == MPIR_PMI_DOMAIN_LOCAL) {
540 bcast_size = MPIR_Process.local_size;
541 }
542 if (bcast_size == 1) {
543 in_domain = 0;
544 }
545
546 char key[50];
547 int root;
548 static int bcast_seq = 0;
549
550 if (!in_domain) {
551 /* PMI_Barrier may require all process to participate */
552 mpi_errno = optional_bcast_barrier(domain);
553 MPIR_ERR_CHECK(mpi_errno);
554 } else {
555 MPIR_Assert(buf);
556 MPIR_Assert(bufsize > 0);
557
558 bcast_seq++;
559
560 root = 0;
561 if (domain == MPIR_PMI_DOMAIN_LOCAL) {
562 root = node_root;
563 }
564 /* add root to the key since potentially we may have multiple root(s)
565 * on a single node due to odd-even-cliques */
566 sprintf(key, "-bcast-%d-%d", bcast_seq, root);
567
568 if (is_root) {
569 mpi_errno = put_ex(key, buf, bufsize, is_local);
570 MPIR_ERR_CHECK(mpi_errno);
571 }
572
573 mpi_errno = optional_bcast_barrier(domain);
574 MPIR_ERR_CHECK(mpi_errno);
575
576 if (!is_root) {
577 int got_size = bufsize;
578 mpi_errno = get_ex(root, key, buf, &got_size, is_local);
579 MPIR_ERR_CHECK(mpi_errno);
580 }
581 }
582
583 fn_exit:
584 return mpi_errno;
585 fn_fail:
586 goto fn_exit;
587 }
588
MPIR_pmi_allgather(const void * sendbuf,int sendsize,void * recvbuf,int recvsize,MPIR_PMI_DOMAIN domain)589 int MPIR_pmi_allgather(const void *sendbuf, int sendsize, void *recvbuf, int recvsize,
590 MPIR_PMI_DOMAIN domain)
591 {
592 int mpi_errno = MPI_SUCCESS;
593
594 MPIR_Assert(domain != MPIR_PMI_DOMAIN_LOCAL);
595
596 int local_node_id = MPIR_Process.node_map[MPIR_Process.rank];
597 int is_node_root = (MPIR_Process.node_root_map[local_node_id] == MPIR_Process.rank);
598 int in_domain = 1;
599 if (domain == MPIR_PMI_DOMAIN_NODE_ROOTS && !is_node_root) {
600 in_domain = 0;
601 }
602
603 static int allgather_seq = 0;
604 allgather_seq++;
605
606 char key[50];
607 sprintf(key, "-allgather-%d-%d", allgather_seq, MPIR_Process.rank);
608
609 if (in_domain) {
610 mpi_errno = put_ex(key, sendbuf, sendsize, 0);
611 MPIR_ERR_CHECK(mpi_errno);
612 }
613 #ifndef USE_PMIX_API
614 /* PMIx will wait, so barrier unnecessary */
615 mpi_errno = MPIR_pmi_barrier();
616 MPIR_ERR_CHECK(mpi_errno);
617 #endif
618
619 if (in_domain) {
620 int domain_size = MPIR_Process.size;
621 if (domain == MPIR_PMI_DOMAIN_NODE_ROOTS) {
622 domain_size = MPIR_Process.num_nodes;
623 }
624 for (int i = 0; i < domain_size; i++) {
625 int rank = i;
626 if (domain == MPIR_PMI_DOMAIN_NODE_ROOTS) {
627 rank = MPIR_Process.node_root_map[i];
628 }
629 sprintf(key, "-allgather-%d-%d", allgather_seq, rank);
630 int got_size = recvsize;
631 mpi_errno = get_ex(rank, key, (unsigned char *) recvbuf + i * recvsize, &got_size, 0);
632 MPIR_ERR_CHECK(mpi_errno);
633 }
634 }
635
636 fn_exit:
637 return mpi_errno;
638 fn_fail:
639 goto fn_exit;
640 }
641
642 /* This version assumes shm_buf is shared across local procs. Each process
643 * participate in the gather part by distributing the task over local procs.
644 *
645 * NOTE: the behavior is different with MPIR_pmi_allgather when domain is
646 * MPIR_PMI_DOMAIN_NODE_ROOTS. With MPIR_pmi_allgather, only the root_nodes participate.
647 */
MPIR_pmi_allgather_shm(const void * sendbuf,int sendsize,void * shm_buf,int recvsize,MPIR_PMI_DOMAIN domain)648 int MPIR_pmi_allgather_shm(const void *sendbuf, int sendsize, void *shm_buf, int recvsize,
649 MPIR_PMI_DOMAIN domain)
650 {
651 int mpi_errno = MPI_SUCCESS;
652
653 MPIR_Assert(domain != MPIR_PMI_DOMAIN_LOCAL);
654
655 int rank = MPIR_Process.rank;
656 int size = MPIR_Process.size;
657 int local_size = MPIR_Process.local_size;
658 int local_rank = MPIR_Process.local_rank;
659 int local_node_id = MPIR_Process.node_map[rank];
660 int node_root = MPIR_Process.node_root_map[local_node_id];
661 int is_node_root = (node_root == MPIR_Process.rank);
662
663 static int allgather_shm_seq = 0;
664 allgather_shm_seq++;
665
666 char key[50];
667 sprintf(key, "-allgather-shm-%d-%d", allgather_shm_seq, rank);
668
669 /* in roots-only, non-roots would skip the put */
670 if (domain != MPIR_PMI_DOMAIN_NODE_ROOTS || is_node_root) {
671 mpi_errno = put_ex(key, (unsigned char *) sendbuf, sendsize, 0);
672 MPIR_ERR_CHECK(mpi_errno);
673 }
674
675 mpi_errno = MPIR_pmi_barrier();
676 MPIR_ERR_CHECK(mpi_errno);
677
678 /* Each rank need get val from "size" ranks, divide the task evenly over local ranks */
679 if (domain == MPIR_PMI_DOMAIN_NODE_ROOTS) {
680 size = MPIR_Process.num_nodes;
681 }
682 int per_local_rank = size / local_size;
683 if (per_local_rank * local_size < size) {
684 per_local_rank++;
685 }
686 int start = local_rank * per_local_rank;
687 int end = start + per_local_rank;
688 if (end > size) {
689 end = size;
690 }
691 for (int i = start; i < end; i++) {
692 int src = i;
693 if (domain == MPIR_PMI_DOMAIN_NODE_ROOTS) {
694 src = MPIR_Process.node_root_map[i];
695 }
696 sprintf(key, "-allgather-shm-%d-%d", allgather_shm_seq, src);
697 int got_size = recvsize;
698 mpi_errno = get_ex(src, key, (unsigned char *) shm_buf + i * recvsize, &got_size, 0);
699 MPIR_ERR_CHECK(mpi_errno);
700 MPIR_Assert(got_size <= recvsize);
701 }
702
703 fn_exit:
704 return mpi_errno;
705 fn_fail:
706 goto fn_exit;
707 }
708
MPIR_pmi_get_universe_size(int * universe_size)709 int MPIR_pmi_get_universe_size(int *universe_size)
710 {
711 int mpi_errno = MPI_SUCCESS;
712 int pmi_errno;
713
714 #ifdef USE_PMI1_API
715 pmi_errno = PMI_Get_universe_size(universe_size);
716 MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI_SUCCESS, mpi_errno, MPI_ERR_OTHER,
717 "**pmi_get_universe_size", "**pmi_get_universe_size %d", pmi_errno);
718 #elif defined(USE_PMI2_API)
719 char val[PMI2_MAX_VALLEN];
720 int found = 0;
721 char *endptr;
722
723 pmi_errno = PMI2_Info_GetJobAttr("universeSize", val, sizeof(val), &found);
724 MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI2_SUCCESS, mpi_errno, MPI_ERR_OTHER,
725 "**pmi_getjobattr", "**pmi_getjobattr %d", pmi_errno);
726 if (!found) {
727 *universe_size = MPIR_UNIVERSE_SIZE_NOT_AVAILABLE;
728 } else {
729 *universe_size = strtol(val, &endptr, 0);
730 MPIR_ERR_CHKINTERNAL(endptr - val != strlen(val), mpi_errno, "can't parse universe size");
731 }
732 #elif defined(USE_PMIX_API)
733 pmix_value_t *pvalue = NULL;
734
735 pmi_errno = PMIx_Get(&pmix_wcproc, PMIX_UNIV_SIZE, NULL, 0, &pvalue);
736 MPIR_ERR_CHKANDJUMP1(pmi_errno != PMIX_SUCCESS, mpi_errno, MPI_ERR_OTHER,
737 "**pmix_get", "**pmix_get %d", pmi_errno);
738 *universe_size = pvalue->data.uint32;
739 PMIX_VALUE_RELEASE(pvalue);
740 #endif
741 fn_exit:
742 return mpi_errno;
743 fn_fail:
744 goto fn_exit;
745 }
746
MPIR_pmi_get_failed_procs(void)747 char *MPIR_pmi_get_failed_procs(void)
748 {
749 int pmi_errno;
750 char *failed_procs_string = NULL;
751
752 failed_procs_string = MPL_malloc(pmi_max_val_size, MPL_MEM_OTHER);
753 MPIR_Assert(failed_procs_string);
754 #ifdef USE_PMI1_API
755 pmi_errno = PMI_KVS_Get(pmi_kvs_name, "PMI_dead_processes",
756 failed_procs_string, pmi_max_val_size);
757 if (pmi_errno != PMI_SUCCESS)
758 goto fn_fail;
759 #elif defined(USE_PMI2_API)
760 int out_len;
761 pmi_errno = PMI2_KVS_Get(pmi_jobid, PMI2_ID_NULL, "PMI_dead_processes",
762 failed_procs_string, pmi_max_val_size, &out_len);
763 if (pmi_errno != PMI2_SUCCESS)
764 goto fn_fail;
765 #elif defined(USE_PMIX_API)
766 goto fn_fail;
767 #endif
768
769 fn_exit:
770 return failed_procs_string;
771 fn_fail:
772 /* FIXME: approprate error messages here? */
773 MPL_free(failed_procs_string);
774 failed_procs_string = NULL;
775 goto fn_exit;
776 }
777
778 /* static functions only for MPIR_pmi_spawn_multiple */
779 #if defined(USE_PMI1_API) || defined(USE_PMI2_API)
780 /* PMI_keyval_t is only defined in PMI1 or PMI2 */
781 static int mpi_to_pmi_keyvals(MPIR_Info * info_ptr, PMI_keyval_t ** kv_ptr, int *nkeys_ptr);
782 static void free_pmi_keyvals(PMI_keyval_t ** kv, int size, int *counts);
783 #endif
784
785 /* NOTE: MPIR_pmi_spawn_multiple is to be called by a single root spawning process */
MPIR_pmi_spawn_multiple(int count,char * commands[],char ** argvs[],const int maxprocs[],MPIR_Info * info_ptrs[],int num_preput_keyval,struct MPIR_PMI_KEYVAL * preput_keyvals,int * pmi_errcodes)786 int MPIR_pmi_spawn_multiple(int count, char *commands[], char **argvs[],
787 const int maxprocs[], MPIR_Info * info_ptrs[],
788 int num_preput_keyval, struct MPIR_PMI_KEYVAL *preput_keyvals,
789 int *pmi_errcodes)
790 {
791 int mpi_errno = MPI_SUCCESS;
792 int pmi_errno;
793
794 #ifdef USE_PMI1_API
795 int *info_keyval_sizes = NULL;
796 PMI_keyval_t **info_keyval_vectors = NULL;
797
798 info_keyval_sizes = (int *) MPL_malloc(count * sizeof(int), MPL_MEM_BUFFER);
799 MPIR_ERR_CHKANDJUMP(!info_keyval_sizes, mpi_errno, MPI_ERR_OTHER, "**nomem");
800
801 info_keyval_vectors =
802 (PMI_keyval_t **) MPL_malloc(count * sizeof(PMI_keyval_t *), MPL_MEM_BUFFER);
803 MPIR_ERR_CHKANDJUMP(!info_keyval_vectors, mpi_errno, MPI_ERR_OTHER, "**nomem");
804
805 if (!info_ptrs) {
806 for (int i = 0; i < count; i++) {
807 info_keyval_vectors[i] = 0;
808 info_keyval_sizes[i] = 0;
809 }
810 } else {
811 for (int i = 0; i < count; i++) {
812 mpi_errno = mpi_to_pmi_keyvals(info_ptrs[i], &info_keyval_vectors[i],
813 &info_keyval_sizes[i]);
814 MPIR_ERR_CHECK(mpi_errno);
815 }
816 }
817
818 pmi_errno = PMI_Spawn_multiple(count, (const char **) commands, (const char ***) argvs,
819 maxprocs,
820 info_keyval_sizes,
821 (const PMI_keyval_t **) info_keyval_vectors,
822 num_preput_keyval, (PMI_keyval_t *) preput_keyvals,
823 pmi_errcodes);
824
825 MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI_SUCCESS, mpi_errno, MPI_ERR_OTHER,
826 "**pmi_spawn_multiple", "**pmi_spawn_multiple %d", pmi_errno);
827 #elif defined(USE_PMI2_API)
828 /* not supported yet */
829 MPIR_Assert(0);
830 #elif defined(USE_PMIX_API)
831 /* not supported yet */
832 MPIR_Assert(0);
833 #endif
834
835 fn_exit:
836 #ifdef USE_PMI1_API
837 if (info_keyval_vectors) {
838 free_pmi_keyvals(info_keyval_vectors, count, info_keyval_sizes);
839 MPL_free(info_keyval_vectors);
840 }
841
842 MPL_free(info_keyval_sizes);
843 #endif
844
845 return mpi_errno;
846 fn_fail:
847 goto fn_exit;
848 }
849
850 /* ---- static functions ---- */
851
852 /* The following static function declares are only for build_nodemap() */
853 static int get_option_no_local(void);
854 static int get_option_num_cliques(void);
855 static int build_nodemap_nolocal(int *nodemap, int sz, int *p_max_node_id);
856 static int build_nodemap_roundrobin(int num_cliques, int *nodemap, int sz, int *p_max_node_id);
857
858 #ifdef USE_PMI1_API
859 static int build_nodemap_pmi1(int *nodemap, int sz, int *p_max_node_id);
860 static int build_nodemap_fallback(int *nodemap, int sz, int *p_max_node_id);
861 #elif defined(USE_PMI2_API)
862 static int build_nodemap_pmi2(int *nodemap, int sz, int *p_max_node_id);
863 #elif defined(USE_PMIX_API)
864 static int build_nodemap_pmix(int *nodemap, int sz, int *p_max_node_id);
865 #endif
866
build_nodemap(int * nodemap,int sz,int * p_max_node_id)867 static int build_nodemap(int *nodemap, int sz, int *p_max_node_id)
868 {
869 int mpi_errno = MPI_SUCCESS;
870
871 if (sz == 1 || get_option_no_local()) {
872 mpi_errno = build_nodemap_nolocal(nodemap, sz, p_max_node_id);
873 goto fn_exit;
874 }
875 #ifdef USE_PMI1_API
876 mpi_errno = build_nodemap_pmi1(nodemap, sz, p_max_node_id);
877 #elif defined(USE_PMI2_API)
878 mpi_errno = build_nodemap_pmi2(nodemap, sz, p_max_node_id);
879 #elif defined(USE_PMIX_API)
880 mpi_errno = build_nodemap_pmix(nodemap, sz, p_max_node_id);
881 #endif
882 MPIR_ERR_CHECK(mpi_errno);
883
884 int num_cliques = get_option_num_cliques();
885 if (num_cliques > sz) {
886 num_cliques = sz;
887 }
888 if (*p_max_node_id == 0 && num_cliques > 1) {
889 mpi_errno = build_nodemap_roundrobin(num_cliques, nodemap, sz, p_max_node_id);
890 MPIR_ERR_CHECK(mpi_errno);
891 }
892
893 fn_exit:
894 return mpi_errno;
895 fn_fail:
896 goto fn_exit;
897 }
898
get_option_no_local(void)899 static int get_option_no_local(void)
900 {
901 /* Used for debugging only. This disables communication over shared memory */
902 #ifdef ENABLE_NO_LOCAL
903 return 1;
904 #else
905 return MPIR_CVAR_NOLOCAL;
906 #endif
907 }
908
get_option_num_cliques(void)909 static int get_option_num_cliques(void)
910 {
911 /* Used for debugging on a single machine: split procs into num_cliques nodes.
912 * If ODD_EVEN_CLIQUES were enabled, split procs into 2 nodes.
913 */
914 if (MPIR_CVAR_NUM_CLIQUES > 1) {
915 return MPIR_CVAR_NUM_CLIQUES;
916 } else {
917 return MPIR_CVAR_ODD_EVEN_CLIQUES ? 2 : 1;
918 }
919 }
920
921 /* one process per node */
build_nodemap_nolocal(int * nodemap,int sz,int * p_max_node_id)922 int build_nodemap_nolocal(int *nodemap, int sz, int *p_max_node_id)
923 {
924 for (int i = 0; i < sz; ++i) {
925 nodemap[i] = i;
926 }
927 *p_max_node_id = sz - 1;
928 return MPI_SUCCESS;
929 }
930
931 /* assign processes to num_cliques nodes in a round-robin fashion */
build_nodemap_roundrobin(int num_cliques,int * nodemap,int sz,int * p_max_node_id)932 static int build_nodemap_roundrobin(int num_cliques, int *nodemap, int sz, int *p_max_node_id)
933 {
934 for (int i = 0; i < sz; ++i) {
935 nodemap[i] = i % num_cliques;
936 }
937 *p_max_node_id = num_cliques - 1;
938 return MPI_SUCCESS;
939 }
940
941 #ifdef USE_PMI1_API
942
943 /* build nodemap based on allgather hostnames */
944 /* FIXME: migrate the function */
build_nodemap_fallback(int * nodemap,int sz,int * p_max_node_id)945 static int build_nodemap_fallback(int *nodemap, int sz, int *p_max_node_id)
946 {
947 return MPIR_NODEMAP_build_nodemap_fallback(sz, MPIR_Process.rank, nodemap, p_max_node_id);
948 }
949
950 /* build nodemap using PMI1 process_mapping or fallback with hostnames */
build_nodemap_pmi1(int * nodemap,int sz,int * p_max_node_id)951 static int build_nodemap_pmi1(int *nodemap, int sz, int *p_max_node_id)
952 {
953 int mpi_errno = MPI_SUCCESS;
954 int pmi_errno;
955 int did_map = 0;
956 if (pmi_version == 1 && pmi_subversion == 1) {
957 char *process_mapping = MPL_malloc(pmi_max_val_size, MPL_MEM_ADDRESS);
958 pmi_errno = PMI_KVS_Get(pmi_kvs_name, "PMI_process_mapping",
959 process_mapping, pmi_max_val_size);
960 if (pmi_errno == PMI_SUCCESS) {
961 mpi_errno = MPIR_NODEMAP_populate_ids_from_mapping(process_mapping, sz, nodemap,
962 p_max_node_id, &did_map);
963 MPIR_ERR_CHECK(mpi_errno);
964 MPIR_ERR_CHKINTERNAL(!did_map, mpi_errno,
965 "unable to populate node ids from PMI_process_mapping");
966 }
967 MPL_free(process_mapping);
968 }
969 if (!did_map) {
970 mpi_errno = build_nodemap_fallback(nodemap, sz, p_max_node_id);
971 }
972 fn_exit:
973 return mpi_errno;
974 fn_fail:
975 goto fn_exit;
976 }
977
978 #elif defined USE_PMI2_API
979
980 /* build nodemap using PMI2 process_mapping or error */
build_nodemap_pmi2(int * nodemap,int sz,int * p_max_node_id)981 static int build_nodemap_pmi2(int *nodemap, int sz, int *p_max_node_id)
982 {
983 int mpi_errno = MPI_SUCCESS;
984 int pmi_errno;
985 char process_mapping[PMI2_MAX_VALLEN];
986 int found;
987
988 pmi_errno = PMI2_Info_GetJobAttr("PMI_process_mapping", process_mapping, PMI2_MAX_VALLEN,
989 &found);
990 MPIR_ERR_CHKANDJUMP1(pmi_errno != PMI2_SUCCESS, mpi_errno, MPI_ERR_OTHER,
991 "**pmi2_info_getjobattr", "**pmi2_info_getjobattr %d", pmi_errno);
992 MPIR_ERR_CHKINTERNAL(!found, mpi_errno, "PMI_process_mapping attribute not found");
993
994 int did_map;
995 mpi_errno = MPIR_NODEMAP_populate_ids_from_mapping(process_mapping, sz, nodemap,
996 p_max_node_id, &did_map);
997 MPIR_ERR_CHECK(mpi_errno);
998 MPIR_ERR_CHKINTERNAL(!did_map, mpi_errno,
999 "unable to populate node ids from PMI_process_mapping");
1000 fn_exit:
1001 return mpi_errno;
1002 fn_fail:
1003 goto fn_exit;
1004 }
1005
1006 #elif defined USE_PMIX_API
1007
1008 /* build nodemap using PMIx_Resolve_nodes */
build_nodemap_pmix(int * nodemap,int sz,int * p_max_node_id)1009 int build_nodemap_pmix(int *nodemap, int sz, int *p_max_node_id)
1010 {
1011 int mpi_errno = MPI_SUCCESS;
1012 int pmi_errno;
1013 char *nodelist = NULL, *node = NULL;
1014 pmix_proc_t *procs = NULL;
1015 size_t nprocs, node_id = 0;
1016
1017 pmi_errno = PMIx_Resolve_nodes(pmix_proc.nspace, &nodelist);
1018 MPIR_ERR_CHKANDJUMP1(pmi_errno != PMIX_SUCCESS, mpi_errno, MPI_ERR_OTHER,
1019 "**pmix_resolve_nodes", "**pmix_resolve_nodes %d", pmi_errno);
1020 MPIR_Assert(nodelist);
1021
1022 node = strtok(nodelist, ",");
1023 while (node) {
1024 pmi_errno = PMIx_Resolve_peers(node, pmix_proc.nspace, &procs, &nprocs);
1025 MPIR_ERR_CHKANDJUMP1(pmi_errno != PMIX_SUCCESS, mpi_errno, MPI_ERR_OTHER,
1026 "**pmix_resolve_peers", "**pmix_resolve_peers %d", pmi_errno);
1027 for (int i = 0; i < nprocs; i++) {
1028 nodemap[procs[i].rank] = node_id;
1029 }
1030 node_id++;
1031 node = strtok(NULL, ",");
1032 }
1033 *p_max_node_id = node_id - 1;
1034 /* PMIx latest adds pmix_free. We should switch to that at some point */
1035 MPL_external_free(nodelist);
1036 PMIX_PROC_FREE(procs, nprocs);
1037
1038 fn_exit:
1039 return mpi_errno;
1040 fn_fail:
1041 goto fn_exit;
1042 }
1043
1044 #endif
1045
1046 /* allocate and populate MPIR_Process.node_local_map and MPIR_Process.node_root_map */
build_locality(void)1047 static int build_locality(void)
1048 {
1049 int local_rank = -1;
1050 int local_size = 0;
1051 int *node_root_map, *node_local_map;
1052
1053 int rank = MPIR_Process.rank;
1054 int size = MPIR_Process.size;
1055 int *node_map = MPIR_Process.node_map;
1056 int num_nodes = MPIR_Process.num_nodes;
1057 int local_node_id = node_map[rank];
1058
1059 node_root_map = MPL_malloc(num_nodes * sizeof(int), MPL_MEM_ADDRESS);
1060 for (int i = 0; i < num_nodes; i++) {
1061 node_root_map[i] = -1;
1062 }
1063
1064 for (int i = 0; i < size; i++) {
1065 int node_id = node_map[i];
1066 if (node_root_map[node_id] < 0) {
1067 node_root_map[node_id] = i;
1068 }
1069 if (node_id == local_node_id) {
1070 local_size++;
1071 }
1072 }
1073
1074 node_local_map = MPL_malloc(local_size * sizeof(int), MPL_MEM_ADDRESS);
1075 int j = 0;
1076 for (int i = 0; i < size; i++) {
1077 int node_id = node_map[i];
1078 if (node_id == local_node_id) {
1079 node_local_map[j] = i;
1080 if (i == rank) {
1081 local_rank = j;
1082 }
1083 j++;
1084 }
1085 }
1086
1087 MPIR_Process.node_root_map = node_root_map;
1088 MPIR_Process.node_local_map = node_local_map;
1089 MPIR_Process.local_size = local_size;
1090 MPIR_Process.local_rank = local_rank;
1091
1092 return MPI_SUCCESS;
1093 }
1094
1095 /* similar to functions in mpl/src/str/mpl_argstr.c, but much simpler */
hex(unsigned char c)1096 static int hex(unsigned char c)
1097 {
1098 if (c >= '0' && c <= '9') {
1099 return c - '0';
1100 } else if (c >= 'a' && c <= 'f') {
1101 return 10 + c - 'a';
1102 } else if (c >= 'A' && c <= 'F') {
1103 return 10 + c - 'A';
1104 } else {
1105 MPIR_Assert(0);
1106 return -1;
1107 }
1108 }
1109
encode(int size,const char * src,char * dest)1110 static void encode(int size, const char *src, char *dest)
1111 {
1112 for (int i = 0; i < size; i++) {
1113 MPL_snprintf(dest, 3, "%02X", (unsigned char) *src);
1114 src++;
1115 dest += 2;
1116 }
1117 }
1118
decode(int size,const char * src,char * dest)1119 static void decode(int size, const char *src, char *dest)
1120 {
1121 for (int i = 0; i < size; i++) {
1122 *dest = (char) (hex(src[0]) << 4) + hex(src[1]);
1123 src += 2;
1124 dest++;
1125 }
1126 }
1127
1128 /* static functions used in MPIR_pmi_spawn_multiple */
1129 #if defined(USE_PMI1_API) || defined(USE_PMI2_API)
mpi_to_pmi_keyvals(MPIR_Info * info_ptr,PMI_keyval_t ** kv_ptr,int * nkeys_ptr)1130 static int mpi_to_pmi_keyvals(MPIR_Info * info_ptr, PMI_keyval_t ** kv_ptr, int *nkeys_ptr)
1131 {
1132 char key[MPI_MAX_INFO_KEY];
1133 PMI_keyval_t *kv = 0;
1134 int nkeys = 0, vallen, flag, mpi_errno = MPI_SUCCESS;
1135
1136 MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPI_TO_PMI_KEYVALS);
1137 MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPI_TO_PMI_KEYVALS);
1138
1139 if (!info_ptr || info_ptr->handle == MPI_INFO_NULL)
1140 goto fn_exit;
1141
1142 MPIR_Info_get_nkeys_impl(info_ptr, &nkeys);
1143
1144 if (nkeys == 0)
1145 goto fn_exit;
1146
1147 kv = (PMI_keyval_t *) MPL_malloc(nkeys * sizeof(PMI_keyval_t), MPL_MEM_BUFFER);
1148
1149 for (int i = 0; i < nkeys; i++) {
1150 mpi_errno = MPIR_Info_get_nthkey_impl(info_ptr, i, key);
1151 MPIR_ERR_CHECK(mpi_errno);
1152 MPIR_Info_get_valuelen_impl(info_ptr, key, &vallen, &flag);
1153 kv[i].key = MPL_strdup(key);
1154 kv[i].val = (char *) MPL_malloc(vallen + 1, MPL_MEM_BUFFER);
1155 MPIR_Info_get_impl(info_ptr, key, vallen + 1, kv[i].val, &flag);
1156 }
1157
1158 fn_exit:
1159 *kv_ptr = kv;
1160 *nkeys_ptr = nkeys;
1161 MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPI_TO_PMI_KEYVALS);
1162 return mpi_errno;
1163
1164 fn_fail:
1165 goto fn_exit;
1166 }
1167
free_pmi_keyvals(PMI_keyval_t ** kv,int size,int * counts)1168 static void free_pmi_keyvals(PMI_keyval_t ** kv, int size, int *counts)
1169 {
1170 MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_FREE_PMI_KEYVALS);
1171 MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_FREE_PMI_KEYVALS);
1172
1173 for (int i = 0; i < size; i++) {
1174 for (int j = 0; j < counts[i]; j++) {
1175 MPL_free((char *) kv[i][j].key);
1176 MPL_free(kv[i][j].val);
1177 }
1178 MPL_free(kv[i]);
1179 }
1180
1181 MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_FREE_PMI_KEYVALS);
1182 }
1183 #endif /* USE_PMI1_API or USE_PMI2_API */
1184