1 /*****************************************************************************\
2  **  setup.c - PMI2 server setup
3  *****************************************************************************
4  *  Copyright (C) 2011-2012 National University of Defense Technology.
5  *  Written by Hongjia Cao <hjcao@nudt.edu.cn>.
6  *  All rights reserved.
7  *  Portions copyright (C) 2015 Mellanox Technologies Inc.
8  *  Written by Artem Y. Polyakov <artemp@mellanox.com>.
9  *  All rights reserved.
10  *  Portions copyright (C) 2017 SchedMD LLC.
11  *
12  *  This file is part of Slurm, a resource management program.
13  *  For details, see <https://slurm.schedmd.com/>.
14  *  Please also read the included file: DISCLAIMER.
15  *
16  *  Slurm is free software; you can redistribute it and/or modify it under
17  *  the terms of the GNU General Public License as published by the Free
18  *  Software Foundation; either version 2 of the License, or (at your option)
19  *  any later version.
20  *
21  *  In addition, as a special exception, the copyright holders give permission
22  *  to link the code of portions of this program with the OpenSSL library under
23  *  certain conditions as described in each individual source file, and
24  *  distribute linked combinations including the two. You must obey the GNU
25  *  General Public License in all respects for all of the code used other than
26  *  OpenSSL. If you modify file(s) with this exception, you may extend this
27  *  exception to your version of the file(s), but you are not obligated to do
28  *  so. If you do not wish to do so, delete this exception statement from your
29  *  version.  If you delete this exception statement from all source files in
30  *  the program, then also delete it here.
31  *
32  *  Slurm is distributed in the hope that it will be useful, but WITHOUT ANY
33  *  WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
34  *  FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
35  *  details.
36  *
37  *  You should have received a copy of the GNU General Public License along
38  *  with Slurm; if not, write to the Free Software Foundation, Inc.,
39  *  51 Franklin Street, Fifth Floor, Boston, MA 02110-1301  USA.
40 \*****************************************************************************/
41 
42 #if defined(__DragonFly__)
43 #include <sys/socket.h>	/* AF_INET */
44 #endif
45 
46 #include <dlfcn.h>
47 #include <fcntl.h>
48 #include <poll.h>
49 #include <signal.h>
50 #include <stdlib.h>
51 #include <string.h>
52 #include <sys/types.h>
53 #include <sys/un.h>
54 #include <unistd.h>
55 
56 #include "src/common/slurm_xlator.h"
57 #include "src/common/net.h"
58 #include "src/common/proc_args.h"
59 #include "src/common/slurm_mpi.h"
60 #include "src/common/xstring.h"
61 #include "src/slurmd/slurmstepd/slurmstepd_job.h"
62 #include "src/slurmd/common/reverse_tree_math.h"
63 
64 #include "setup.h"
65 #include "tree.h"
66 #include "pmi.h"
67 #include "spawn.h"
68 #include "kvs.h"
69 #include "ring.h"
70 
71 #define PMI2_SOCK_ADDR_FMT "%s/sock.pmi2.%u.%u"
72 
73 
74 extern char **environ;
75 
76 static bool run_in_stepd = 0;
77 
78 int  tree_sock;
79 int *task_socks;
80 char tree_sock_addr[128];
81 pmi2_job_info_t job_info;
82 pmi2_tree_info_t tree_info;
83 
84 static char *fmt_tree_sock_addr = NULL;
85 
86 extern bool
in_stepd(void)87 in_stepd(void)
88 {
89 	return run_in_stepd;
90 }
91 
92 static void
_remove_tree_sock(void)93 _remove_tree_sock(void)
94 {
95 	if (fmt_tree_sock_addr) {
96 		unlink(fmt_tree_sock_addr);
97 		xfree(fmt_tree_sock_addr);
98 	}
99 }
100 
101 static int
_setup_stepd_job_info(const stepd_step_rec_t * job,char *** env)102 _setup_stepd_job_info(const stepd_step_rec_t *job, char ***env)
103 {
104 	char *p;
105 	int i;
106 
107 	memset(&job_info, 0, sizeof(job_info));
108 
109 	if (job->het_job_id && (job->het_job_id != NO_VAL)) {
110 		job_info.jobid  = job->het_job_id;
111 		job_info.stepid = job->stepid;
112 		job_info.nnodes = job->het_job_nnodes;
113 		job_info.nodeid = job->nodeid + job->het_job_node_offset;
114 		job_info.ntasks = job->het_job_ntasks;
115 		job_info.ltasks = job->node_tasks;
116 		job_info.gtids = xmalloc(job_info.ltasks * sizeof(uint32_t));
117 		for (i = 0; i < job_info.ltasks; i ++) {
118 			job_info.gtids[i] = job->task[i]->gtid +
119 					    job->het_job_task_offset;
120 		}
121 	} else {
122 		job_info.jobid  = job->jobid;
123 		job_info.stepid = job->stepid;
124 		job_info.nnodes = job->nnodes;
125 		job_info.nodeid = job->nodeid;
126 		job_info.ntasks = job->ntasks;
127 		job_info.ltasks = job->node_tasks;
128 		job_info.gtids = xmalloc(job_info.ltasks * sizeof(uint32_t));
129 		for (i = 0; i < job_info.ltasks; i ++) {
130 			job_info.gtids[i] = job->task[i]->gtid;
131 		}
132 	}
133 
134 	p = getenvp(*env, PMI2_PMI_DEBUGGED_ENV);
135 	if (p) {
136 		job_info.pmi_debugged = atoi(p);
137 	} else {
138 		job_info.pmi_debugged = 0;
139 	}
140 	p = getenvp(*env, PMI2_SPAWN_SEQ_ENV);
141 	if (p) { 		/* spawned */
142 		job_info.spawn_seq = atoi(p);
143 		unsetenvp(*env, PMI2_SPAWN_SEQ_ENV);
144 		p = getenvp(*env, PMI2_SPAWNER_JOBID_ENV);
145 		job_info.spawner_jobid = xstrdup(p);
146 		unsetenvp(*env, PMI2_SPAWNER_JOBID_ENV);
147 	} else {
148 		job_info.spawn_seq = 0;
149 		job_info.spawner_jobid = NULL;
150 	}
151 	p = getenvp(*env, PMI2_PMI_JOBID_ENV);
152 	if (p) {
153 		job_info.pmi_jobid = xstrdup(p);
154 		unsetenvp(*env, PMI2_PMI_JOBID_ENV);
155 	} else {
156 		xstrfmtcat(job_info.pmi_jobid, "%u.%u", job_info.jobid,
157 			   job_info.stepid);
158 	}
159 	p = getenvp(*env, PMI2_STEP_NODES_ENV);
160 	if (!p) {
161 		error("mpi/pmi2: unable to find nodes in job environment");
162 		return SLURM_ERROR;
163 	} else {
164 		job_info.step_nodelist = xstrdup(p);
165 		unsetenvp(*env, PMI2_STEP_NODES_ENV);
166 	}
167 	/*
168 	 * how to get the mapping info from stepd directly?
169 	 * there is the task distribution info in the launch_tasks_request_msg_t,
170 	 * but it is not stored in the stepd_step_rec_t.
171 	 */
172 	p = getenvp(*env, PMI2_PROC_MAPPING_ENV);
173 	if (!p) {
174 		error("PMI2_PROC_MAPPING_ENV not found");
175 		return SLURM_ERROR;
176 	} else {
177 		job_info.proc_mapping = xstrdup(p);
178 		unsetenvp(*env, PMI2_PROC_MAPPING_ENV);
179 	}
180 
181 	job_info.job_env = env_array_copy((const char **)*env);
182 
183 	job_info.MPIR_proctable = NULL;
184 	job_info.srun_opt = NULL;
185 
186 	/* get the SLURM_STEP_RESV_PORTS
187 	 */
188 	p = getenvp(*env, SLURM_STEP_RESV_PORTS);
189 	if (!p) {
190 		debug("%s: %s not found in env", __func__, SLURM_STEP_RESV_PORTS);
191 	} else {
192 		job_info.resv_ports = xstrdup(p);
193 		info("%s: SLURM_STEP_RESV_PORTS found %s", __func__, p);
194 	}
195 	return SLURM_SUCCESS;
196 }
197 
198 static int
_setup_stepd_tree_info(char *** env)199 _setup_stepd_tree_info(char ***env)
200 {
201 	hostlist_t hl;
202 	char *srun_host;
203 	uint16_t port;
204 	char *p;
205 	int tree_width;
206 
207 	/* job info available */
208 
209 	memset(&tree_info, 0, sizeof(tree_info));
210 
211 	hl = hostlist_create(job_info.step_nodelist);
212 	p = hostlist_nth(hl, job_info.nodeid); /* strdup-ed */
213 	tree_info.this_node = xstrdup(p);
214 	free(p);
215 
216 	/* this only controls the upward communication tree width */
217 	p = getenvp(*env, PMI2_TREE_WIDTH_ENV);
218 	if (p) {
219 		tree_width = atoi(p);
220 		if (tree_width < 2) {
221 			info("invalid PMI2 tree width value (%d) detected. "
222 			     "fallback to default value.", tree_width);
223 			tree_width = slurm_get_tree_width();
224 		}
225 	} else {
226 		tree_width = slurm_get_tree_width();
227 	}
228 
229 	/* TODO: cannot launch 0 tasks on node */
230 
231 	/*
232 	 * In tree position calculation, root of the tree is srun with id 0.
233 	 * Stepd's id will be its nodeid plus 1.
234 	 */
235 	reverse_tree_info(job_info.nodeid + 1, job_info.nnodes + 1,
236 			  tree_width, &tree_info.parent_id,
237 			  &tree_info.num_children, &tree_info.depth,
238 			  &tree_info.max_depth);
239 	tree_info.parent_id --;	       /* restore real nodeid */
240 	if (tree_info.parent_id < 0) {	/* parent is srun */
241 		tree_info.parent_node = NULL;
242 	} else {
243 		p = hostlist_nth(hl, tree_info.parent_id);
244 		tree_info.parent_node = xstrdup(p);
245 		free(p);
246 	}
247 	hostlist_destroy(hl);
248 
249 	tree_info.pmi_port = 0;	/* not used */
250 
251 	srun_host = getenvp(*env, "SLURM_SRUN_COMM_HOST");
252 	if (!srun_host) {
253 		error("mpi/pmi2: unable to find srun comm ifhn in env");
254 		return SLURM_ERROR;
255 	}
256 	p = getenvp(*env, PMI2_SRUN_PORT_ENV);
257 	if (!p) {
258 		error("mpi/pmi2: unable to find srun pmi2 port in env");
259 		return SLURM_ERROR;
260 	}
261 	port = atoi(p);
262 
263 	tree_info.srun_addr = xmalloc(sizeof(slurm_addr_t));
264 	slurm_set_addr(tree_info.srun_addr, port, srun_host);
265 
266 	unsetenvp(*env, PMI2_SRUN_PORT_ENV);
267 
268 	/* init kvs seq to 0. TODO: reduce array size */
269 	tree_info.children_kvs_seq = xmalloc(sizeof(uint32_t) *
270 					     job_info.nnodes);
271 
272 	return SLURM_SUCCESS;
273 }
274 
275 /*
276  * setup sockets for slurmstepd
277  */
278 static int
_setup_stepd_sockets(const stepd_step_rec_t * job,char *** env)279 _setup_stepd_sockets(const stepd_step_rec_t *job, char ***env)
280 {
281 	struct sockaddr_un sa;
282 	int i;
283 	char *spool;
284 
285 	debug("mpi/pmi2: setup sockets");
286 
287 	tree_sock = socket(AF_UNIX, SOCK_STREAM, 0);
288 	if (tree_sock < 0) {
289 		error("mpi/pmi2: failed to create tree socket: %m");
290 		return SLURM_ERROR;
291 	}
292 	sa.sun_family = PF_UNIX;
293 
294 	/*
295 	 * tree_sock_addr has to remain unformatted since the formatting
296 	 * happens on the slurmd side
297 	 */
298 	spool = slurm_get_slurmd_spooldir(NULL);
299 	snprintf(tree_sock_addr, sizeof(tree_sock_addr), PMI2_SOCK_ADDR_FMT,
300 		 spool, job_info.jobid, job_info.stepid);
301 	/*
302 	 * Make sure we adjust for the spool dir coming in on the address to
303 	 * point to the right spot.
304 	 * We need to unlink this later so we need a formatted version of the
305 	 * string to unlink.
306 	 */
307 	xstrsubstitute(spool, "%n", job->node_name);
308 	xstrsubstitute(spool, "%h", job->node_name);
309 	xstrfmtcat(fmt_tree_sock_addr, PMI2_SOCK_ADDR_FMT, spool,
310 		   job_info.jobid, job_info.stepid);
311 	/*
312 	 * If socket name would be truncated, emit error and exit
313 	 */
314 	if (strlen(fmt_tree_sock_addr) >= sizeof(sa.sun_path)) {
315 		error("%s: Unix socket path '%s' is too long. (%ld > %ld)",
316 		      __func__, fmt_tree_sock_addr,
317 		      (long int)(strlen(fmt_tree_sock_addr) + 1),
318 		      (long int)sizeof(sa.sun_path));
319 		xfree(spool);
320 		xfree(fmt_tree_sock_addr);
321 		return SLURM_ERROR;
322 	}
323 
324 	strlcpy(sa.sun_path, fmt_tree_sock_addr, sizeof(sa.sun_path));
325 
326 	unlink(sa.sun_path);    /* remove possible old socket */
327 	xfree(spool);
328 
329 	if (bind(tree_sock, (struct sockaddr *)&sa, SUN_LEN(&sa)) < 0) {
330 		error("mpi/pmi2: failed to bind tree socket: %m");
331 		unlink(sa.sun_path);
332 		return SLURM_ERROR;
333 	}
334 	if (listen(tree_sock, 64) < 0) {
335 		error("mpi/pmi2: failed to listen tree socket: %m");
336 		unlink(sa.sun_path);
337 		return SLURM_ERROR;
338 	}
339 
340 	task_socks = xmalloc(2 * job->node_tasks * sizeof(int));
341 	for (i = 0; i < job->node_tasks; i ++) {
342 		socketpair(AF_UNIX, SOCK_STREAM, 0, &task_socks[i * 2]);
343 		/* this must be delayed after the tasks have been forked */
344 /* 		close(TASK_PMI_SOCK(i)); */
345 	}
346 	return SLURM_SUCCESS;
347 }
348 
349 static int
_setup_stepd_kvs(char *** env)350 _setup_stepd_kvs(char ***env)
351 {
352 	int rc = SLURM_SUCCESS, i = 0, pp_cnt = 0;
353 	char *p, env_key[32], *ppkey, *ppval;
354 
355 	kvs_seq = 1;
356 	rc = temp_kvs_init();
357 	if (rc != SLURM_SUCCESS)
358 		return rc;
359 
360 	rc = kvs_init();
361 	if (rc != SLURM_SUCCESS)
362 		return rc;
363 
364 	/* preput */
365 	p = getenvp(*env, PMI2_PREPUT_CNT_ENV);
366 	if (p) {
367 		pp_cnt = atoi(p);
368 	}
369 
370 	for (i = 0; i < pp_cnt; i ++) {
371 		snprintf(env_key, 32, PMI2_PPKEY_ENV"%d", i);
372 		p = getenvp(*env, env_key);
373 		ppkey = p; /* getenvp will not modify p */
374 		snprintf(env_key, 32, PMI2_PPVAL_ENV"%d", i);
375 		p = getenvp(*env, env_key);
376 		ppval = p;
377 		kvs_put(ppkey, ppval);
378 	}
379 
380 	/*
381 	 * For PMI11.
382 	 * A better logic would be to put PMI_process_mapping in KVS only if
383 	 * the task distribution method is not "arbitrary", because in
384 	 * "arbitrary" distribution the process mapping variable is not correct.
385 	 * MPICH2 may deduce the clique info from the hostnames. But that
386 	 * is rather costly.
387 	 */
388 	kvs_put("PMI_process_mapping", job_info.proc_mapping);
389 
390 	return SLURM_SUCCESS;
391 }
392 
393 extern int
pmi2_setup_stepd(const stepd_step_rec_t * job,char *** env)394 pmi2_setup_stepd(const stepd_step_rec_t *job, char ***env)
395 {
396 	int rc;
397 
398 	run_in_stepd = true;
399 
400 	/* job info */
401 	rc = _setup_stepd_job_info(job, env);
402 	if (rc != SLURM_SUCCESS)
403 		return rc;
404 
405 	/* tree info */
406 	rc = _setup_stepd_tree_info(env);
407 	if (rc != SLURM_SUCCESS)
408 		return rc;
409 
410 	/* sockets */
411 	rc = _setup_stepd_sockets(job, env);
412 	if (rc != SLURM_SUCCESS)
413 		return rc;
414 
415 	/* kvs */
416 	rc = _setup_stepd_kvs(env);
417 	if (rc != SLURM_SUCCESS)
418 		return rc;
419 
420 	/* TODO: finalize pmix_ring state somewhere */
421 	/* initialize pmix_ring state */
422 	rc = pmix_ring_init(&job_info, env);
423 	if (rc != SLURM_SUCCESS)
424 		return rc;
425 
426 	return SLURM_SUCCESS;
427 }
428 
429 extern void
pmi2_cleanup_stepd(void)430 pmi2_cleanup_stepd(void)
431 {
432 	close(tree_sock);
433 	_remove_tree_sock();
434 }
435 /**************************************************************/
436 
437 /* returned string should be xfree-ed by caller */
438 static char *
_get_proc_mapping(const mpi_plugin_client_info_t * job)439 _get_proc_mapping(const mpi_plugin_client_info_t *job)
440 {
441 	uint32_t node_cnt, task_cnt, task_mapped, node_task_cnt, **tids;
442 	uint32_t task_dist, block;
443 	uint16_t *tasks, *rounds;
444 	int i, start_id, end_id;
445 	char *mapping = NULL;
446 
447 	node_cnt = job->step_layout->node_cnt;
448 	task_cnt = job->step_layout->task_cnt;
449 	task_dist = job->step_layout->task_dist & SLURM_DIST_STATE_BASE;
450 	tasks = job->step_layout->tasks;
451 	tids = job->step_layout->tids;
452 
453 	/* for now, PMI2 only supports vector processor mapping */
454 
455 	if ((task_dist & SLURM_DIST_NODEMASK) == SLURM_DIST_NODECYCLIC) {
456 		mapping = xstrdup("(vector");
457 
458 		rounds = xmalloc (node_cnt * sizeof(uint16_t));
459 		task_mapped = 0;
460 		while (task_mapped < task_cnt) {
461 			start_id = 0;
462 			/* find start_id */
463 			while (start_id < node_cnt) {
464 				while (start_id < node_cnt &&
465 				       ( rounds[start_id] >= tasks[start_id] ||
466 					 (task_mapped !=
467 					  tids[start_id][rounds[start_id]]) )) {
468 					start_id ++;
469 				}
470 				if (start_id >= node_cnt)
471 					break;
472 				/* block is always 1 */
473 				/* find end_id */
474 				end_id = start_id;
475 				while (end_id < node_cnt &&
476 				       ( rounds[end_id] < tasks[end_id] &&
477 					 (task_mapped ==
478 					  tids[end_id][rounds[end_id]]) )) {
479 					rounds[end_id] ++;
480 					task_mapped ++;
481 					end_id ++;
482 				}
483 				xstrfmtcat(mapping, ",(%u,%u,1)", start_id,
484 					   end_id - start_id);
485 				start_id = end_id;
486 			}
487 		}
488 		xfree(rounds);
489 		xstrcat(mapping, ")");
490 	} else if (task_dist == SLURM_DIST_ARBITRARY) {
491 		/*
492 		 * MPICH2 will think that each task runs on a seperate node.
493 		 * The program will run, but no SHM will be used for
494 		 * communication.
495 		 */
496 		mapping = xstrdup("(vector");
497 		xstrfmtcat(mapping, ",(0,%u,1)", job->step_layout->task_cnt);
498 		xstrcat(mapping, ")");
499 
500 	} else if (task_dist == SLURM_DIST_PLANE) {
501 		mapping = xstrdup("(vector");
502 
503 		rounds = xmalloc (node_cnt * sizeof(uint16_t));
504 		task_mapped = 0;
505 		while (task_mapped < task_cnt) {
506 			start_id = 0;
507 			/* find start_id */
508 			while (start_id < node_cnt) {
509 				while (start_id < node_cnt &&
510 				       ( rounds[start_id] >= tasks[start_id] ||
511 					 (task_mapped !=
512 					  tids[start_id][rounds[start_id]]) )) {
513 					start_id ++;
514 				}
515 				if (start_id >= node_cnt)
516 					break;
517 				/* find start block. block may be less
518 				 * than plane size */
519 				block = 0;
520 				while (rounds[start_id] < tasks[start_id] &&
521 				       (task_mapped ==
522 					tids[start_id][rounds[start_id]])) {
523 					block ++;
524 					rounds[start_id] ++;
525 					task_mapped ++;
526 				}
527 				/* find end_id */
528 				end_id = start_id + 1;
529 				while (end_id < node_cnt &&
530 				       (rounds[end_id] + block - 1 <
531 					tasks[end_id])) {
532 					for (i = 0;
533 					     i < tasks[end_id] - rounds[end_id];
534 					     i ++) {
535 						if (task_mapped + i !=
536 						    tids[end_id][rounds[end_id]
537 								 + i]) {
538 							break;
539 						}
540 					}
541 					if (i != block)
542 						break;
543 					rounds[end_id] += block;
544 					task_mapped += block;
545 					end_id ++;
546 				}
547 				xstrfmtcat(mapping, ",(%u,%u,%u)", start_id,
548 					   end_id - start_id, block);
549 				start_id = end_id;
550 			}
551 		}
552 		xfree(rounds);
553 		xstrcat(mapping, ")");
554 
555 	} else {		/* BLOCK mode */
556 		mapping = xstrdup("(vector");
557 		start_id = 0;
558 		node_task_cnt = tasks[start_id];
559 		for (i = start_id + 1; i < node_cnt; i ++) {
560 			if (node_task_cnt == tasks[i])
561 				continue;
562 			xstrfmtcat(mapping, ",(%u,%u,%u)", start_id,
563 				   i - start_id, node_task_cnt);
564 			start_id = i;
565 			node_task_cnt = tasks[i];
566 		}
567 		xstrfmtcat(mapping, ",(%u,%u,%u))", start_id, i - start_id,
568 			   node_task_cnt);
569 	}
570 
571 	debug("mpi/pmi2: processor mapping: %s", mapping);
572 	return mapping;
573 }
574 
575 static int
_setup_srun_job_info(const mpi_plugin_client_info_t * job)576 _setup_srun_job_info(const mpi_plugin_client_info_t *job)
577 {
578 	char *p;
579 	void *handle = NULL, *sym = NULL;
580 
581 	memset(&job_info, 0, sizeof(job_info));
582 
583 	if (job->het_job_id && (job->het_job_id != NO_VAL)) {
584 		job_info.jobid  = job->het_job_id;
585 		job_info.stepid = job->stepid;
586 		job_info.nnodes = job->step_layout->node_cnt;
587 		job_info.ntasks = job->step_layout->task_cnt;
588 	} else {
589 		job_info.jobid  = job->jobid;
590 		job_info.stepid = job->stepid;
591 		job_info.nnodes = job->step_layout->node_cnt;
592 		job_info.ntasks = job->step_layout->task_cnt;
593 	}
594 	job_info.nodeid = -1;	/* id in tree. not used. */
595 	job_info.ltasks = 0;	/* not used */
596 	job_info.gtids = NULL;	/* not used */
597 
598 	p = getenv(PMI2_PMI_DEBUGGED_ENV);
599 	if (p) {
600 		job_info.pmi_debugged = atoi(p);
601 	} else {
602 		job_info.pmi_debugged = 0;
603 	}
604 	p = getenv(PMI2_SPAWN_SEQ_ENV);
605 	if (p) { 		/* spawned */
606 		job_info.spawn_seq = atoi(p);
607 		p = getenv(PMI2_SPAWNER_JOBID_ENV);
608 		job_info.spawner_jobid = xstrdup(p);
609 		/* env unset in stepd */
610 	} else {
611 		job_info.spawn_seq = 0;
612 		job_info.spawner_jobid = NULL;
613 	}
614 
615 	job_info.step_nodelist = xstrdup(job->step_layout->node_list);
616 	job_info.proc_mapping = _get_proc_mapping(job);
617 	if (job_info.proc_mapping == NULL) {
618 		return SLURM_ERROR;
619 	}
620 	p = getenv(PMI2_PMI_JOBID_ENV);
621 	if (p) {		/* spawned */
622 		job_info.pmi_jobid = xstrdup(p);
623 	} else {
624 		xstrfmtcat(job_info.pmi_jobid, "%u.%u", job_info.jobid,
625 			   job_info.stepid);
626 	}
627 	job_info.job_env = env_array_copy((const char **)environ);
628 
629 	/* hjcao: this is really dirty.
630 	   But writing a new launcher is not desirable. */
631 	handle = dlopen(NULL, RTLD_LAZY);
632 	if (handle == NULL) {
633 		error("mpi/pmi2: failed to dlopen()");
634 		return SLURM_ERROR;
635 	}
636 	sym = dlsym(handle, "MPIR_proctable");
637 	if (sym == NULL) {
638 		/* if called directly in API, there may be no symbol available */
639 		verbose ("mpi/pmi2: failed to find symbol 'MPIR_proctable'");
640 		job_info.MPIR_proctable = NULL;
641 	} else {
642 		job_info.MPIR_proctable = *(MPIR_PROCDESC **)sym;
643 	}
644 	sym = dlsym(handle, "opt");
645 	if (sym == NULL) {
646 		verbose("mpi/pmi2: failed to find symbol 'opt'");
647 		job_info.srun_opt = NULL;
648 	} else {
649 		job_info.srun_opt = (slurm_opt_t *)sym;
650 	}
651 	dlclose(handle);
652 
653 	return SLURM_SUCCESS;
654 }
655 
656 static int
_setup_srun_tree_info(void)657 _setup_srun_tree_info(void)
658 {
659 	char *p;
660 	uint16_t p_port;
661 	char *spool;
662 
663 	memset(&tree_info, 0, sizeof(tree_info));
664 
665 	tree_info.this_node = "launcher"; /* not used */
666 	tree_info.parent_id = -2;   /* not used */
667 	tree_info.parent_node = NULL; /* not used */
668 	tree_info.num_children = job_info.nnodes;
669 	tree_info.depth = 0;	 /* not used */
670 	tree_info.max_depth = 0; /* not used */
671 	/* pmi_port set in _setup_srun_sockets */
672 	p = getenv(PMI2_SPAWNER_PORT_ENV);
673 	if (p) {		/* spawned */
674 		p_port = atoi(p);
675 		tree_info.srun_addr = xmalloc(sizeof(slurm_addr_t));
676 		/* assume there is always a lo interface */
677 		slurm_set_addr(tree_info.srun_addr, p_port, "127.0.0.1");
678 	} else
679 		tree_info.srun_addr = NULL;
680 
681 	/*
682 	 * FIXME: We need to handle %n and %h in the spool dir, but don't have
683 	 * the node name here
684 	 */
685 	spool = slurm_get_slurmd_spooldir(NULL);
686 	snprintf(tree_sock_addr, 128, PMI2_SOCK_ADDR_FMT,
687 		 spool, job_info.jobid, job_info.stepid);
688 	xfree(spool);
689 
690 	/* init kvs seq to 0. TODO: reduce array size */
691 	tree_info.children_kvs_seq = xmalloc(sizeof(uint32_t) *
692 					     job_info.nnodes);
693 
694 	return SLURM_SUCCESS;
695 }
696 
697 static int
_setup_srun_socket(const mpi_plugin_client_info_t * job)698 _setup_srun_socket(const mpi_plugin_client_info_t *job)
699 {
700 	if (net_stream_listen(&tree_sock,
701 			      &tree_info.pmi_port) < 0) {
702 		error("mpi/pmi2: Failed to create tree socket");
703 		return SLURM_ERROR;
704 	}
705 	debug("mpi/pmi2: srun pmi port: %hu", tree_info.pmi_port);
706 
707 	return SLURM_SUCCESS;
708 }
709 
710 static int
_setup_srun_kvs(void)711 _setup_srun_kvs(void)
712 {
713 	int rc;
714 
715 	kvs_seq = 1;
716 	rc = temp_kvs_init();
717 	return rc;
718 }
719 
720 static int
_setup_srun_environ(const mpi_plugin_client_info_t * job,char *** env)721 _setup_srun_environ(const mpi_plugin_client_info_t *job, char ***env)
722 {
723 	/* ifhn will be set in SLURM_SRUN_COMM_HOST by slurmd */
724 	env_array_overwrite_fmt(env, PMI2_SRUN_PORT_ENV, "%hu",
725 				tree_info.pmi_port);
726 	env_array_overwrite_fmt(env, PMI2_STEP_NODES_ENV, "%s",
727 				job_info.step_nodelist);
728 	env_array_overwrite_fmt(env, PMI2_PROC_MAPPING_ENV, "%s",
729 				job_info.proc_mapping);
730 	return SLURM_SUCCESS;
731 }
732 
733 inline static int
_tasks_launched(void)734 _tasks_launched (void)
735 {
736 	int i, all_launched = 1;
737 	if (job_info.MPIR_proctable == NULL)
738 		return 1;
739 
740 	for (i = 0; i < job_info.ntasks; i ++) {
741 		if (job_info.MPIR_proctable[i].pid == 0) {
742 			all_launched = 0;
743 			break;
744 		}
745 	}
746 	return all_launched;
747 }
748 
749 static void *
_task_launch_detection(void * unused)750 _task_launch_detection(void *unused)
751 {
752 	spawn_resp_t *resp;
753 	time_t start;
754 	int rc = 0;
755 
756 	/*
757 	 * mpir_init() is called in plugins/launch/slurm/launch_slurm.c before
758 	 * mpi_hook_client_prelaunch() is called in api/step_launch.c
759 	 */
760 	start = time(NULL);
761 	while (_tasks_launched() == 0) {
762 		usleep(1000*50);
763 		if (time(NULL) - start > 600) {
764 			rc = 1;
765 			break;
766 		}
767 	}
768 
769 	/* send a resp to spawner srun */
770 	resp = spawn_resp_new();
771 	resp->seq = job_info.spawn_seq;
772 	resp->jobid = xstrdup(job_info.pmi_jobid);
773 	resp->error_cnt = 0;	/* TODO */
774 	resp->rc = rc;
775 	resp->pmi_port = tree_info.pmi_port;
776 
777 	spawn_resp_send_to_srun(resp);
778 	spawn_resp_free(resp);
779 	return NULL;
780 }
781 
782 extern int
pmi2_setup_srun(const mpi_plugin_client_info_t * job,char *** env)783 pmi2_setup_srun(const mpi_plugin_client_info_t *job, char ***env)
784 {
785 	static pthread_mutex_t setup_mutex = PTHREAD_MUTEX_INITIALIZER;
786 	static pthread_cond_t setup_cond  = PTHREAD_COND_INITIALIZER;
787 	static int global_rc = NO_VAL16;
788 	int rc = SLURM_SUCCESS;
789 
790 	run_in_stepd = false;
791 	if ((job->het_job_id == NO_VAL) || (job->het_job_task_offset == 0)) {
792 		rc = _setup_srun_job_info(job);
793 		if (rc == SLURM_SUCCESS)
794 			rc = _setup_srun_tree_info();
795 		if (rc == SLURM_SUCCESS)
796 			rc = _setup_srun_socket(job);
797 		if (rc == SLURM_SUCCESS)
798 			rc = _setup_srun_kvs();
799 		if (rc == SLURM_SUCCESS)
800 			rc = _setup_srun_environ(job, env);
801 		if ((rc == SLURM_SUCCESS) && job_info.spawn_seq) {
802 			slurm_thread_create_detached(NULL,
803 						     _task_launch_detection,
804 						     NULL);
805 		}
806 		slurm_mutex_lock(&setup_mutex);
807 		global_rc = rc;
808 		slurm_cond_broadcast(&setup_cond);
809 		slurm_mutex_unlock(&setup_mutex);
810 	} else {
811 		slurm_mutex_lock(&setup_mutex);
812 		while (global_rc == NO_VAL16)
813 			slurm_cond_wait(&setup_cond, &setup_mutex);
814 		rc = global_rc;
815 		slurm_mutex_unlock(&setup_mutex);
816 		if (rc == SLURM_SUCCESS)
817 			rc = _setup_srun_environ(job, env);
818  	}
819 
820 	return rc;
821 }
822