1 /*
2  * Copyright (C) by Argonne National Laboratory
3  *     See COPYRIGHT in top-level directory
4  */
5 
6 #include "mpidimpl.h"
7 #include "ofi_impl.h"
8 #include "mpidu_bc.h"
9 #include "ofi_noinline.h"
10 
11 /*
12 === BEGIN_MPI_T_CVAR_INFO_BLOCK ===
13 
14 categories :
15     - name : CH4_OFI
16       description : A category for CH4 OFI netmod variables
17 
18 cvars:
19     - name        : MPIR_CVAR_CH4_OFI_CAPABILITY_SETS_DEBUG
20       category    : CH4_OFI
21       type        : int
22       default     : 0
23       class       : none
24       verbosity   : MPI_T_VERBOSITY_USER_BASIC
25       scope       : MPI_T_SCOPE_LOCAL
26       description : >-
27         Prints out the configuration of each capability selected via the capability sets interface.
28 
29     - name        : MPIR_CVAR_OFI_SKIP_IPV6
30       category    : DEVELOPER
31       type        : boolean
32       default     : true
33       class       : none
34       verbosity   : MPI_T_VERBOSITY_USER_BASIC
35       scope       : MPI_T_SCOPE_LOCAL
36       description : >-
37         Skip IPv6 providers.
38 
39     - name        : MPIR_CVAR_CH4_OFI_ENABLE_AV_TABLE
40       category    : CH4_OFI
41       type        : int
42       default     : -1
43       class       : none
44       verbosity   : MPI_T_VERBOSITY_USER_BASIC
45       scope       : MPI_T_SCOPE_LOCAL
46       description : >-
47         If true, the OFI addressing information will be stored with an FI_AV_TABLE.
48         If false, an FI_AV_MAP will be used.
49 
50     - name        : MPIR_CVAR_CH4_OFI_ENABLE_SCALABLE_ENDPOINTS
51       category    : CH4_OFI
52       type        : int
53       default     : -1
54       class       : none
55       verbosity   : MPI_T_VERBOSITY_USER_BASIC
56       scope       : MPI_T_SCOPE_LOCAL
57       description : >-
58         If true, use OFI scalable endpoints.
59 
60     - name        : MPIR_CVAR_CH4_OFI_ENABLE_SHARED_CONTEXTS
61       category    : CH4_OFI
62       type        : int
63       default     : 0
64       class       : none
65       verbosity   : MPI_T_VERBOSITY_USER_BASIC
66       scope       : MPI_T_SCOPE_LOCAL
67       description : >-
68         If set to false (zero), MPICH does not use OFI shared contexts.
69         If set to -1, it is determined by the OFI capability sets based on the provider.
70         Otherwise, MPICH tries to use OFI shared contexts. If they are unavailable,
71         it'll fall back to the mode without shared contexts.
72 
73     - name        : MPIR_CVAR_CH4_OFI_ENABLE_MR_SCALABLE
74       category    : CH4_OFI
75       type        : int
76       default     : -1
77       class       : none
78       verbosity   : MPI_T_VERBOSITY_USER_BASIC
79       scope       : MPI_T_SCOPE_LOCAL
80       description : >-
81         This variable is only provided for backward compatibility. When using OFI versions 1.5+, use
82         the other memory region variables.
83 
84         If true, MR_SCALABLE for OFI memory regions.
85         If false, MR_BASIC for OFI memory regions.
86 
87     - name        : MPIR_CVAR_CH4_OFI_ENABLE_MR_VIRT_ADDRESS
88       category    : CH4_OFI
89       type        : int
90       default     : -1
91       class       : none
92       verbosity   : MPI_T_VERBOSITY_USER_BASIC
93       scope       : MPI_T_SCOPE_LOCAL
94       description : >-
95         If true, enable virtual addressing for OFI memory regions. This variable is only meaningful
96         for OFI versions 1.5+. It is equivelent to using FI_MR_BASIC in versions of
97         OFI older than 1.5.
98 
99     - name        : MPIR_CVAR_CH4_OFI_ENABLE_MR_ALLOCATED
100       category    : CH4_OFI
101       type        : int
102       default     : -1
103       class       : none
104       verbosity   : MPI_T_VERBOSITY_USER_BASIC
105       scope       : MPI_T_SCOPE_LOCAL
106       description : >-
107         If true, require all OFI memory regions must be backed by physical memory pages
108         at the time the registration call is made. This variable is only meaningful
109         for OFI versions 1.5+. It is equivelent to using FI_MR_BASIC in versions of
110         OFI older than 1.5.
111 
112     - name        : MPIR_CVAR_CH4_OFI_ENABLE_MR_PROV_KEY
113       category    : CH4_OFI
114       type        : int
115       default     : -1
116       class       : none
117       verbosity   : MPI_T_VERBOSITY_USER_BASIC
118       scope       : MPI_T_SCOPE_LOCAL
119       description : >-
120         If true, enable provider supplied key for OFI memory regions. This variable is only
121         meaningful for OFI versions 1.5+. It is equivelent to using FI_MR_BASIC in versions of OFI
122         older than 1.5.
123 
124     - name        : MPIR_CVAR_CH4_OFI_ENABLE_TAGGED
125       category    : CH4_OFI
126       type        : int
127       default     : -1
128       class       : none
129       verbosity   : MPI_T_VERBOSITY_USER_BASIC
130       scope       : MPI_T_SCOPE_LOCAL
131       description : >-
132         If true, use tagged message transmission functions in OFI.
133 
134     - name        : MPIR_CVAR_CH4_OFI_ENABLE_AM
135       category    : CH4_OFI
136       type        : int
137       default     : -1
138       class       : none
139       verbosity   : MPI_T_VERBOSITY_USER_BASIC
140       scope       : MPI_T_SCOPE_LOCAL
141       description : >-
142         If true, enable OFI active message support.
143 
144     - name        : MPIR_CVAR_CH4_OFI_ENABLE_RMA
145       category    : CH4_OFI
146       type        : int
147       default     : -1
148       class       : none
149       verbosity   : MPI_T_VERBOSITY_USER_BASIC
150       scope       : MPI_T_SCOPE_LOCAL
151       description : >-
152         If true, enable OFI RMA support for MPI RMA operations. OFI support for basic RMA is always
153         required to implement large messgage transfers in the active message code path.
154 
155     - name        : MPIR_CVAR_CH4_OFI_ENABLE_ATOMICS
156       category    : CH4_OFI
157       type        : int
158       default     : -1
159       class       : none
160       verbosity   : MPI_T_VERBOSITY_USER_BASIC
161       scope       : MPI_T_SCOPE_LOCAL
162       description : >-
163         If true, enable OFI Atomics support.
164 
165     - name        : MPIR_CVAR_CH4_OFI_FETCH_ATOMIC_IOVECS
166       category    : CH4_OFI
167       type        : int
168       default     : -1
169       class       : none
170       verbosity   : MPI_T_VERBOSITY_USER_BASIC
171       scope       : MPI_T_SCOPE_LOCAL
172       description : >-
173         Specifies the maximum number of iovecs that can be used by the OFI provider
174         for fetch_atomic operations. The default value is -1, indicating that
175         no value is set.
176 
177     - name        : MPIR_CVAR_CH4_OFI_ENABLE_DATA_AUTO_PROGRESS
178       category    : CH4_OFI
179       type        : int
180       default     : -1
181       class       : none
182       verbosity   : MPI_T_VERBOSITY_USER_BASIC
183       scope       : MPI_T_SCOPE_LOCAL
184       description : >-
185         If true, enable MPI data auto progress.
186 
187     - name        : MPIR_CVAR_CH4_OFI_ENABLE_CONTROL_AUTO_PROGRESS
188       category    : CH4_OFI
189       type        : int
190       default     : -1
191       class       : none
192       verbosity   : MPI_T_VERBOSITY_USER_BASIC
193       scope       : MPI_T_SCOPE_LOCAL
194       description : >-
195         If true, enable MPI control auto progress.
196 
197     - name        : MPIR_CVAR_CH4_OFI_ENABLE_PT2PT_NOPACK
198       category    : CH4_OFI
199       type        : int
200       default     : -1
201       class       : none
202       verbosity   : MPI_T_VERBOSITY_USER_BASIC
203       scope       : MPI_T_SCOPE_LOCAL
204       description : >-
205         If true, enable iovec for pt2pt.
206 
207     - name        : MPIR_CVAR_CH4_OFI_CONTEXT_ID_BITS
208       category    : CH4_OFI
209       type        : int
210       default     : -1
211       class       : none
212       verbosity   : MPI_T_VERBOSITY_USER_BASIC
213       scope       : MPI_T_SCOPE_LOCAL
214       description : >-
215         Specifies the number of bits that will be used for matching the context
216         ID. The default value is -1, indicating that no value is set and that
217         the default will be defined in the ofi_types.h file.
218 
219     - name        : MPIR_CVAR_CH4_OFI_RANK_BITS
220       category    : CH4_OFI
221       type        : int
222       default     : -1
223       class       : none
224       verbosity   : MPI_T_VERBOSITY_USER_BASIC
225       scope       : MPI_T_SCOPE_LOCAL
226       description : >-
227         Specifies the number of bits that will be used for matching the MPI
228         rank. The default value is -1, indicating that no value is set and that
229         the default will be defined in the ofi_types.h file.
230 
231     - name        : MPIR_CVAR_CH4_OFI_TAG_BITS
232       category    : CH4_OFI
233       type        : int
234       default     : -1
235       class       : none
236       verbosity   : MPI_T_VERBOSITY_USER_BASIC
237       scope       : MPI_T_SCOPE_LOCAL
238       description : >-
239         Specifies the number of bits that will be used for matching the user
240         tag. The default value is -1, indicating that no value is set and that
241         the default will be defined in the ofi_types.h file.
242 
243     - name        : MPIR_CVAR_CH4_OFI_MAJOR_VERSION
244       category    : CH4_OFI
245       type        : int
246       default     : -1
247       class       : none
248       verbosity   : MPI_T_VERBOSITY_USER_BASIC
249       scope       : MPI_T_SCOPE_LOCAL
250       description : >-
251         Specifies the major version of the OFI library. The default is the
252         major version of the OFI library used with MPICH. If using this CVAR,
253         it is recommended that the user also specifies a specific OFI provider.
254 
255     - name        : MPIR_CVAR_CH4_OFI_MINOR_VERSION
256       category    : CH4_OFI
257       type        : int
258       default     : -1
259       class       : none
260       verbosity   : MPI_T_VERBOSITY_USER_BASIC
261       scope       : MPI_T_SCOPE_LOCAL
262       description : >-
263         Specifies the major version of the OFI library. The default is the
264         minor version of the OFI library used with MPICH. If using this CVAR,
265         it is recommended that the user also specifies a specific OFI provider.
266 
267     - name        : MPIR_CVAR_CH4_OFI_MAX_VNIS
268       category    : CH4_OFI
269       type        : int
270       default     : 0
271       class       : none
272       verbosity   : MPI_T_VERBOSITY_USER_BASIC
273       scope       : MPI_T_SCOPE_LOCAL
274       description : >-
275         If set to positive, this CVAR specifies the maximum number of CH4 VNIs
276         that OFI netmod exposes. If set to 0 (the default) or bigger than
277         MPIR_CVAR_CH4_NUM_VCIS, the number of exposed VNIs is set to MPIR_CVAR_CH4_NUM_VCIS.
278 
279     - name        : MPIR_CVAR_CH4_OFI_MAX_RMA_SEP_CTX
280       category    : CH4_OFI
281       type        : int
282       default     : 0
283       class       : none
284       verbosity   : MPI_T_VERBOSITY_USER_BASIC
285       scope       : MPI_T_SCOPE_LOCAL
286       description : >-
287         If set to positive, this CVAR specifies the maximum number of transmit
288         contexts RMA can utilize in a scalable endpoint.
289         This value is effective only when scalable endpoint is available, otherwise
290         it will be ignored.
291 
292     - name        : MPIR_CVAR_CH4_OFI_MAX_EAGAIN_RETRY
293       category    : CH4_OFI
294       type        : int
295       default     : -1
296       class       : none
297       verbosity   : MPI_T_VERBOSITY_USER_BASIC
298       scope       : MPI_T_SCOPE_LOCAL
299       description : >-
300         If set to positive, this CVAR specifies the maximum number of retries
301         of an ofi operations before returning MPIX_ERR_EAGAIN. This value is
302         effective only when the communicator has the MPI_OFI_set_eagain info
303         hint set to true.
304 
305     - name        : MPIR_CVAR_CH4_OFI_NUM_AM_BUFFERS
306       category    : CH4_OFI
307       type        : int
308       default     : -1
309       class       : none
310       verbosity   : MPI_T_VERBOSITY_USER_BASIC
311       scope       : MPI_T_SCOPE_LOCAL
312       description : >-
313         Specifies the number of buffers for receiving active messages.
314 
315     - name        : MPIR_CVAR_CH4_OFI_RMA_PROGRESS_INTERVAL
316       category    : CH4_OFI
317       type        : int
318       default     : 100
319       class       : none
320       verbosity   : MPI_T_VERBOSITY_USER_BASIC
321       scope       : MPI_T_SCOPE_LOCAL
322       description : >-
323         Specifies the interval for manually flushing RMA operations when automatic progress is not
324         enabled. It the underlying OFI provider supports auto data progress, this value is ignored.
325         If the value is -1, this optimization will be turned off.
326 
327     - name        : MPIR_CVAR_CH4_OFI_RMA_IOVEC_MAX
328       category    : CH4_OFI
329       type        : int
330       default     : 16384
331       class       : none
332       verbosity   : MPI_T_VERBOSITY_USER_BASIC
333       scope       : MPI_T_SCOPE_LOCAL
334       description : >-
335         Specifies the maximum number of iovecs to allocate for RMA operations
336         to/from noncontiguous buffers.
337 
338     - name        : MPIR_CVAR_CH4_OFI_NUM_PACK_BUFFERS_PER_CHUNK
339       category    : CH4_OFI
340       type        : int
341       default     : 16
342       class       : none
343       verbosity   : MPI_T_VERBOSITY_USER_BASIC
344       scope       : MPI_T_SCOPE_LOCAL
345       description : >-
346         Specifies the number of buffers for packing/unpacking messages in
347         each block of the pool.
348 
349     - name        : MPIR_CVAR_CH4_OFI_MAX_NUM_PACK_BUFFERS
350       category    : CH4_OFI
351       type        : int
352       default     : 256
353       class       : none
354       verbosity   : MPI_T_VERBOSITY_USER_BASIC
355       scope       : MPI_T_SCOPE_LOCAL
356       description : >-
357         Specifies the max number of buffers for packing/unpacking messages
358         in the pool.
359 
360     - name        : MPIR_CVAR_CH4_OFI_EAGER_MAX_MSG_SIZE
361       category    : CH4_OFI
362       type        : int
363       default     : -1
364       class       : none
365       verbosity   : MPI_T_VERBOSITY_USER_BASIC
366       scope       : MPI_T_SCOPE_LOCAL
367       description : >-
368         This cvar controls the message size at which OFI native path switches from eager to
369         rendezvous mode. It does not affect the AM path eager limit. Having this gives a way to
370         reliably test native non-path.
371         If the number is positive, OFI will init the MPIDI_OFI_global.max_msg_size to the value of
372         cvar. If the number is negative, OFI will init the MPIDI_OFI_globa.max_msg_size using
373         whatever provider gives (which might be unlimited for socket provider).
374 
375 === END_MPI_T_CVAR_INFO_BLOCK ===
376 */
377 
378 static int get_ofi_version(void);
379 static int open_fabric(void);
380 static int create_vni_context(int vni);
381 static int destroy_vni_context(int vni);
382 
383 static int conn_manager_init(void);
384 static int conn_manager_destroy(void);
385 static int dynproc_send_disconnect(int conn_id);
386 
387 static int addr_exchange_root_vni(MPIR_Comm * init_comm);
388 static int addr_exchange_all_vnis(void);
389 
390 static void *host_alloc(uintptr_t size);
391 static void *host_alloc_registered(uintptr_t size);
392 static void host_free(void *ptr);
393 static void host_free_registered(void *ptr);
394 
host_alloc(uintptr_t size)395 static void *host_alloc(uintptr_t size)
396 {
397     return MPL_malloc(size, MPL_MEM_BUFFER);
398 }
399 
host_alloc_registered(uintptr_t size)400 static void *host_alloc_registered(uintptr_t size)
401 {
402     void *ptr = MPL_malloc(size, MPL_MEM_BUFFER);
403     MPIR_Assert(ptr);
404     MPL_gpu_register_host(ptr, size);
405     return ptr;
406 }
407 
host_free(void * ptr)408 static void host_free(void *ptr)
409 {
410     MPL_free(ptr);
411 }
412 
host_free_registered(void * ptr)413 static void host_free_registered(void *ptr)
414 {
415     MPL_gpu_unregister_host(ptr);
416     MPL_free(ptr);
417 }
418 
get_ofi_version(void)419 static int get_ofi_version(void)
420 {
421     if (MPIDI_OFI_MAJOR_VERSION != -1 && MPIDI_OFI_MINOR_VERSION != -1)
422         return FI_VERSION(MPIDI_OFI_MAJOR_VERSION, MPIDI_OFI_MINOR_VERSION);
423     else
424         return FI_VERSION(FI_MAJOR_VERSION, FI_MINOR_VERSION);
425 }
426 
conn_manager_init()427 static int conn_manager_init()
428 {
429     int mpi_errno = MPI_SUCCESS, i;
430 
431     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_CONN_MANAGER_INIT);
432     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_CONN_MANAGER_INIT);
433 
434     MPIDI_OFI_global.conn_mgr.max_n_conn = 1;
435     MPIDI_OFI_global.conn_mgr.next_conn_id = 0;
436     MPIDI_OFI_global.conn_mgr.n_conn = 0;
437 
438     MPIDI_OFI_global.conn_mgr.conn_list =
439         (MPIDI_OFI_conn_t *) MPL_malloc(8 * 4 * 1024 /* FIXME: what is this size? */ ,
440                                         MPL_MEM_ADDRESS);
441     MPIR_ERR_CHKANDSTMT(MPIDI_OFI_global.conn_mgr.conn_list == NULL, mpi_errno, MPI_ERR_NO_MEM,
442                         goto fn_fail, "**nomem");
443 
444     MPIDI_OFI_global.conn_mgr.free_conn_id =
445         (int *) MPL_malloc(MPIDI_OFI_global.conn_mgr.max_n_conn * sizeof(int), MPL_MEM_ADDRESS);
446     MPIR_ERR_CHKANDSTMT(MPIDI_OFI_global.conn_mgr.free_conn_id == NULL, mpi_errno,
447                         MPI_ERR_NO_MEM, goto fn_fail, "**nomem");
448 
449     for (i = 0; i < MPIDI_OFI_global.conn_mgr.max_n_conn; ++i) {
450         MPIDI_OFI_global.conn_mgr.free_conn_id[i] = i + 1;
451     }
452     MPIDI_OFI_global.conn_mgr.free_conn_id[MPIDI_OFI_global.conn_mgr.max_n_conn - 1] = -1;
453 
454   fn_exit:
455     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_CONN_MANAGER_INIT);
456     return mpi_errno;
457   fn_fail:
458     goto fn_exit;
459 }
460 
conn_manager_destroy()461 static int conn_manager_destroy()
462 {
463     int mpi_errno = MPI_SUCCESS, i, j;
464     MPIDI_OFI_dynamic_process_request_t *req;
465     fi_addr_t *conn;
466     int max_n_conn = MPIDI_OFI_global.conn_mgr.max_n_conn;
467     int *close_msg;
468     uint64_t match_bits = 0;
469     uint64_t mask_bits = 0;
470     MPIR_Context_id_t context_id = 0xF000;
471     MPIR_CHKLMEM_DECL(3);
472 
473     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_CONN_MANAGER_DESTROY);
474     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_CONN_MANAGER_DESTROY);
475 
476     match_bits = MPIDI_OFI_init_recvtag(&mask_bits, context_id, 1);
477     match_bits |= MPIDI_OFI_DYNPROC_SEND;
478 
479     if (max_n_conn > 0) {
480         /* try wait/close connections */
481         MPIR_CHKLMEM_MALLOC(req, MPIDI_OFI_dynamic_process_request_t *,
482                             max_n_conn * sizeof(MPIDI_OFI_dynamic_process_request_t), mpi_errno,
483                             "req", MPL_MEM_BUFFER);
484         MPIR_CHKLMEM_MALLOC(conn, fi_addr_t *, max_n_conn * sizeof(fi_addr_t), mpi_errno, "conn",
485                             MPL_MEM_BUFFER);
486         MPIR_CHKLMEM_MALLOC(close_msg, int *, max_n_conn * sizeof(int), mpi_errno, "int",
487                             MPL_MEM_BUFFER);
488 
489         j = 0;
490         for (i = 0; i < max_n_conn; ++i) {
491             switch (MPIDI_OFI_global.conn_mgr.conn_list[i].state) {
492                 case MPIDI_OFI_DYNPROC_CONNECTED_CHILD:
493                     mpi_errno = dynproc_send_disconnect(i);
494                     MPIR_ERR_CHECK(mpi_errno);
495                     break;
496                 case MPIDI_OFI_DYNPROC_LOCAL_DISCONNECTED_PARENT:
497                 case MPIDI_OFI_DYNPROC_CONNECTED_PARENT:
498                     MPL_DBG_MSG_FMT(MPIDI_CH4_DBG_GENERAL, VERBOSE,
499                                     (MPL_DBG_FDEST, "Wait for close of conn_id=%d", i));
500                     conn[j] = MPIDI_OFI_global.conn_mgr.conn_list[i].dest;
501                     req[j].done = 0;
502                     req[j].event_id = MPIDI_OFI_EVENT_DYNPROC_DONE;
503                     MPIDI_OFI_CALL_RETRY(fi_trecv(MPIDI_OFI_global.ctx[0].rx,
504                                                   &close_msg[j],
505                                                   sizeof(int),
506                                                   NULL,
507                                                   conn[j],
508                                                   match_bits,
509                                                   mask_bits, &req[j].context), 0, trecv, FALSE);
510                     j++;
511                     break;
512                 default:
513                     break;
514             }
515         }
516 
517         for (i = 0; i < j; ++i) {
518             MPIDI_OFI_PROGRESS_WHILE(!req[i].done, 0);
519             MPIDI_OFI_global.conn_mgr.conn_list[i].state = MPIDI_OFI_DYNPROC_DISCONNECTED;
520             MPL_DBG_MSG_FMT(MPIDI_CH4_DBG_GENERAL, VERBOSE,
521                             (MPL_DBG_FDEST, "conn_id=%d closed", i));
522         }
523 
524         MPIR_CHKLMEM_FREEALL();
525     }
526 
527     MPL_free(MPIDI_OFI_global.conn_mgr.conn_list);
528     MPL_free(MPIDI_OFI_global.conn_mgr.free_conn_id);
529 
530   fn_exit:
531     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_CONN_MANAGER_DESTROY);
532     return mpi_errno;
533   fn_fail:
534     goto fn_exit;
535 }
536 
dynproc_send_disconnect(int conn_id)537 static int dynproc_send_disconnect(int conn_id)
538 {
539     int mpi_errno = MPI_SUCCESS;
540 
541     MPIR_Context_id_t context_id = 0xF000;
542     MPIDI_OFI_dynamic_process_request_t req;
543     uint64_t match_bits = 0;
544     unsigned int close_msg = 0xcccccccc;
545     struct fi_msg_tagged msg;
546     struct iovec msg_iov;
547 
548     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_DYNPROC_SEND_DISCONNECT);
549     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_DYNPROC_SEND_DISCONNECT);
550 
551     if (MPIDI_OFI_global.conn_mgr.conn_list[conn_id].state == MPIDI_OFI_DYNPROC_CONNECTED_CHILD) {
552         MPL_DBG_MSG_FMT(MPIDI_CH4_DBG_GENERAL, VERBOSE,
553                         (MPL_DBG_FDEST, " send disconnect msg conn_id=%d from child side",
554                          conn_id));
555         match_bits = MPIDI_OFI_init_sendtag(context_id, 1, MPIDI_OFI_DYNPROC_SEND);
556 
557         /* fi_av_map here is not quite right for some providers */
558         /* we need to get this connection from the sockname     */
559         req.done = 0;
560         req.event_id = MPIDI_OFI_EVENT_DYNPROC_DONE;
561         msg_iov.iov_base = &close_msg;
562         msg_iov.iov_len = sizeof(close_msg);
563         msg.msg_iov = &msg_iov;
564         msg.desc = NULL;
565         msg.iov_count = 0;
566         msg.addr = MPIDI_OFI_global.conn_mgr.conn_list[conn_id].dest;
567         msg.tag = match_bits;
568         msg.ignore = context_id;
569         msg.context = (void *) &req.context;
570         msg.data = 0;
571         MPIDI_OFI_CALL_RETRY(fi_tsendmsg(MPIDI_OFI_global.ctx[0].tx, &msg,
572                                          FI_COMPLETION | FI_TRANSMIT_COMPLETE | FI_REMOTE_CQ_DATA),
573                              0, tsendmsg, FALSE);
574         MPIDI_OFI_PROGRESS_WHILE(!req.done, 0);
575     }
576 
577     switch (MPIDI_OFI_global.conn_mgr.conn_list[conn_id].state) {
578         case MPIDI_OFI_DYNPROC_CONNECTED_CHILD:
579             MPIDI_OFI_global.conn_mgr.conn_list[conn_id].state =
580                 MPIDI_OFI_DYNPROC_LOCAL_DISCONNECTED_CHILD;
581             break;
582         case MPIDI_OFI_DYNPROC_CONNECTED_PARENT:
583             MPIDI_OFI_global.conn_mgr.conn_list[conn_id].state =
584                 MPIDI_OFI_DYNPROC_LOCAL_DISCONNECTED_PARENT;
585             break;
586         default:
587             break;
588     }
589 
590     MPL_DBG_MSG_FMT(MPIDI_CH4_DBG_GENERAL, VERBOSE,
591                     (MPL_DBG_FDEST, " local_disconnected conn_id=%d state=%d",
592                      conn_id, MPIDI_OFI_global.conn_mgr.conn_list[conn_id].state));
593 
594   fn_exit:
595     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_DYNPROC_SEND_DISCONNECT);
596     return mpi_errno;
597   fn_fail:
598     goto fn_exit;
599 }
600 
MPIDI_OFI_mpi_init_hook(int rank,int size,int appnum,int * tag_bits,MPIR_Comm * init_comm)601 int MPIDI_OFI_mpi_init_hook(int rank, int size, int appnum, int *tag_bits, MPIR_Comm * init_comm)
602 {
603     int mpi_errno = MPI_SUCCESS, i;
604     size_t optlen;
605 
606     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDI_OFI_MPI_INIT_HOOK);
607     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDI_OFI_MPI_INIT_HOOK);
608 
609     MPL_COMPILE_TIME_ASSERT(offsetof(struct MPIR_Request, dev.ch4.netmod) ==
610                             offsetof(MPIDI_OFI_chunk_request, context));
611     MPL_COMPILE_TIME_ASSERT(offsetof(struct MPIR_Request, dev.ch4.netmod) ==
612                             offsetof(MPIDI_OFI_huge_recv_t, context));
613     MPL_COMPILE_TIME_ASSERT(offsetof(struct MPIR_Request, dev.ch4.netmod) ==
614                             offsetof(MPIDI_OFI_am_repost_request_t, context));
615     MPL_COMPILE_TIME_ASSERT(offsetof(struct MPIR_Request, dev.ch4.netmod) ==
616                             offsetof(MPIDI_OFI_ssendack_request_t, context));
617     MPL_COMPILE_TIME_ASSERT(offsetof(struct MPIR_Request, dev.ch4.netmod) ==
618                             offsetof(MPIDI_OFI_dynamic_process_request_t, context));
619     MPL_COMPILE_TIME_ASSERT(offsetof(struct MPIR_Request, dev.ch4.am.netmod_am.ofi.context) ==
620                             offsetof(struct MPIR_Request, dev.ch4.netmod.ofi.context));
621     MPL_COMPILE_TIME_ASSERT(sizeof(MPIDI_Devreq_t) >= sizeof(MPIDI_OFI_request_t));
622 
623     int err;
624     MPID_Thread_mutex_create(&MPIDI_OFI_THREAD_UTIL_MUTEX, &err);
625     MPIR_Assert(err == 0);
626 
627     MPID_Thread_mutex_create(&MPIDI_OFI_THREAD_PROGRESS_MUTEX, &err);
628     MPIR_Assert(err == 0);
629 
630     MPID_Thread_mutex_create(&MPIDI_OFI_THREAD_FI_MUTEX, &err);
631     MPIR_Assert(err == 0);
632 
633     MPID_Thread_mutex_create(&MPIDI_OFI_THREAD_SPAWN_MUTEX, &err);
634     MPIR_Assert(err == 0);
635 
636     mpi_errno = open_fabric();
637     MPIR_ERR_CHECK(mpi_errno);
638 
639     /* ------------------------------------------------------------------------ */
640     /* Create transport level communication contexts.                           */
641     /* ------------------------------------------------------------------------ */
642 
643     int num_vnis = 1;
644     if (MPIR_CVAR_CH4_OFI_MAX_VNIS == 0 || MPIR_CVAR_CH4_OFI_MAX_VNIS > MPIDI_global.n_vcis) {
645         num_vnis = MPIDI_global.n_vcis;
646     } else {
647         num_vnis = MPIR_CVAR_CH4_OFI_MAX_VNIS;
648     }
649 
650     /* TODO: update num_vnis according to provider capabilities, such as
651      * prov_use->domain_attr->{tx,rx}_ctx_cnt
652      */
653     if (num_vnis > MPIDI_OFI_MAX_VNIS) {
654         num_vnis = MPIDI_OFI_MAX_VNIS;
655     }
656     /* for best performance, we ensure 1-to-1 vci/vni mapping. ref: MPIDI_OFI_vci_to_vni */
657     /* TODO: allow less num_vnis. Option 1. runtime MOD; 2. overide MPIDI_global.n_vcis */
658     MPIR_Assert(num_vnis == MPIDI_global.n_vcis);
659 
660     /* Multiple vni without using domain require MPIDI_OFI_ENABLE_SCALABLE_ENDPOINTS */
661 #ifndef MPIDI_OFI_VNI_USE_DOMAIN
662     MPIR_Assert(num_vnis == 1 || MPIDI_OFI_ENABLE_SCALABLE_ENDPOINTS);
663 #endif
664 
665     /* WorkQ only works with single vni for now */
666 #ifdef MPIDI_CH4_USE_WORK_QUEUES
667     MPIR_Assert(num_vnis == 1);
668 #endif
669 
670     MPIDI_OFI_global.num_vnis = num_vnis;
671 
672     /* Create MPIDI_OFI_global.ctx[0] first  */
673     mpi_errno = create_vni_context(0);
674     MPIR_ERR_CHECK(mpi_errno);
675 
676     /* Creating the additional vni contexts.
677      * This code maybe moved to a later stage */
678     for (i = 1; i < MPIDI_OFI_global.num_vnis; i++) {
679         mpi_errno = create_vni_context(i);
680         MPIR_ERR_CHECK(mpi_errno);
681     }
682 
683     /* ------------------------------------------------------------------------ */
684     /* Address exchange (essentially activating the vnis)                       */
685     /* ------------------------------------------------------------------------ */
686 
687     if (!MPIDI_OFI_global.got_named_av) {
688         mpi_errno = addr_exchange_root_vni(init_comm);
689         MPIR_ERR_CHECK(mpi_errno);
690     }
691 
692     /* -------------------------------- */
693     /* Create the id to object maps     */
694     /* -------------------------------- */
695     MPIDIU_map_create(&MPIDI_OFI_global.win_map, MPL_MEM_RMA);
696     MPIDIU_map_create(&MPIDI_OFI_global.req_map, MPL_MEM_OTHER);
697 
698     /* ---------------------------------- */
699     /* Initialize Active Message          */
700     /* ---------------------------------- */
701     if (MPIDI_OFI_ENABLE_AM) {
702         /* Maximum possible message size for short message send (=eager send)
703          * See MPIDI_OFI_do_am_isend for short/long switching logic */
704         MPIR_Assert(MPIDI_OFI_DEFAULT_SHORT_SEND_SIZE <= MPIDI_OFI_global.max_msg_size);
705         MPL_COMPILE_TIME_ASSERT(sizeof(MPIDI_OFI_am_request_header_t)
706                                 < MPIDI_OFI_AM_HDR_POOL_CELL_SIZE);
707         MPL_COMPILE_TIME_ASSERT(MPIDI_OFI_AM_HDR_POOL_CELL_SIZE
708                                 >= sizeof(MPIDI_OFI_am_send_pipeline_request_t));
709         mpi_errno =
710             MPIDU_genq_private_pool_create_unsafe(MPIDI_OFI_AM_HDR_POOL_CELL_SIZE,
711                                                   MPIDI_OFI_AM_HDR_POOL_NUM_CELLS_PER_CHUNK,
712                                                   MPIDI_OFI_AM_HDR_POOL_MAX_NUM_CELLS,
713                                                   host_alloc, host_free,
714                                                   &MPIDI_OFI_global.am_hdr_buf_pool);
715         MPIR_ERR_CHECK(mpi_errno);
716 
717         MPIDI_OFI_global.cq_buffered_dynamic_head = MPIDI_OFI_global.cq_buffered_dynamic_tail =
718             NULL;
719         MPIDI_OFI_global.cq_buffered_static_head = MPIDI_OFI_global.cq_buffered_static_tail = 0;
720         optlen = MPIDI_OFI_DEFAULT_SHORT_SEND_SIZE;
721 
722         MPIDI_OFI_CALL(fi_setopt(&(MPIDI_OFI_global.ctx[0].rx->fid),
723                                  FI_OPT_ENDPOINT,
724                                  FI_OPT_MIN_MULTI_RECV, &optlen, sizeof(optlen)), setopt);
725 
726         MPIDIU_map_create(&MPIDI_OFI_global.am_recv_seq_tracker, MPL_MEM_BUFFER);
727         MPIDIU_map_create(&MPIDI_OFI_global.am_send_seq_tracker, MPL_MEM_BUFFER);
728         MPIDI_OFI_global.am_unordered_msgs = NULL;
729 
730         for (i = 0; i < MPIDI_OFI_NUM_AM_BUFFERS; i++) {
731             MPL_gpu_malloc_host(&(MPIDI_OFI_global.am_bufs[i]), MPIDI_OFI_AM_BUFF_SZ);
732             MPIDI_OFI_global.am_reqs[i].event_id = MPIDI_OFI_EVENT_AM_RECV;
733             MPIDI_OFI_global.am_reqs[i].index = i;
734             MPIDI_OFI_global.am_iov[i].iov_base = MPIDI_OFI_global.am_bufs[i];
735             MPIDI_OFI_global.am_iov[i].iov_len = MPIDI_OFI_AM_BUFF_SZ;
736             MPIDI_OFI_global.am_msg[i].msg_iov = &MPIDI_OFI_global.am_iov[i];
737             MPIDI_OFI_global.am_msg[i].desc = NULL;
738             MPIDI_OFI_global.am_msg[i].addr = FI_ADDR_UNSPEC;
739             MPIDI_OFI_global.am_msg[i].context = &MPIDI_OFI_global.am_reqs[i].context;
740             MPIDI_OFI_global.am_msg[i].iov_count = 1;
741             MPIDI_OFI_CALL_RETRY(fi_recvmsg(MPIDI_OFI_global.ctx[0].rx,
742                                             &MPIDI_OFI_global.am_msg[i],
743                                             FI_MULTI_RECV | FI_COMPLETION), 0, prepost, FALSE);
744         }
745 
746         MPIDIG_am_reg_cb(MPIDI_OFI_INTERNAL_HANDLER_CONTROL, NULL, &MPIDI_OFI_control_handler);
747     }
748     MPL_atomic_store_int(&MPIDI_OFI_global.am_inflight_inject_emus, 0);
749     MPL_atomic_store_int(&MPIDI_OFI_global.am_inflight_rma_send_mrs, 0);
750 
751     /* Create pack buffer pool */
752     mpi_errno =
753         MPIDU_genq_private_pool_create_unsafe(MPIDI_OFI_DEFAULT_SHORT_SEND_SIZE,
754                                               MPIR_CVAR_CH4_OFI_NUM_PACK_BUFFERS_PER_CHUNK,
755                                               MPIR_CVAR_CH4_OFI_MAX_NUM_PACK_BUFFERS,
756                                               host_alloc_registered,
757                                               host_free_registered,
758                                               &MPIDI_OFI_global.pack_buf_pool);
759     MPIR_ERR_CHECK(mpi_errno);
760 
761     /* Initalize RMA keys allocator */
762     MPIDI_OFI_mr_key_allocator_init();
763 
764     /* ------------------------------------------------- */
765     /* Initialize Connection Manager for Dynamic Tasking */
766     /* ------------------------------------------------- */
767     conn_manager_init();
768 
769     MPIR_Comm_register_hint(MPIR_COMM_HINT_EAGAIN, "eagain", NULL, MPIR_COMM_HINT_TYPE_BOOL, 0);
770 
771     /* index datatypes for RMA atomics */
772     MPIDI_OFI_index_datatypes();
773 
774     MPIDI_OFI_global.deferred_am_isend_q = NULL;
775 
776   fn_exit:
777     *tag_bits = MPIDI_OFI_TAG_BITS;
778 
779     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDI_OFI_MPI_INIT_HOOK);
780     return mpi_errno;
781   fn_fail:
782     goto fn_exit;
783 }
784 
MPIDI_OFI_mpi_finalize_hook(void)785 int MPIDI_OFI_mpi_finalize_hook(void)
786 {
787     int mpi_errno = MPI_SUCCESS;
788     int i = 0;
789     int barrier[2] = { 0 };
790     MPIR_Errflag_t errflag = MPIR_ERR_NONE;
791 
792     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDI_OFI_MPI_FINALIZE_HOOK);
793     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDI_OFI_MPI_FINALIZE_HOOK);
794 
795     /* clean dynamic process connections */
796     mpi_errno = conn_manager_destroy();
797     MPIR_ERR_CHECK(mpi_errno);
798 
799     /* Progress until we drain all inflight RMA send long buffers */
800     /* NOTE: am currently only use vni 0. Need update once that changes */
801     while (MPL_atomic_load_int(&MPIDI_OFI_global.am_inflight_rma_send_mrs) > 0)
802         MPIDI_OFI_PROGRESS(0);
803 
804     /* Destroy RMA key allocator */
805     MPIDI_OFI_mr_key_allocator_destroy();
806 
807     /* Barrier over allreduce, but force non-immediate send */
808     MPIDI_OFI_global.max_buffered_send = 0;
809     mpi_errno = MPIR_Allreduce_allcomm_auto(&barrier[0], &barrier[1], 1, MPI_INT, MPI_SUM,
810                                             MPIR_Process.comm_world, &errflag);
811     MPIR_ERR_CHECK(mpi_errno);
812 
813     /* Progress until we drain all inflight injection emulation requests */
814     /* NOTE: am currently only use vni 0. Need update once that changes */
815     while (MPL_atomic_load_int(&MPIDI_OFI_global.am_inflight_inject_emus) > 0)
816         MPIDI_OFI_PROGRESS(0);
817     MPIR_Assert(MPL_atomic_load_int(&MPIDI_OFI_global.am_inflight_inject_emus) == 0);
818 
819     /* Tearing down endpoints */
820     for (i = 1; i < MPIDI_OFI_global.num_vnis; i++) {
821         mpi_errno = destroy_vni_context(i);
822         MPIR_ERR_CHECK(mpi_errno);
823     }
824     /* 0th ctx is special, synonymous to global context */
825     mpi_errno = destroy_vni_context(0);
826     MPIR_ERR_CHECK(mpi_errno);
827 
828     MPIDI_OFI_CALL(fi_close(&MPIDI_OFI_global.fabric->fid), fabricclose);
829 
830     fi_freeinfo(MPIDI_OFI_global.prov_use);
831 
832     MPIDIU_map_destroy(MPIDI_OFI_global.win_map);
833     MPIDIU_map_destroy(MPIDI_OFI_global.req_map);
834 
835     if (MPIDI_OFI_ENABLE_AM) {
836         while (MPIDI_OFI_global.am_unordered_msgs) {
837             MPIDI_OFI_am_unordered_msg_t *uo_msg = MPIDI_OFI_global.am_unordered_msgs;
838             DL_DELETE(MPIDI_OFI_global.am_unordered_msgs, uo_msg);
839         }
840         MPIDIU_map_destroy(MPIDI_OFI_global.am_send_seq_tracker);
841         MPIDIU_map_destroy(MPIDI_OFI_global.am_recv_seq_tracker);
842 
843         for (i = 0; i < MPIDI_OFI_NUM_AM_BUFFERS; i++)
844             MPL_gpu_free_host(MPIDI_OFI_global.am_bufs[i]);
845 
846         MPIDU_genq_private_pool_destroy_unsafe(MPIDI_OFI_global.am_hdr_buf_pool);
847 
848         MPIR_Assert(MPIDI_OFI_global.cq_buffered_static_head ==
849                     MPIDI_OFI_global.cq_buffered_static_tail);
850         MPIR_Assert(NULL == MPIDI_OFI_global.cq_buffered_dynamic_head);
851     }
852 
853     MPIDU_genq_private_pool_destroy_unsafe(MPIDI_OFI_global.pack_buf_pool);
854 
855     int err;
856     MPID_Thread_mutex_destroy(&MPIDI_OFI_THREAD_UTIL_MUTEX, &err);
857     MPIR_Assert(err == 0);
858 
859     MPID_Thread_mutex_destroy(&MPIDI_OFI_THREAD_PROGRESS_MUTEX, &err);
860     MPIR_Assert(err == 0);
861 
862     MPID_Thread_mutex_destroy(&MPIDI_OFI_THREAD_FI_MUTEX, &err);
863     MPIR_Assert(err == 0);
864 
865     MPID_Thread_mutex_destroy(&MPIDI_OFI_THREAD_SPAWN_MUTEX, &err);
866     MPIR_Assert(err == 0);
867 
868   fn_exit:
869     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDI_OFI_MPI_FINALIZE_HOOK);
870     return mpi_errno;
871   fn_fail:
872     goto fn_exit;
873 }
874 
MPIDI_OFI_post_init(void)875 int MPIDI_OFI_post_init(void)
876 {
877     int mpi_errno = MPI_SUCCESS;
878     if (MPIDI_OFI_global.num_vnis > 1) {
879         mpi_errno = addr_exchange_all_vnis();
880     }
881     return mpi_errno;
882 }
883 
MPIDI_OFI_get_vci_attr(int vci)884 int MPIDI_OFI_get_vci_attr(int vci)
885 {
886     MPIR_Assert(0 <= vci && vci < 1);
887     return MPIDI_VCI_TX | MPIDI_VCI_RX;
888 }
889 
MPIDI_OFI_mpi_alloc_mem(size_t size,MPIR_Info * info_ptr)890 void *MPIDI_OFI_mpi_alloc_mem(size_t size, MPIR_Info * info_ptr)
891 {
892     return MPIDIG_mpi_alloc_mem(size, info_ptr);
893 }
894 
MPIDI_OFI_mpi_free_mem(void * ptr)895 int MPIDI_OFI_mpi_free_mem(void *ptr)
896 {
897     return MPIDIG_mpi_free_mem(ptr);
898 }
899 
900 /* ---- static functions for vni contexts ---- */
901 static int create_vni_domain(struct fid_domain **p_domain, struct fid_av **p_av,
902                              struct fid_cntr **p_cntr);
903 static int create_cq(struct fid_domain *domain, struct fid_cq **p_cq);
904 static int create_sep_tx(struct fid_ep *ep, int idx, struct fid_ep **p_tx,
905                          struct fid_cq *cq, struct fid_cntr *cntr);
906 static int create_sep_rx(struct fid_ep *ep, int idx, struct fid_ep **p_rx,
907                          struct fid_cq *cq, struct fid_cntr *cntr);
908 static int try_open_shared_av(struct fid_domain *domain, struct fid_av **p_av);
909 static int create_rma_stx_ctx(struct fid_domain *domain, struct fid_stx **p_rma_stx_ctx);
910 
create_vni_context(int vni)911 static int create_vni_context(int vni)
912 {
913     int mpi_errno = MPI_SUCCESS;
914 
915     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_CREATE_VNI_CONTEXT);
916     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_CREATE_VNI_CONTEXT);
917 
918     struct fi_info *prov_use = MPIDI_OFI_global.prov_use;
919 
920     /* Each VNI context consists of domain, av, cq, cntr, etc.
921      *
922      * If MPIDI_OFI_VNI_USE_DOMAIN is true, each context is a separate domain,
923      * within which are each separate av, cq, cntr, ..., everything. Within the
924      * VNI context, it still can use either simple endpoint or scalable endpoint.
925      *
926      * If MPIDI_OFI_VNI_USE_DOMAIN is false, then all the VNI contexts will share
927      * the same domain and av, and use a single scalable endpoint. Separate VNI
928      * context will have its separate cq and separate tx and rx with the SEP.
929      *
930      * To accomodate both configurations, each context structure will have all fields
931      * including domain, av, cq, ... For "VNI_USE_DOMAIN", they are not shared.
932      * When not "VNI_USE_DOMAIN" or "VNI_USE_SEPCTX", domain, av, and ep are shared
933      * with the root (or 0th) VNI context.
934      */
935     struct fid_domain *domain;
936     struct fid_av *av;
937     struct fid_cntr *rma_cmpl_cntr;
938     struct fid_cq *cq;
939 
940     struct fid_ep *ep;
941     struct fid_ep *tx;
942     struct fid_ep *rx;
943 
944 #ifdef MPIDI_OFI_VNI_USE_DOMAIN
945     mpi_errno = create_vni_domain(&domain, &av, &rma_cmpl_cntr);
946     MPIR_ERR_CHECK(mpi_errno);
947     mpi_errno = create_cq(domain, &cq);
948     MPIR_ERR_CHECK(mpi_errno);
949 
950     if (MPIDI_OFI_ENABLE_SCALABLE_ENDPOINTS) {
951         MPIDI_OFI_CALL(fi_scalable_ep(domain, prov_use, &ep, NULL), ep);
952         MPIDI_OFI_CALL(fi_scalable_ep_bind(ep, &av->fid, 0), bind);
953         MPIDI_OFI_CALL(fi_enable(ep), ep_enable);
954 
955         mpi_errno = create_sep_tx(ep, 0, &tx, cq, rma_cmpl_cntr);
956         MPIR_ERR_CHECK(mpi_errno);
957         mpi_errno = create_sep_rx(ep, 0, &rx, cq, rma_cmpl_cntr);
958         MPIR_ERR_CHECK(mpi_errno);
959     } else {
960         MPIDI_OFI_CALL(fi_endpoint(domain, prov_use, &ep, NULL), ep);
961         MPIDI_OFI_CALL(fi_ep_bind(ep, &av->fid, 0), bind);
962         MPIDI_OFI_CALL(fi_ep_bind(ep, &cq->fid, FI_SEND | FI_RECV | FI_SELECTIVE_COMPLETION), bind);
963         MPIDI_OFI_CALL(fi_ep_bind(ep, &rma_cmpl_cntr->fid, FI_READ | FI_WRITE), bind);
964         MPIDI_OFI_CALL(fi_enable(ep), ep_enable);
965         tx = ep;
966         rx = ep;
967     }
968     MPIDI_OFI_global.ctx[vni].domain = domain;
969     MPIDI_OFI_global.ctx[vni].av = av;
970     MPIDI_OFI_global.ctx[vni].rma_cmpl_cntr = rma_cmpl_cntr;
971     MPIDI_OFI_global.ctx[vni].ep = ep;
972     MPIDI_OFI_global.ctx[vni].cq = cq;
973     MPIDI_OFI_global.ctx[vni].tx = tx;
974     MPIDI_OFI_global.ctx[vni].rx = rx;
975 
976 #else /* MPIDI_OFI_VNI_USE_SEPCTX */
977     if (vni == 0) {
978         mpi_errno = create_vni_domain(&domain, &av, &rma_cmpl_cntr);
979         MPIR_ERR_CHECK(mpi_errno);
980     } else {
981         domain = MPIDI_OFI_global.ctx[0].domain;
982         av = MPIDI_OFI_global.ctx[0].av;
983         rma_cmpl_cntr = MPIDI_OFI_global.ctx[0].rma_cmpl_cntr;
984     }
985     mpi_errno = create_cq(domain, &cq);
986     MPIR_ERR_CHECK(mpi_errno);
987 
988     if (vni == 0) {
989         if (MPIDI_OFI_ENABLE_SCALABLE_ENDPOINTS) {
990             MPIDI_OFI_CALL(fi_scalable_ep(domain, prov_use, &ep, NULL), ep);
991             MPIDI_OFI_CALL(fi_scalable_ep_bind(ep, &av->fid, 0), bind);
992             MPIDI_OFI_CALL(fi_enable(ep), ep_enable);
993         } else {
994             MPIDI_OFI_CALL(fi_endpoint(domain, prov_use, &ep, NULL), ep);
995             MPIDI_OFI_CALL(fi_ep_bind(ep, &av->fid, 0), bind);
996             MPIDI_OFI_CALL(fi_ep_bind(ep, &cq->fid, FI_SEND | FI_RECV | FI_SELECTIVE_COMPLETION),
997                            bind);
998             MPIDI_OFI_CALL(fi_ep_bind(ep, &rma_cmpl_cntr->fid, FI_READ | FI_WRITE), bind);
999             MPIDI_OFI_CALL(fi_enable(ep), ep_enable);
1000         }
1001     } else {
1002         ep = MPIDI_OFI_global.ctx[0].ep;
1003     }
1004 
1005     if (MPIDI_OFI_ENABLE_SCALABLE_ENDPOINTS) {
1006         mpi_errno = create_sep_tx(ep, vni, &tx, cq, rma_cmpl_cntr);
1007         MPIR_ERR_CHECK(mpi_errno);
1008         mpi_errno = create_sep_rx(ep, vni, &rx, cq, rma_cmpl_cntr);
1009         MPIR_ERR_CHECK(mpi_errno);
1010     } else {
1011         tx = ep;
1012         rx = ep;
1013     }
1014 
1015     if (vni == 0) {
1016         MPIDI_OFI_global.ctx[0].domain = domain;
1017         MPIDI_OFI_global.ctx[0].av = av;
1018         MPIDI_OFI_global.ctx[0].rma_cmpl_cntr = rma_cmpl_cntr;
1019         MPIDI_OFI_global.ctx[0].ep = ep;
1020     }
1021     MPIDI_OFI_global.ctx[vni].cq = cq;
1022     MPIDI_OFI_global.ctx[vni].tx = tx;
1023     MPIDI_OFI_global.ctx[vni].rx = rx;
1024 #endif
1025 
1026     /* ------------------------------------------------------------------------ */
1027     /* Construct:  Shared TX Context for RMA                                    */
1028     /* ------------------------------------------------------------------------ */
1029     if (vni == 0) {
1030         mpi_errno = create_rma_stx_ctx(domain, &MPIDI_OFI_global.rma_stx_ctx);
1031         MPIR_ERR_CHECK(mpi_errno);
1032     }
1033 
1034   fn_exit:
1035     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_CREATE_VNI_CONTEXT);
1036     return mpi_errno;
1037   fn_fail:
1038     goto fn_exit;
1039 }
1040 
1041 /* ---------------------------------------------------------- */
1042 /* Provider Selections and open_fabric()                      */
1043 /* ---------------------------------------------------------- */
1044 static int find_provider(struct fi_info *hints);
1045 static void update_global_settings(struct fi_info *prov, struct fi_info *hints);
1046 static void dump_global_settings(void);
1047 
1048 /* set MPIDI_OFI_global.settings based on provider-set */
1049 static void init_global_settings(const char *prov_name);
1050 /* set hints based on MPIDI_OFI_global.settings */
1051 static void init_hints(struct fi_info *hints);
1052 /* whether prov matches MPIDI_OFI_global.settings */
1053 bool match_global_settings(struct fi_info *prov);
1054 /* picks one matching provider from the list or return NULL */
1055 static struct fi_info *pick_provider_from_list(const char *provname, struct fi_info *prov_list);
1056 static struct fi_info *pick_provider_by_name(const char *provname, struct fi_info *prov_list);
1057 static struct fi_info *pick_provider_by_global_settings(struct fi_info *prov_list);
1058 
destroy_vni_context(int vni)1059 static int destroy_vni_context(int vni)
1060 {
1061     int mpi_errno = MPI_SUCCESS;
1062 
1063 #ifdef MPIDI_OFI_VNI_USE_DOMAIN
1064     if (MPIDI_OFI_ENABLE_SCALABLE_ENDPOINTS) {
1065         MPIDI_OFI_CALL(fi_close((fid_t) MPIDI_OFI_global.ctx[vni].tx), epclose);
1066         MPIDI_OFI_CALL(fi_close((fid_t) MPIDI_OFI_global.ctx[vni].rx), epclose);
1067         MPIDI_OFI_CALL(fi_close((fid_t) MPIDI_OFI_global.ctx[vni].cq), cqclose);
1068 
1069         MPIDI_OFI_CALL(fi_close(&MPIDI_OFI_global.ctx[vni].ep->fid), epclose);
1070         MPIDI_OFI_CALL(fi_close(&MPIDI_OFI_global.ctx[vni].av->fid), avclose);
1071         MPIDI_OFI_CALL(fi_close(&MPIDI_OFI_global.ctx[vni].rma_cmpl_cntr->fid), cntrclose);
1072         MPIDI_OFI_CALL(fi_close(&MPIDI_OFI_global.ctx[vni].domain->fid), domainclose);
1073     } else {    /* normal endpoint */
1074         MPIDI_OFI_CALL(fi_close(&MPIDI_OFI_global.ctx[vni].ep->fid), epclose);
1075         MPIDI_OFI_CALL(fi_close(&MPIDI_OFI_global.ctx[vni].cq->fid), cqclose);
1076         MPIDI_OFI_CALL(fi_close(&MPIDI_OFI_global.ctx[vni].av->fid), avclose);
1077         MPIDI_OFI_CALL(fi_close(&MPIDI_OFI_global.ctx[vni].rma_cmpl_cntr->fid), cntrclose);
1078         MPIDI_OFI_CALL(fi_close(&MPIDI_OFI_global.ctx[vni].domain->fid), domainclose);
1079     }
1080 
1081 #else /* MPIDI_OFI_VNI_USE_SEPCTX */
1082     if (MPIDI_OFI_ENABLE_SCALABLE_ENDPOINTS) {
1083         MPIDI_OFI_CALL(fi_close((fid_t) MPIDI_OFI_global.ctx[vni].tx), epclose);
1084         MPIDI_OFI_CALL(fi_close((fid_t) MPIDI_OFI_global.ctx[vni].rx), epclose);
1085         MPIDI_OFI_CALL(fi_close((fid_t) MPIDI_OFI_global.ctx[vni].cq), cqclose);
1086         if (vni == 0) {
1087             MPIDI_OFI_CALL(fi_close(&MPIDI_OFI_global.ctx[vni].ep->fid), epclose);
1088             MPIDI_OFI_CALL(fi_close(&MPIDI_OFI_global.ctx[vni].av->fid), avclose);
1089             MPIDI_OFI_CALL(fi_close(&MPIDI_OFI_global.ctx[vni].rma_cmpl_cntr->fid), cntrclose);
1090             MPIDI_OFI_CALL(fi_close(&MPIDI_OFI_global.ctx[vni].domain->fid), domainclose);
1091         }
1092     } else {    /* normal endpoint */
1093         MPIR_Assert(vni == 0);
1094         MPIDI_OFI_CALL(fi_close(&MPIDI_OFI_global.ctx[vni].ep->fid), epclose);
1095         MPIDI_OFI_CALL(fi_close(&MPIDI_OFI_global.ctx[vni].cq->fid), cqclose);
1096         MPIDI_OFI_CALL(fi_close(&MPIDI_OFI_global.ctx[vni].av->fid), avclose);
1097         MPIDI_OFI_CALL(fi_close(&MPIDI_OFI_global.ctx[vni].rma_cmpl_cntr->fid), cntrclose);
1098         MPIDI_OFI_CALL(fi_close(&MPIDI_OFI_global.ctx[vni].domain->fid), domainclose);
1099     }
1100 #endif
1101     if (vni == 0) {
1102         /* Close RMA scalable EP. */
1103         if (MPIDI_OFI_global.rma_sep) {
1104             /* All transmit contexts on RMA must be closed. */
1105             MPIR_Assert(utarray_len(MPIDI_OFI_global.rma_sep_idx_array) ==
1106                         MPIDI_OFI_global.max_rma_sep_tx_cnt);
1107             utarray_free(MPIDI_OFI_global.rma_sep_idx_array);
1108             MPIDI_OFI_CALL(fi_close(&MPIDI_OFI_global.rma_sep->fid), epclose);
1109         }
1110 
1111         if (MPIDI_OFI_global.rma_stx_ctx != NULL) {
1112             MPIDI_OFI_CALL(fi_close(&MPIDI_OFI_global.rma_stx_ctx->fid), stx_ctx_close);
1113         }
1114     }
1115 
1116   fn_exit:
1117     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_DESTROY_VNI_CONTEXT);
1118     return mpi_errno;
1119   fn_fail:
1120     goto fn_exit;
1121 }
1122 
create_vni_domain(struct fid_domain ** p_domain,struct fid_av ** p_av,struct fid_cntr ** p_cntr)1123 static int create_vni_domain(struct fid_domain **p_domain, struct fid_av **p_av,
1124                              struct fid_cntr **p_cntr)
1125 {
1126     int mpi_errno = MPI_SUCCESS;
1127 
1128     /* ---- domain ---- */
1129     struct fid_domain *domain;
1130     MPIDI_OFI_CALL(fi_domain(MPIDI_OFI_global.fabric, MPIDI_OFI_global.prov_use, &domain, NULL),
1131                    opendomain);
1132     *p_domain = domain;
1133 
1134     /* ---- av ---- */
1135     /* ----
1136      * Attempt to open a shared address vector read-only.
1137      * The open will fail if the address vector does not exist.
1138      * Otherwise, set MPIDI_OFI_global.got_named_av and
1139      * copy the map_addr.
1140      */
1141     if (try_open_shared_av(domain, p_av)) {
1142         MPIDI_OFI_global.got_named_av = 1;
1143     }
1144 
1145     if (!MPIDI_OFI_global.got_named_av) {
1146         struct fi_av_attr av_attr;
1147         memset(&av_attr, 0, sizeof(av_attr));
1148         if (MPIDI_OFI_ENABLE_AV_TABLE) {
1149             av_attr.type = FI_AV_TABLE;
1150         } else {
1151             av_attr.type = FI_AV_MAP;
1152         }
1153         av_attr.rx_ctx_bits = MPIDI_OFI_MAX_ENDPOINTS_BITS;
1154         av_attr.count = MPIR_Process.size;
1155 
1156         av_attr.name = NULL;
1157         av_attr.flags = 0;
1158         MPIDI_OFI_CALL(fi_av_open(domain, &av_attr, p_av, NULL), avopen);
1159     }
1160 
1161     /* ---- other sharable objects ---- */
1162     struct fi_cntr_attr cntr_attr;
1163     memset(&cntr_attr, 0, sizeof(cntr_attr));
1164     cntr_attr.events = FI_CNTR_EVENTS_COMP;
1165     cntr_attr.wait_obj = FI_WAIT_UNSPEC;
1166     MPIDI_OFI_CALL(fi_cntr_open(domain, &cntr_attr, p_cntr, NULL), openct);
1167 
1168   fn_exit:
1169     return mpi_errno;
1170   fn_fail:
1171     goto fn_exit;
1172 }
1173 
create_cq(struct fid_domain * domain,struct fid_cq ** p_cq)1174 static int create_cq(struct fid_domain *domain, struct fid_cq **p_cq)
1175 {
1176     int mpi_errno = MPI_SUCCESS;
1177     struct fi_cq_attr cq_attr;
1178     memset(&cq_attr, 0, sizeof(cq_attr));
1179     cq_attr.format = FI_CQ_FORMAT_TAGGED;
1180     MPIDI_OFI_CALL(fi_cq_open(domain, &cq_attr, p_cq, NULL), opencq);
1181 
1182   fn_exit:
1183     return mpi_errno;
1184   fn_fail:
1185     goto fn_exit;
1186 }
1187 
create_sep_tx(struct fid_ep * ep,int idx,struct fid_ep ** p_tx,struct fid_cq * cq,struct fid_cntr * cntr)1188 static int create_sep_tx(struct fid_ep *ep, int idx, struct fid_ep **p_tx,
1189                          struct fid_cq *cq, struct fid_cntr *cntr)
1190 {
1191     int mpi_errno = MPI_SUCCESS;
1192 
1193     struct fi_tx_attr tx_attr;
1194     tx_attr = *(MPIDI_OFI_global.prov_use->tx_attr);
1195     tx_attr.op_flags = FI_COMPLETION;
1196     if (MPIDI_OFI_ENABLE_RMA || MPIDI_OFI_ENABLE_ATOMICS)
1197         tx_attr.op_flags |= FI_DELIVERY_COMPLETE;
1198     tx_attr.caps = 0;
1199 
1200     if (MPIDI_OFI_ENABLE_TAGGED)
1201         tx_attr.caps = FI_TAGGED;
1202 
1203     /* RMA */
1204     if (MPIDI_OFI_ENABLE_RMA)
1205         tx_attr.caps |= FI_RMA;
1206     if (MPIDI_OFI_ENABLE_ATOMICS)
1207         tx_attr.caps |= FI_ATOMICS;
1208     /* MSG */
1209     tx_attr.caps |= FI_MSG;
1210     tx_attr.caps |= FI_NAMED_RX_CTX;    /* Required for scalable endpoints indexing */
1211 
1212     MPIDI_OFI_CALL(fi_tx_context(ep, idx, &tx_attr, p_tx, NULL), ep);
1213     MPIDI_OFI_CALL(fi_ep_bind(*p_tx, &cq->fid, FI_SEND | FI_SELECTIVE_COMPLETION), bind);
1214     MPIDI_OFI_CALL(fi_ep_bind(*p_tx, &cntr->fid, FI_WRITE | FI_READ), bind);
1215     MPIDI_OFI_CALL(fi_enable(*p_tx), ep_enable);
1216 
1217   fn_exit:
1218     return mpi_errno;
1219   fn_fail:
1220     goto fn_exit;
1221 }
1222 
create_sep_rx(struct fid_ep * ep,int idx,struct fid_ep ** p_rx,struct fid_cq * cq,struct fid_cntr * cntr)1223 static int create_sep_rx(struct fid_ep *ep, int idx, struct fid_ep **p_rx,
1224                          struct fid_cq *cq, struct fid_cntr *cntr)
1225 {
1226     int mpi_errno = MPI_SUCCESS;
1227 
1228     struct fi_rx_attr rx_attr;
1229     rx_attr = *(MPIDI_OFI_global.prov_use->rx_attr);
1230     rx_attr.caps = 0;
1231 
1232     if (MPIDI_OFI_ENABLE_TAGGED) {
1233         rx_attr.caps |= FI_TAGGED;
1234         rx_attr.caps |= FI_DIRECTED_RECV;
1235     }
1236 
1237     if (MPIDI_OFI_ENABLE_RMA)
1238         rx_attr.caps |= FI_RMA | FI_REMOTE_READ | FI_REMOTE_WRITE;
1239     if (MPIDI_OFI_ENABLE_ATOMICS)
1240         rx_attr.caps |= FI_ATOMICS;
1241     rx_attr.caps |= FI_MSG;
1242     rx_attr.caps |= FI_MULTI_RECV;
1243     rx_attr.caps |= FI_NAMED_RX_CTX;    /* Required for scalable endpoints indexing */
1244 
1245     MPIDI_OFI_CALL(fi_rx_context(ep, idx, &rx_attr, p_rx, NULL), ep);
1246     MPIDI_OFI_CALL(fi_ep_bind(*p_rx, &cq->fid, FI_RECV), bind);
1247     MPIDI_OFI_CALL(fi_enable(*p_rx), ep_enable);
1248 
1249   fn_exit:
1250     return mpi_errno;
1251   fn_fail:
1252     goto fn_exit;
1253 }
1254 
try_open_shared_av(struct fid_domain * domain,struct fid_av ** p_av)1255 static int try_open_shared_av(struct fid_domain *domain, struct fid_av **p_av)
1256 {
1257 #ifdef MPIDI_OFI_VNI_USE_DOMAIN
1258     /* shared/named av table cannot be used when multiple fi_domain is enabled */
1259     return 0;
1260 #else
1261     struct fi_av_attr av_attr;
1262     memset(&av_attr, 0, sizeof(av_attr));
1263     if (MPIDI_OFI_ENABLE_AV_TABLE) {
1264         av_attr.type = FI_AV_TABLE;
1265     } else {
1266         av_attr.type = FI_AV_MAP;
1267     }
1268     av_attr.rx_ctx_bits = MPIDI_OFI_MAX_ENDPOINTS_BITS;
1269     av_attr.count = MPIR_Process.size;
1270 
1271     char av_name[128];
1272     MPL_snprintf(av_name, sizeof(av_name), "FI_NAMED_AV_%d\n", MPIR_Process.appnum);
1273     av_attr.name = av_name;
1274     av_attr.flags = FI_READ;
1275     av_attr.map_addr = 0;
1276 
1277     if (0 == fi_av_open(domain, &av_attr, p_av, NULL)) {
1278         /* TODO - the copy from the pre-existing av map into the 'MPIDI_OFI_AV' */
1279         /* is wasteful and should be changed so that the 'MPIDI_OFI_AV' object  */
1280         /* directly references the mapped fi_addr_t array instead               */
1281         fi_addr_t *mapped_table = (fi_addr_t *) av_attr.map_addr;
1282         for (int i = 0; i < MPIR_Process.size; i++) {
1283             MPIDI_OFI_AV(&MPIDIU_get_av(0, i)).dest[0][0] = mapped_table[i];
1284             MPL_DBG_MSG_FMT(MPIDI_CH4_DBG_MAP, VERBOSE,
1285                             (MPL_DBG_FDEST, " grank mapped to: rank=%d, av=%p, dest=%" PRIu64,
1286                              i, (void *) &MPIDIU_get_av(0, i), mapped_table[i]));
1287         }
1288         return 1;
1289     } else {
1290         return 0;
1291     }
1292 #endif
1293 }
1294 
create_rma_stx_ctx(struct fid_domain * domain,struct fid_stx ** p_rma_stx_ctx)1295 static int create_rma_stx_ctx(struct fid_domain *domain, struct fid_stx **p_rma_stx_ctx)
1296 {
1297     int mpi_errno = MPI_SUCCESS;
1298 
1299     if (MPIDI_OFI_ENABLE_SHARED_CONTEXTS) {
1300         int ret;
1301         struct fi_tx_attr tx_attr;
1302         memset(&tx_attr, 0, sizeof(tx_attr));
1303         /* A shared transmit context’s attributes must be a union of all associated
1304          * endpoints' transmit capabilities. */
1305         tx_attr.caps = FI_RMA | FI_WRITE | FI_READ | FI_ATOMIC;
1306         tx_attr.msg_order = FI_ORDER_RAR | FI_ORDER_RAW | FI_ORDER_WAR | FI_ORDER_WAW;
1307         tx_attr.op_flags = FI_DELIVERY_COMPLETE | FI_COMPLETION;
1308         MPIDI_OFI_CALL_RETURN(fi_stx_context(domain, &tx_attr, p_rma_stx_ctx, NULL), ret);
1309         if (ret < 0) {
1310             MPL_DBG_MSG(MPIDI_CH4_DBG_GENERAL, VERBOSE,
1311                         "Failed to create shared TX context for RMA, "
1312                         "falling back to global EP/counter scheme");
1313             *p_rma_stx_ctx = NULL;
1314         }
1315     }
1316 
1317   fn_exit:
1318     return mpi_errno;
1319   fn_fail:
1320     goto fn_exit;
1321 }
1322 
open_fabric(void)1323 static int open_fabric(void)
1324 {
1325     int mpi_errno = MPI_SUCCESS;
1326     struct fi_info *prov_list = NULL;
1327 
1328     /* First, find the provider and prepare the hints */
1329     struct fi_info *hints = fi_allocinfo();
1330     MPIR_Assert(hints != NULL);
1331 
1332     mpi_errno = find_provider(hints);
1333     MPIR_ERR_CHECK(mpi_errno);
1334 
1335     /* Second, get the actual fi_info * prov */
1336     MPIDI_OFI_CALL(fi_getinfo(get_ofi_version(), NULL, NULL, 0ULL, hints, &prov_list), getinfo);
1337 
1338     struct fi_info *prov = prov_list;
1339     /* fi_getinfo may ignore the addr_format in hints, filter it again */
1340     if (hints->addr_format != FI_FORMAT_UNSPEC) {
1341         while (prov && prov->addr_format != hints->addr_format) {
1342             prov = prov->next;
1343         }
1344     }
1345     MPIR_ERR_CHKANDJUMP(prov == NULL, mpi_errno, MPI_ERR_OTHER, "**ofid_getinfo");
1346     if (!MPIDI_OFI_ENABLE_RUNTIME_CHECKS) {
1347         int set_number = MPIDI_OFI_get_set_number(prov->fabric_attr->prov_name);
1348         MPIR_ERR_CHKANDJUMP(MPIDI_OFI_SET_NUMBER != set_number,
1349                             mpi_errno, MPI_ERR_OTHER, "**ofi_provider_mismatch");
1350     }
1351 
1352     /* Third, update global settings */
1353     if (MPIDI_OFI_ENABLE_RUNTIME_CHECKS) {
1354         update_global_settings(prov, hints);
1355     }
1356 
1357     MPIDI_OFI_global.prov_use = fi_dupinfo(prov);
1358     MPIR_Assert(MPIDI_OFI_global.prov_use);
1359 
1360     MPIDI_OFI_global.max_buffered_send = prov->tx_attr->inject_size;
1361     MPIDI_OFI_global.max_buffered_write = prov->tx_attr->inject_size;
1362     if (MPIR_CVAR_CH4_OFI_EAGER_MAX_MSG_SIZE > 0 &&
1363         MPIR_CVAR_CH4_OFI_EAGER_MAX_MSG_SIZE <= prov->ep_attr->max_msg_size) {
1364         /* Truncate max_msg_size to a user-selected value */
1365         MPIDI_OFI_global.max_msg_size = MPIR_CVAR_CH4_OFI_EAGER_MAX_MSG_SIZE;
1366     } else {
1367         MPIDI_OFI_global.max_msg_size = MPL_MIN(prov->ep_attr->max_msg_size, MPIR_AINT_MAX);
1368     }
1369     MPIDI_OFI_global.max_order_raw = prov->ep_attr->max_order_raw_size;
1370     MPIDI_OFI_global.max_order_war = prov->ep_attr->max_order_war_size;
1371     MPIDI_OFI_global.max_order_waw = prov->ep_attr->max_order_waw_size;
1372     MPIDI_OFI_global.tx_iov_limit = MIN(prov->tx_attr->iov_limit, MPIDI_OFI_IOV_MAX);
1373     MPIDI_OFI_global.rx_iov_limit = MIN(prov->rx_attr->iov_limit, MPIDI_OFI_IOV_MAX);
1374     MPIDI_OFI_global.rma_iov_limit = MIN(prov->tx_attr->rma_iov_limit, MPIDI_OFI_IOV_MAX);
1375     MPIDI_OFI_global.max_mr_key_size = prov->domain_attr->mr_key_size;
1376 
1377     /* if using extended context id, check that selected provider can support it */
1378     MPIR_Assert(MPIR_CONTEXT_ID_BITS <= MPIDI_OFI_CONTEXT_BITS);
1379     /* Check that the desired number of ranks is possible and abort if not */
1380     if (MPIDI_OFI_MAX_RANK_BITS < 32 && MPIR_Process.size > (1 << MPIDI_OFI_MAX_RANK_BITS)) {
1381         MPIR_ERR_SETANDJUMP(mpi_errno, MPI_ERR_OTHER, "**ch4|too_many_ranks");
1382     }
1383 
1384     if (MPIR_CVAR_CH4_OFI_CAPABILITY_SETS_DEBUG && MPIR_Process.rank == 0) {
1385         dump_global_settings();
1386     }
1387 
1388     /* Finally open the fabric */
1389     MPIDI_OFI_CALL(fi_fabric(prov->fabric_attr, &MPIDI_OFI_global.fabric, NULL), fabric);
1390 
1391   fn_exit:
1392     if (prov_list) {
1393         fi_freeinfo(prov_list);
1394     }
1395 
1396     /* prov_name is from MPL_strdup, can't let fi_freeinfo to free it */
1397     MPL_free(hints->fabric_attr->prov_name);
1398     hints->fabric_attr->prov_name = NULL;
1399     fi_freeinfo(hints);
1400 
1401     return mpi_errno;
1402   fn_fail:
1403     goto fn_exit;
1404 }
1405 
find_provider(struct fi_info * hints)1406 static int find_provider(struct fi_info *hints)
1407 {
1408     int mpi_errno = MPI_SUCCESS;
1409 
1410     const char *provname = MPIR_CVAR_OFI_USE_PROVIDER;
1411     int ofi_version = get_ofi_version();
1412 
1413     if (!MPIDI_OFI_ENABLE_RUNTIME_CHECKS) {
1414         init_global_settings(MPIR_CVAR_OFI_USE_PROVIDER);
1415     } else {
1416         init_global_settings(MPIR_CVAR_OFI_USE_PROVIDER ? MPIR_CVAR_OFI_USE_PROVIDER :
1417                              MPIDI_OFI_SET_NAME_DEFAULT);
1418     }
1419 
1420     if (MPIDI_OFI_ENABLE_RUNTIME_CHECKS) {
1421         /* Ensure that we aren't trying to shove too many bits into the match_bits.
1422          * Currently, this needs to fit into a uint64_t and we take 4 bits for protocol. */
1423         MPIR_Assert(MPIDI_OFI_CONTEXT_BITS + MPIDI_OFI_SOURCE_BITS + MPIDI_OFI_TAG_BITS <= 60);
1424 
1425         struct fi_info *prov_list, *prov_use;
1426         MPIDI_OFI_CALL(fi_getinfo(ofi_version, NULL, NULL, 0ULL, NULL, &prov_list), getinfo);
1427 
1428         prov_use = pick_provider_from_list(provname, prov_list);
1429 
1430         MPIR_ERR_CHKANDJUMP(prov_use == NULL, mpi_errno, MPI_ERR_OTHER, "**ofid_getinfo");
1431 
1432         /* Initialize hints based on MPIDI_OFI_global.settings (updated by pick_provider_from_list()) */
1433         init_hints(hints);
1434         hints->fabric_attr->prov_name = MPL_strdup(prov_use->fabric_attr->prov_name);
1435         hints->caps = prov_use->caps;
1436         hints->addr_format = prov_use->addr_format;
1437 
1438         fi_freeinfo(prov_list);
1439     } else {
1440         /* Make sure that the user-specified provider matches the configure-specified provider. */
1441         MPIR_ERR_CHKANDJUMP(provname != NULL &&
1442                             MPIDI_OFI_SET_NUMBER != MPIDI_OFI_get_set_number(provname),
1443                             mpi_errno, MPI_ERR_OTHER, "**ofi_provider_mismatch");
1444         /* Initialize hints based on MPIDI_OFI_global.settings (config macros) */
1445         init_hints(hints);
1446         hints->fabric_attr->prov_name = provname ? MPL_strdup(provname) : NULL;
1447     }
1448   fn_exit:
1449     return mpi_errno;
1450   fn_fail:
1451     goto fn_exit;
1452 }
1453 
1454 #define DBG_TRY_PICK_PROVIDER(round) /* round is a str, eg "Round 1" */ \
1455     if (NULL == prov_use) { \
1456         MPL_DBG_MSG_FMT(MPIDI_CH4_DBG_GENERAL, VERBOSE, \
1457                         (MPL_DBG_FDEST, round ": find_provider returned NULL\n")); \
1458     } else { \
1459         MPL_DBG_MSG_FMT(MPIDI_CH4_DBG_GENERAL, VERBOSE, \
1460                         (MPL_DBG_FDEST, round ": find_provider returned %s\n", \
1461                         prov_use->fabric_attr->prov_name)); \
1462     }
1463 
pick_provider_from_list(const char * provname,struct fi_info * prov_list)1464 static struct fi_info *pick_provider_from_list(const char *provname, struct fi_info *prov_list)
1465 {
1466     struct fi_info *prov_use = NULL;
1467     /* We'll try to pick the best provider three times.
1468      * 1 - Check to see if any provider matches an existing capability set (e.g. sockets)
1469      * 2 - Check to see if any provider meets the default capability set
1470      * 3 - Check to see if any provider meets the minimal capability set
1471      */
1472     bool provname_is_set = (provname &&
1473                             strcmp(provname, MPIDI_OFI_SET_NAME_DEFAULT) != 0 &&
1474                             strcmp(provname, MPIDI_OFI_SET_NAME_MINIMAL) != 0);
1475     if (NULL == prov_use && provname_is_set) {
1476         prov_use = pick_provider_by_name((char *) provname, prov_list);
1477         DBG_TRY_PICK_PROVIDER("[match name]");
1478     }
1479 
1480     bool provname_is_minimal = (provname && strcmp(provname, MPIDI_OFI_SET_NAME_MINIMAL) == 0);
1481     if (NULL == prov_use && !provname_is_minimal) {
1482         init_global_settings(MPIDI_OFI_SET_NAME_DEFAULT);
1483         prov_use = pick_provider_by_global_settings(prov_list);
1484         DBG_TRY_PICK_PROVIDER("[match default]");
1485     }
1486 
1487     if (NULL == prov_use) {
1488         init_global_settings(MPIDI_OFI_SET_NAME_MINIMAL);
1489         prov_use = pick_provider_by_global_settings(prov_list);
1490         DBG_TRY_PICK_PROVIDER("[match minimal]");
1491     }
1492 
1493     return prov_use;
1494 }
1495 
pick_provider_by_name(const char * provname,struct fi_info * prov_list)1496 static struct fi_info *pick_provider_by_name(const char *provname, struct fi_info *prov_list)
1497 {
1498     struct fi_info *prov, *prov_use = NULL;
1499 
1500     prov = prov_list;
1501     while (NULL != prov) {
1502         /* Match provider name exactly */
1503         if (0 != strcmp(provname, prov->fabric_attr->prov_name)) {
1504             MPL_DBG_MSG_FMT(MPIDI_CH4_DBG_GENERAL, VERBOSE,
1505                             (MPL_DBG_FDEST, "Skipping provider: name mismatch"));
1506             prov = prov->next;
1507             continue;
1508         }
1509 
1510         init_global_settings(prov->fabric_attr->prov_name);
1511 
1512         if (!match_global_settings(prov)) {
1513             prov = prov->next;
1514             continue;
1515         }
1516 
1517         prov_use = prov;
1518 
1519         break;
1520     }
1521 
1522     return prov_use;
1523 }
1524 
pick_provider_by_global_settings(struct fi_info * prov_list)1525 static struct fi_info *pick_provider_by_global_settings(struct fi_info *prov_list)
1526 {
1527     struct fi_info *prov, *prov_use = NULL;
1528 
1529     prov = prov_list;
1530     while (NULL != prov) {
1531         if (!match_global_settings(prov)) {
1532             prov = prov->next;
1533             continue;
1534         } else {
1535             prov_use = prov;
1536             break;
1537         }
1538     }
1539 
1540     return prov_use;
1541 }
1542 
1543 #define CHECK_CAP(SETTING, cond_bad) \
1544     if (SETTING) { \
1545         if (cond_bad) { \
1546             MPL_DBG_MSG_FMT(MPIDI_CH4_DBG_GENERAL, VERBOSE, \
1547                             (MPL_DBG_FDEST, "provider failed " #SETTING)); \
1548             return false; \
1549         } \
1550     }
1551 
match_global_settings(struct fi_info * prov)1552 bool match_global_settings(struct fi_info * prov)
1553 {
1554     MPL_DBG_MSG_FMT(MPIDI_CH4_DBG_GENERAL, VERBOSE, (MPL_DBG_FDEST, "Provider name: %s",
1555                                                      prov->fabric_attr->prov_name));
1556 
1557     if (MPIR_CVAR_OFI_SKIP_IPV6) {
1558         if (prov->addr_format == FI_SOCKADDR_IN6) {
1559             return false;
1560         }
1561     }
1562     CHECK_CAP(MPIDI_OFI_ENABLE_SCALABLE_ENDPOINTS,
1563               prov->domain_attr->max_ep_tx_ctx <= 1 ||
1564               (prov->caps & FI_NAMED_RX_CTX) != FI_NAMED_RX_CTX);
1565 
1566     /* From the fi_getinfo manpage: "FI_TAGGED implies the ability to send and receive
1567      * tagged messages." Therefore no need to specify FI_SEND|FI_RECV.  Moreover FI_SEND
1568      * and FI_RECV are mutually exclusive, so they should never be set both at the same
1569      * time. */
1570     /* This capability set also requires the ability to receive data in the completion
1571      * queue object (at least 32 bits). Previously, this was a separate capability set,
1572      * but as more and more providers supported this feature, the decision was made to
1573      * require it. */
1574     CHECK_CAP(MPIDI_OFI_ENABLE_TAGGED,
1575               !(prov->caps & FI_TAGGED) || prov->domain_attr->cq_data_size < 4);
1576 
1577     /* OFI provider doesn't expose FI_DIRECTED_RECV by default for performance consideration.
1578      * MPICH should request this flag to enable it. */
1579     if (MPIDI_OFI_ENABLE_TAGGED)
1580         prov->caps |= FI_DIRECTED_RECV;
1581 
1582     CHECK_CAP(MPIDI_OFI_ENABLE_AM,
1583               (prov->caps & (FI_MSG | FI_MULTI_RECV)) != (FI_MSG | FI_MULTI_RECV));
1584 
1585     CHECK_CAP(MPIDI_OFI_ENABLE_RMA, !(prov->caps & FI_RMA));
1586 #ifdef FI_HMEM
1587     CHECK_CAP(MPIDI_OFI_ENABLE_HMEM, !(prov->caps & FI_HMEM));
1588 #endif
1589     uint64_t msg_order = MPIDI_OFI_ATOMIC_ORDER_FLAGS;
1590     CHECK_CAP(MPIDI_OFI_ENABLE_ATOMICS,
1591               !(prov->caps & FI_ATOMICS) || (prov->tx_attr->msg_order & msg_order) != msg_order);
1592 
1593     CHECK_CAP(MPIDI_OFI_ENABLE_CONTROL_AUTO_PROGRESS,
1594               !(prov->domain_attr->control_progress & FI_PROGRESS_AUTO));
1595 
1596     CHECK_CAP(MPIDI_OFI_ENABLE_DATA_AUTO_PROGRESS,
1597               !(prov->domain_attr->data_progress & FI_PROGRESS_AUTO));
1598 
1599     int MPICH_REQUIRE_RDM = 1;  /* hack to use CHECK_CAP macro */
1600     CHECK_CAP(MPICH_REQUIRE_RDM, prov->ep_attr->type != FI_EP_RDM);
1601 
1602     return true;
1603 }
1604 
1605 #define UPDATE_SETTING_BY_CAP(cap, CVAR) \
1606     MPIDI_OFI_global.settings.cap = (CVAR != -1) ? CVAR : \
1607                                     prov_name ? prov_caps->cap : \
1608                                     CVAR
1609 
init_global_settings(const char * prov_name)1610 static void init_global_settings(const char *prov_name)
1611 {
1612     int prov_idx = MPIDI_OFI_get_set_number(prov_name);
1613     MPIDI_OFI_capabilities_t *prov_caps = &MPIDI_OFI_caps_list[prov_idx];
1614 
1615     /* Seed the global settings values for cases where we are using runtime sets */
1616     UPDATE_SETTING_BY_CAP(enable_av_table, MPIR_CVAR_CH4_OFI_ENABLE_AV_TABLE);
1617     UPDATE_SETTING_BY_CAP(enable_scalable_endpoints, MPIR_CVAR_CH4_OFI_ENABLE_SCALABLE_ENDPOINTS);
1618     /* If the user specifies -1 (=don't care) and the provider supports it, then try to use STX
1619      * and fall back if necessary in the RMA init code */
1620     UPDATE_SETTING_BY_CAP(enable_shared_contexts, MPIR_CVAR_CH4_OFI_ENABLE_SHARED_CONTEXTS);
1621 
1622     /* As of OFI version 1.5, FI_MR_SCALABLE and FI_MR_BASIC are deprecated. Internally, we now use
1623      * FI_MR_VIRT_ADDRESS and FI_MR_PROV_KEY so set them appropriately depending on the OFI version
1624      * being used. */
1625     if (get_ofi_version() < FI_VERSION(1, 5)) {
1626         /* If the OFI library is 1.5 or less, query whether or not to use FI_MR_SCALABLE and set
1627          * FI_MR_VIRT_ADDRESS, FI_MR_ALLOCATED, and FI_MR_PROV_KEY as the opposite values. */
1628         UPDATE_SETTING_BY_CAP(enable_mr_virt_address, MPIR_CVAR_CH4_OFI_ENABLE_MR_SCALABLE);
1629         MPIDI_OFI_global.settings.enable_mr_virt_address =
1630             MPIDI_OFI_global.settings.enable_mr_prov_key =
1631             MPIDI_OFI_global.settings.enable_mr_allocated =
1632             !MPIDI_OFI_global.settings.enable_mr_virt_address;
1633     } else {
1634         UPDATE_SETTING_BY_CAP(enable_mr_virt_address, MPIR_CVAR_CH4_OFI_ENABLE_MR_VIRT_ADDRESS);
1635         UPDATE_SETTING_BY_CAP(enable_mr_allocated, MPIR_CVAR_CH4_OFI_ENABLE_MR_ALLOCATED);
1636         UPDATE_SETTING_BY_CAP(enable_mr_prov_key, MPIR_CVAR_CH4_OFI_ENABLE_MR_PROV_KEY);
1637     }
1638     UPDATE_SETTING_BY_CAP(enable_tagged, MPIR_CVAR_CH4_OFI_ENABLE_TAGGED);
1639     UPDATE_SETTING_BY_CAP(enable_am, MPIR_CVAR_CH4_OFI_ENABLE_AM);
1640     UPDATE_SETTING_BY_CAP(enable_rma, MPIR_CVAR_CH4_OFI_ENABLE_RMA);
1641     /* try to enable atomics only when RMA is enabled */
1642     if (MPIDI_OFI_ENABLE_RMA) {
1643         UPDATE_SETTING_BY_CAP(enable_atomics, MPIR_CVAR_CH4_OFI_ENABLE_ATOMICS);
1644     } else {
1645         MPIDI_OFI_global.settings.enable_atomics = 0;
1646     }
1647     UPDATE_SETTING_BY_CAP(fetch_atomic_iovecs, MPIR_CVAR_CH4_OFI_FETCH_ATOMIC_IOVECS);
1648     UPDATE_SETTING_BY_CAP(enable_data_auto_progress, MPIR_CVAR_CH4_OFI_ENABLE_DATA_AUTO_PROGRESS);
1649     UPDATE_SETTING_BY_CAP(enable_control_auto_progress,
1650                           MPIR_CVAR_CH4_OFI_ENABLE_CONTROL_AUTO_PROGRESS);
1651     UPDATE_SETTING_BY_CAP(enable_pt2pt_nopack, MPIR_CVAR_CH4_OFI_ENABLE_PT2PT_NOPACK);
1652     UPDATE_SETTING_BY_CAP(context_bits, MPIR_CVAR_CH4_OFI_CONTEXT_ID_BITS);
1653     UPDATE_SETTING_BY_CAP(source_bits, MPIR_CVAR_CH4_OFI_RANK_BITS);
1654     UPDATE_SETTING_BY_CAP(tag_bits, MPIR_CVAR_CH4_OFI_TAG_BITS);
1655     UPDATE_SETTING_BY_CAP(major_version, MPIR_CVAR_CH4_OFI_MAJOR_VERSION);
1656     UPDATE_SETTING_BY_CAP(minor_version, MPIR_CVAR_CH4_OFI_MINOR_VERSION);
1657     UPDATE_SETTING_BY_CAP(num_am_buffers, MPIR_CVAR_CH4_OFI_NUM_AM_BUFFERS);
1658     if (MPIDI_OFI_global.settings.num_am_buffers < 0) {
1659         MPIDI_OFI_global.settings.num_am_buffers = 0;
1660     }
1661     if (MPIDI_OFI_global.settings.num_am_buffers > MPIDI_OFI_MAX_NUM_AM_BUFFERS) {
1662         MPIDI_OFI_global.settings.num_am_buffers = MPIDI_OFI_MAX_NUM_AM_BUFFERS;
1663     }
1664 }
1665 
1666 #define UPDATE_SETTING_BY_INFO(cap, info_cond) \
1667     MPIDI_OFI_global.settings.cap = MPIDI_OFI_global.settings.cap && info_cond
1668 
update_global_settings(struct fi_info * prov_use,struct fi_info * hints)1669 static void update_global_settings(struct fi_info *prov_use, struct fi_info *hints)
1670 {
1671     /* ------------------------------------------------------------------------ */
1672     /* Set global attributes attributes based on the provider choice            */
1673     /* ------------------------------------------------------------------------ */
1674     UPDATE_SETTING_BY_INFO(enable_av_table, prov_use->domain_attr->av_type == FI_AV_TABLE);
1675     UPDATE_SETTING_BY_INFO(enable_scalable_endpoints,
1676                            prov_use->domain_attr->max_ep_tx_ctx > 1 &&
1677                            (prov_use->caps & FI_NAMED_RX_CTX) == FI_NAMED_RX_CTX);
1678     /* As of OFI version 1.5, FI_MR_SCALABLE and FI_MR_BASIC are deprecated. Internally, we now use
1679      * FI_MR_VIRT_ADDRESS and FI_MR_PROV_KEY so set them appropriately depending on the OFI version
1680      * being used. */
1681     if (get_ofi_version() < FI_VERSION(1, 5)) {
1682         /* If the OFI library is 1.5 or less, query whether or not to use FI_MR_SCALABLE and set
1683          * FI_MR_VIRT_ADDRESS, FI_MR_ALLOCATED, and FI_MR_PROV_KEY as the opposite values. */
1684         UPDATE_SETTING_BY_INFO(enable_mr_virt_address,
1685                                prov_use->domain_attr->mr_mode != FI_MR_SCALABLE);
1686         UPDATE_SETTING_BY_INFO(enable_mr_allocated,
1687                                prov_use->domain_attr->mr_mode != FI_MR_SCALABLE);
1688         UPDATE_SETTING_BY_INFO(enable_mr_prov_key,
1689                                prov_use->domain_attr->mr_mode != FI_MR_SCALABLE);
1690     } else {
1691         UPDATE_SETTING_BY_INFO(enable_mr_virt_address,
1692                                prov_use->domain_attr->mr_mode & FI_MR_VIRT_ADDR);
1693         UPDATE_SETTING_BY_INFO(enable_mr_allocated,
1694                                prov_use->domain_attr->mr_mode & FI_MR_ALLOCATED);
1695         UPDATE_SETTING_BY_INFO(enable_mr_prov_key, prov_use->domain_attr->mr_mode & FI_MR_PROV_KEY);
1696     }
1697     UPDATE_SETTING_BY_INFO(enable_tagged,
1698                            (prov_use->caps & FI_TAGGED) &&
1699                            (prov_use->caps & FI_DIRECTED_RECV) &&
1700                            (prov_use->domain_attr->cq_data_size >= 4));
1701     UPDATE_SETTING_BY_INFO(enable_am,
1702                            (prov_use->caps & (FI_MSG | FI_MULTI_RECV | FI_READ)) ==
1703                            (FI_MSG | FI_MULTI_RECV | FI_READ));
1704     UPDATE_SETTING_BY_INFO(enable_rma, prov_use->caps & FI_RMA);
1705     UPDATE_SETTING_BY_INFO(enable_atomics, prov_use->caps & FI_ATOMICS);
1706 #ifdef FI_HMEM
1707     UPDATE_SETTING_BY_INFO(enable_hmem, prov_use->caps & FI_HMEM);
1708 #endif
1709     UPDATE_SETTING_BY_INFO(enable_data_auto_progress,
1710                            hints->domain_attr->data_progress & FI_PROGRESS_AUTO);
1711     UPDATE_SETTING_BY_INFO(enable_control_auto_progress,
1712                            hints->domain_attr->control_progress & FI_PROGRESS_AUTO);
1713 
1714     if (MPIDI_OFI_global.settings.enable_scalable_endpoints) {
1715         MPIDI_OFI_global.settings.max_endpoints = MPIDI_OFI_MAX_ENDPOINTS_SCALABLE;
1716         MPIDI_OFI_global.settings.max_endpoints_bits = MPIDI_OFI_MAX_ENDPOINTS_BITS_SCALABLE;
1717     } else {
1718         MPIDI_OFI_global.settings.max_endpoints = MPIDI_OFI_MAX_ENDPOINTS_REGULAR;
1719         MPIDI_OFI_global.settings.max_endpoints_bits = MPIDI_OFI_MAX_ENDPOINTS_BITS_REGULAR;
1720     }
1721 }
1722 
1723 /* Initializes hint structure based MPIDI_OFI_global.settings (or config macros) */
init_hints(struct fi_info * hints)1724 static void init_hints(struct fi_info *hints)
1725 {
1726     int ofi_version = get_ofi_version();
1727     MPIR_Assert(hints != NULL);
1728 
1729     /* ------------------------------------------------------------------------ */
1730     /* Hints to filter providers                                                */
1731     /* See man fi_getinfo for a list                                            */
1732     /* of all filters                                                           */
1733     /* mode:  Select capabilities that this netmod will support                 */
1734     /*        FI_CONTEXT(2):  This netmod will pass in context into communication */
1735     /*        to optimize storage locality between MPI requests and OFI opaque  */
1736     /*        data structures.                                                  */
1737     /*        FI_ASYNC_IOV:  MPICH will provide storage for iovecs on           */
1738     /*        communication calls, avoiding the OFI provider needing to require */
1739     /*        a copy.                                                           */
1740     /*        FI_LOCAL_MR unset:  Note that we do not set FI_LOCAL_MR,          */
1741     /*        which means this netmod does not support exchange of memory       */
1742     /*        regions on communication calls.                                   */
1743     /* caps:     Capabilities required from the provider.  The bits specified   */
1744     /*           with buffered receive, cancel, and remote complete implements  */
1745     /*           MPI semantics.                                                 */
1746     /*           Tagged: used to support tag matching, 2-sided                  */
1747     /*           RMA|Atomics:  supports MPI 1-sided                             */
1748     /*           MSG|MULTI_RECV:  Supports synchronization protocol for 1-sided */
1749     /*           FI_DIRECTED_RECV: Support not putting the source in the match  */
1750     /*                             bits                                         */
1751     /*           FI_NAMED_RX_CTX: Necessary to specify receiver-side SEP index  */
1752     /*                            when scalable endpoint (SEP) is enabled.      */
1753     /*           We expect to register all memory up front for use with this    */
1754     /*           endpoint, so the netmod requires dynamic memory regions        */
1755     /* ------------------------------------------------------------------------ */
1756     hints->mode = FI_CONTEXT | FI_ASYNC_IOV | FI_RX_CQ_DATA;    /* We can handle contexts  */
1757 
1758     if (ofi_version >= FI_VERSION(1, 5)) {
1759 #ifdef FI_CONTEXT2
1760         hints->mode |= FI_CONTEXT2;
1761 #endif
1762     }
1763     hints->caps = 0ULL;
1764 
1765     /* RMA interface is used in AM and in native modes,
1766      * it should be supported by OFI provider in any case */
1767     hints->caps |= FI_RMA;      /* RMA(read/write)         */
1768     hints->caps |= FI_WRITE;    /* We need to specify all of the extra
1769                                  * capabilities because we need to be
1770                                  * specific later when we create tx/rx
1771                                  * contexts. If we leave this off, the
1772                                  * context creation fails because it's not
1773                                  * a subset of this. */
1774     hints->caps |= FI_READ;
1775     hints->caps |= FI_REMOTE_WRITE;
1776     hints->caps |= FI_REMOTE_READ;
1777 
1778     if (MPIDI_OFI_ENABLE_ATOMICS) {
1779         hints->caps |= FI_ATOMICS;      /* Atomics capabilities    */
1780     }
1781 
1782     if (MPIDI_OFI_ENABLE_TAGGED) {
1783         hints->caps |= FI_TAGGED;       /* Tag matching interface  */
1784         hints->caps |= FI_DIRECTED_RECV;        /* Match source address    */
1785         hints->domain_attr->cq_data_size = 4;   /* Minimum size for completion data entry */
1786     }
1787 
1788     if (MPIDI_OFI_ENABLE_AM) {
1789         hints->caps |= FI_MSG;  /* Message Queue apis      */
1790         hints->caps |= FI_MULTI_RECV;   /* Shared receive buffer   */
1791     }
1792 
1793     /* With scalable endpoints, FI_NAMED_RX_CTX is needed to specify a destination receive context
1794      * index */
1795     if (MPIDI_OFI_ENABLE_SCALABLE_ENDPOINTS)
1796         hints->caps |= FI_NAMED_RX_CTX;
1797 
1798     /* ------------------------------------------------------------------------ */
1799     /* Set object options to be filtered by getinfo                             */
1800     /* domain_attr:  domain attribute requirements                              */
1801     /* op_flags:     persistent flag settings for an endpoint                   */
1802     /* endpoint type:  see FI_EP_RDM                                            */
1803     /* Filters applied (for this netmod, we need providers that can support):   */
1804     /* THREAD_DOMAIN:  Progress serialization is handled by netmod (locking)    */
1805     /* PROGRESS_AUTO:  request providers that make progress without requiring   */
1806     /*                 the ADI to dedicate a thread to advance the state        */
1807     /* FI_DELIVERY_COMPLETE:  RMA operations are visible in remote memory       */
1808     /* FI_COMPLETION:  Selective completions of RMA ops                         */
1809     /* FI_EP_RDM:  Reliable datagram                                            */
1810     /* ------------------------------------------------------------------------ */
1811     hints->addr_format = FI_FORMAT_UNSPEC;
1812     hints->domain_attr->threading = FI_THREAD_DOMAIN;
1813     if (MPIDI_OFI_ENABLE_DATA_AUTO_PROGRESS) {
1814         hints->domain_attr->data_progress = FI_PROGRESS_AUTO;
1815     } else {
1816         hints->domain_attr->data_progress = FI_PROGRESS_MANUAL;
1817     }
1818     if (MPIDI_OFI_ENABLE_CONTROL_AUTO_PROGRESS) {
1819         hints->domain_attr->control_progress = FI_PROGRESS_AUTO;
1820     } else {
1821         hints->domain_attr->control_progress = FI_PROGRESS_MANUAL;
1822     }
1823     hints->domain_attr->resource_mgmt = FI_RM_ENABLED;
1824     hints->domain_attr->av_type = MPIDI_OFI_ENABLE_AV_TABLE ? FI_AV_TABLE : FI_AV_MAP;
1825 
1826     if (ofi_version >= FI_VERSION(1, 5)) {
1827         hints->domain_attr->mr_mode = 0;
1828 #ifdef FI_RESTRICTED_COMP
1829         hints->domain_attr->mode = FI_RESTRICTED_COMP;
1830 #endif
1831         /* avoid using FI_MR_SCALABLE and FI_MR_BASIC because they are only
1832          * for backward compatibility (pre OFI version 1.5), and they don't allow any other
1833          * mode bits to be added */
1834 #ifdef FI_MR_VIRT_ADDR
1835         if (MPIDI_OFI_ENABLE_MR_VIRT_ADDRESS) {
1836             hints->domain_attr->mr_mode |= FI_MR_VIRT_ADDR;
1837         }
1838 #endif
1839 
1840 #ifdef FI_MR_ALLOCATED
1841         if (MPIDI_OFI_ENABLE_MR_ALLOCATED) {
1842             hints->domain_attr->mr_mode |= FI_MR_ALLOCATED;
1843         }
1844 #endif
1845 
1846 #ifdef FI_MR_PROV_KEY
1847         if (MPIDI_OFI_ENABLE_MR_PROV_KEY) {
1848             hints->domain_attr->mr_mode |= FI_MR_PROV_KEY;
1849         }
1850 #endif
1851     } else {
1852         if (MPIDI_OFI_ENABLE_MR_SCALABLE)
1853             hints->domain_attr->mr_mode = FI_MR_SCALABLE;
1854         else
1855             hints->domain_attr->mr_mode = FI_MR_BASIC;
1856     }
1857     hints->tx_attr->op_flags = FI_COMPLETION;
1858     hints->tx_attr->msg_order = FI_ORDER_SAS;
1859     /* direct RMA operations supported only with delivery complete mode,
1860      * else (AM mode) delivery complete is not required */
1861     if (MPIDI_OFI_ENABLE_RMA || MPIDI_OFI_ENABLE_ATOMICS) {
1862         hints->tx_attr->op_flags |= FI_DELIVERY_COMPLETE;
1863         /* Apply most restricted msg order in hints for RMA ATOMICS. */
1864         if (MPIDI_OFI_ENABLE_ATOMICS)
1865             hints->tx_attr->msg_order |= MPIDI_OFI_ATOMIC_ORDER_FLAGS;
1866     }
1867     hints->tx_attr->comp_order = FI_ORDER_NONE;
1868     hints->rx_attr->op_flags = FI_COMPLETION;
1869     hints->rx_attr->total_buffered_recv = 0;    /* FI_RM_ENABLED ensures buffering of unexpected messages */
1870     hints->ep_attr->type = FI_EP_RDM;
1871     hints->ep_attr->mem_tag_format = MPIDI_OFI_SOURCE_BITS ?
1872         /*     PROTOCOL         |  CONTEXT  |        SOURCE         |       TAG          */
1873         MPIDI_OFI_PROTOCOL_MASK | 0 | MPIDI_OFI_SOURCE_MASK | 0 /* With source bits */ :
1874         MPIDI_OFI_PROTOCOL_MASK | 0 | 0 | MPIDI_OFI_TAG_MASK /* No source bits */ ;
1875 }
1876 
1877 /* ---------------------------------------------------------- */
1878 /* Debug Routines                                             */
1879 /* ---------------------------------------------------------- */
1880 
dump_global_settings(void)1881 static void dump_global_settings(void)
1882 {
1883     fprintf(stdout, "==== Capability set configuration ====\n");
1884     fprintf(stdout, "libfabric provider: %s\n", MPIDI_OFI_global.prov_use->fabric_attr->prov_name);
1885     fprintf(stdout, "MPIDI_OFI_ENABLE_AV_TABLE: %d\n", MPIDI_OFI_ENABLE_AV_TABLE);
1886     fprintf(stdout, "MPIDI_OFI_ENABLE_SCALABLE_ENDPOINTS: %d\n",
1887             MPIDI_OFI_ENABLE_SCALABLE_ENDPOINTS);
1888     fprintf(stdout, "MPIDI_OFI_ENABLE_SHARED_CONTEXTS: %d\n", MPIDI_OFI_ENABLE_SHARED_CONTEXTS);
1889     fprintf(stdout, "MPIDI_OFI_ENABLE_MR_SCALABLE: %d\n", MPIDI_OFI_ENABLE_MR_SCALABLE);
1890     fprintf(stdout, "MPIDI_OFI_ENABLE_MR_VIRT_ADDRESS: %d\n", MPIDI_OFI_ENABLE_MR_VIRT_ADDRESS);
1891     fprintf(stdout, "MPIDI_OFI_ENABLE_MR_ALLOCATED: %d\n", MPIDI_OFI_ENABLE_MR_ALLOCATED);
1892     fprintf(stdout, "MPIDI_OFI_ENABLE_MR_PROV_KEY: %d\n", MPIDI_OFI_ENABLE_MR_PROV_KEY);
1893     fprintf(stdout, "MPIDI_OFI_ENABLE_TAGGED: %d\n", MPIDI_OFI_ENABLE_TAGGED);
1894     fprintf(stdout, "MPIDI_OFI_ENABLE_AM: %d\n", MPIDI_OFI_ENABLE_AM);
1895     fprintf(stdout, "MPIDI_OFI_ENABLE_RMA: %d\n", MPIDI_OFI_ENABLE_RMA);
1896     fprintf(stdout, "MPIDI_OFI_ENABLE_ATOMICS: %d\n", MPIDI_OFI_ENABLE_ATOMICS);
1897     fprintf(stdout, "MPIDI_OFI_FETCH_ATOMIC_IOVECS: %d\n", MPIDI_OFI_FETCH_ATOMIC_IOVECS);
1898     fprintf(stdout, "MPIDI_OFI_ENABLE_DATA_AUTO_PROGRESS: %d\n",
1899             MPIDI_OFI_ENABLE_DATA_AUTO_PROGRESS);
1900     fprintf(stdout, "MPIDI_OFI_ENABLE_CONTROL_AUTO_PROGRESS: %d\n",
1901             MPIDI_OFI_ENABLE_CONTROL_AUTO_PROGRESS);
1902     fprintf(stdout, "MPIDI_OFI_ENABLE_PT2PT_NOPACK: %d\n", MPIDI_OFI_ENABLE_PT2PT_NOPACK);
1903     fprintf(stdout, "MPIDI_OFI_ENABLE_HMEM: %d\n", MPIDI_OFI_ENABLE_HMEM);
1904     fprintf(stdout, "MPIDI_OFI_NUM_AM_BUFFERS: %d\n", MPIDI_OFI_NUM_AM_BUFFERS);
1905     fprintf(stdout, "MPIDI_OFI_CONTEXT_BITS: %d\n", MPIDI_OFI_CONTEXT_BITS);
1906     fprintf(stdout, "MPIDI_OFI_SOURCE_BITS: %d\n", MPIDI_OFI_SOURCE_BITS);
1907     fprintf(stdout, "MPIDI_OFI_TAG_BITS: %d\n", MPIDI_OFI_TAG_BITS);
1908     fprintf(stdout, "======================================\n");
1909 
1910     /* Discover the maximum number of ranks. If the source shift is not
1911      * defined, there are 32 bits in use due to the uint32_t used in
1912      * ofi_send.h */
1913     fprintf(stdout, "MAXIMUM SUPPORTED RANKS: %ld\n", (long int) 1 << MPIDI_OFI_MAX_RANK_BITS);
1914 
1915     /* Discover the tag_ub */
1916     fprintf(stdout, "MAXIMUM TAG: %lu\n", 1UL << MPIDI_OFI_TAG_BITS);
1917     fprintf(stdout, "======================================\n");
1918 }
1919 
1920 /* static address exchange routines */
addr_exchange_root_vni(MPIR_Comm * init_comm)1921 static int addr_exchange_root_vni(MPIR_Comm * init_comm)
1922 {
1923     int mpi_errno = MPI_SUCCESS;
1924     int size = MPIR_Process.size;
1925     int rank = MPIR_Process.rank;
1926 
1927     /* No pre-published address table, need do address exchange. */
1928     /* First, each get its own name */
1929     MPIDI_OFI_global.addrnamelen = FI_NAME_MAX;
1930     MPIDI_OFI_CALL(fi_getname((fid_t) MPIDI_OFI_global.ctx[0].ep, MPIDI_OFI_global.addrname,
1931                               &MPIDI_OFI_global.addrnamelen), getname);
1932     MPIR_Assert(MPIDI_OFI_global.addrnamelen <= FI_NAME_MAX);
1933 
1934     /* Second, exchange names using PMI */
1935     /* If MPIR_CVAR_CH4_ROOTS_ONLY_PMI is true, we only collect a table of node-roots.
1936      * Otherwise, we collect a table of everyone. */
1937     void *table = NULL;
1938     int ret_bc_len;
1939     mpi_errno = MPIDU_bc_table_create(rank, size, MPIDI_global.node_map[0],
1940                                       &MPIDI_OFI_global.addrname, MPIDI_OFI_global.addrnamelen,
1941                                       TRUE, MPIR_CVAR_CH4_ROOTS_ONLY_PMI, &table, &ret_bc_len);
1942     MPIR_ERR_CHECK(mpi_errno);
1943     /* MPIR_Assert(ret_bc_len = MPIDI_OFI_global.addrnamelen); */
1944 
1945     /* Third, each fi_av_insert those addresses */
1946     if (MPIR_CVAR_CH4_ROOTS_ONLY_PMI) {
1947         /* if "ROOTS_ONLY", we do a two stage bootstrapping ... */
1948         int num_nodes = MPIR_Process.num_nodes;
1949         int *node_roots = MPIR_Process.node_root_map;
1950         int *rank_map, recv_bc_len;
1951 
1952         /* First, insert address of node-roots, init_comm become useful */
1953         fi_addr_t *mapped_table;
1954         mapped_table = (fi_addr_t *) MPL_malloc(num_nodes * sizeof(fi_addr_t), MPL_MEM_ADDRESS);
1955         MPIDI_OFI_CALL(fi_av_insert
1956                        (MPIDI_OFI_global.ctx[0].av, table, num_nodes, mapped_table, 0ULL, NULL),
1957                        avmap);
1958 
1959         for (int i = 0; i < num_nodes; i++) {
1960             MPIR_Assert(mapped_table[i] != FI_ADDR_NOTAVAIL);
1961             MPIDI_OFI_AV(&MPIDIU_get_av(0, node_roots[i])).dest[0][0] = mapped_table[i];
1962         }
1963         MPL_free(mapped_table);
1964         /* Then, allgather all address names using init_comm */
1965         MPIDU_bc_allgather(init_comm, MPIDI_OFI_global.addrname, MPIDI_OFI_global.addrnamelen,
1966                            TRUE, &table, &rank_map, &recv_bc_len);
1967 
1968         /* Insert the rest of the addresses */
1969         for (int i = 0; i < MPIR_Process.size; i++) {
1970             if (rank_map[i] >= 0) {
1971                 fi_addr_t addr;
1972                 char *addrname = (char *) table + recv_bc_len * rank_map[i];
1973                 MPIDI_OFI_CALL(fi_av_insert(MPIDI_OFI_global.ctx[0].av,
1974                                             addrname, 1, &addr, 0ULL, NULL), avmap);
1975                 MPIDI_OFI_AV(&MPIDIU_get_av(0, rank)).dest[0][0] = addr;
1976             }
1977         }
1978         MPIDU_bc_table_destroy();
1979     } else {
1980         /* not "ROOTS_ONLY", we already have everyone's address name, insert all of them */
1981         fi_addr_t *mapped_table;
1982         mapped_table = (fi_addr_t *) MPL_malloc(size * sizeof(fi_addr_t), MPL_MEM_ADDRESS);
1983         MPIDI_OFI_CALL(fi_av_insert
1984                        (MPIDI_OFI_global.ctx[0].av, table, size, mapped_table, 0ULL, NULL), avmap);
1985 
1986         for (int i = 0; i < size; i++) {
1987             MPIR_Assert(mapped_table[i] != FI_ADDR_NOTAVAIL);
1988             MPIDI_OFI_AV(&MPIDIU_get_av(0, i)).dest[0][0] = mapped_table[i];
1989         }
1990         MPL_free(mapped_table);
1991         MPIDU_bc_table_destroy();
1992     }
1993 
1994   fn_exit:
1995     return mpi_errno;
1996   fn_fail:
1997     goto fn_exit;
1998 }
1999 
addr_exchange_all_vnis(void)2000 static int addr_exchange_all_vnis(void)
2001 {
2002     int mpi_errno = MPI_SUCCESS;
2003 
2004 #ifdef MPIDI_OFI_VNI_USE_DOMAIN
2005     int size = MPIR_Process.size;
2006     int rank = MPIR_Process.rank;
2007     int num_vnis = MPIDI_OFI_global.num_vnis;
2008 
2009     /* get addr name length */
2010     size_t name_len = 0;
2011     int ret = fi_getname((fid_t) MPIDI_OFI_global.ctx[0].ep, NULL, &name_len);
2012     MPIR_Assert(ret == -FI_ETOOSMALL);
2013     MPIR_Assert(name_len > 0);
2014 
2015     int my_len = num_vnis * name_len;
2016     char *all_names = MPL_malloc(size * my_len, MPL_MEM_ADDRESS);
2017     MPIR_Assert(all_names);
2018 
2019     char *my_names = all_names + rank * my_len;
2020 
2021     /* put in my addrnames */
2022     for (int i = 0; i < num_vnis; i++) {
2023         size_t actual_name_len = name_len;
2024         char *vni_addrname = my_names + i * name_len;
2025         MPIDI_OFI_CALL(fi_getname((fid_t) MPIDI_OFI_global.ctx[i].ep, vni_addrname,
2026                                   &actual_name_len), getname);
2027         MPIR_Assert(actual_name_len == name_len);
2028     }
2029     /* Allgather */
2030     MPIR_Comm *comm = MPIR_Process.comm_world;
2031     MPIR_Errflag_t errflag = MPIR_ERR_NONE;
2032     mpi_errno = MPIR_Allgather_allcomm_auto(MPI_IN_PLACE, 0, MPI_BYTE,
2033                                             all_names, my_len, MPI_BYTE, comm, &errflag);
2034     /* insert the addresses */
2035     fi_addr_t *mapped_table;
2036     mapped_table = (fi_addr_t *) MPL_malloc(size * num_vnis * sizeof(fi_addr_t), MPL_MEM_ADDRESS);
2037     for (int vni_local = 0; vni_local < num_vnis; vni_local++) {
2038         MPIDI_OFI_CALL(fi_av_insert(MPIDI_OFI_global.ctx[vni_local].av, all_names, size * num_vnis,
2039                                     mapped_table, 0ULL, NULL), avmap);
2040         for (int r = 0; r < size; r++) {
2041             MPIDI_OFI_addr_t *av = &MPIDI_OFI_AV(&MPIDIU_get_av(0, r));
2042             for (int vni_remote = 0; vni_remote < num_vnis; vni_remote++) {
2043                 if (vni_local == 0 && vni_remote == 0) {
2044                     /* don't overwrite existing addr, or bad things will happen */
2045                     continue;
2046                 }
2047                 int idx = r * num_vnis + vni_remote;
2048                 MPIR_Assert(mapped_table[idx] != FI_ADDR_NOTAVAIL);
2049                 av->dest[vni_local][vni_remote] = mapped_table[idx];
2050             }
2051         }
2052     }
2053 #endif
2054   fn_exit:
2055     return mpi_errno;
2056   fn_fail:
2057     goto fn_exit;
2058 }
2059