1 /*
2  * Copyright (C) by Argonne National Laboratory
3  *     See COPYRIGHT in top-level directory
4  */
5 
6 #include "mpidimpl.h"
7 #include "ucx_impl.h"
8 #include "mpidu_bc.h"
9 #include <ucp/api/ucp.h>
10 
11 /*
12 === BEGIN_MPI_T_CVAR_INFO_BLOCK ===
13 
14 categories :
15     - name : CH4_UCX
16       description : A category for CH4 UCX netmod variables
17 
18 cvars:
19     - name        : MPIR_CVAR_CH4_UCX_MAX_VNIS
20       category    : CH4_UCX
21       type        : int
22       default     : 0
23       class       : none
24       verbosity   : MPI_T_VERBOSITY_USER_BASIC
25       scope       : MPI_T_SCOPE_LOCAL
26       description : >-
27         If set to positive, this CVAR specifies the maximum number of CH4 VNIs
28         that UCX netmod exposes. If set to 0 (the default) or bigger than
29         MPIR_CVAR_CH4_NUM_VCIS, the number of exposed VNIs is set to MPIR_CVAR_CH4_NUM_VCIS.
30 
31 === END_MPI_T_CVAR_INFO_BLOCK ===
32 */
33 
34 static void request_init_callback(void *request);
35 
request_init_callback(void * request)36 static void request_init_callback(void *request)
37 {
38 
39     MPIDI_UCX_ucp_request_t *ucp_request = (MPIDI_UCX_ucp_request_t *) request;
40     ucp_request->req = NULL;
41 
42 }
43 
init_num_vnis(void)44 static void init_num_vnis(void)
45 {
46     int num_vnis = 1;
47     if (MPIR_CVAR_CH4_UCX_MAX_VNIS == 0 || MPIR_CVAR_CH4_UCX_MAX_VNIS > MPIDI_global.n_vcis) {
48         num_vnis = MPIDI_global.n_vcis;
49     } else {
50         num_vnis = MPIR_CVAR_CH4_UCX_MAX_VNIS;
51     }
52 
53     /* for best performance, we ensure 1-to-1 vci/vni mapping. ref: MPIDI_OFI_vci_to_vni */
54     /* TODO: allow less num_vnis. Option 1. runtime MOD; 2. overide MPIDI_global.n_vcis */
55     MPIR_Assert(num_vnis == MPIDI_global.n_vcis);
56 
57     MPIDI_UCX_global.num_vnis = num_vnis;
58 }
59 
init_worker(int vni)60 static int init_worker(int vni)
61 {
62     int mpi_errno = MPI_SUCCESS;
63 
64     ucp_worker_params_t worker_params;
65     worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE;
66     worker_params.thread_mode = UCS_THREAD_MODE_SERIALIZED;
67 
68     ucs_status_t ucx_status;
69     ucx_status = ucp_worker_create(MPIDI_UCX_global.context, &worker_params,
70                                    &MPIDI_UCX_global.ctx[vni].worker);
71     MPIDI_UCX_CHK_STATUS(ucx_status);
72     ucx_status = ucp_worker_get_address(MPIDI_UCX_global.ctx[vni].worker,
73                                         &MPIDI_UCX_global.ctx[vni].if_address,
74                                         &MPIDI_UCX_global.ctx[vni].addrname_len);
75     MPIDI_UCX_CHK_STATUS(ucx_status);
76     MPIR_Assert(MPIDI_UCX_global.ctx[vni].addrname_len <= INT_MAX);
77 
78   fn_exit:
79     return mpi_errno;
80   fn_fail:
81     goto fn_exit;
82 }
83 
initial_address_exchange(MPIR_Comm * init_comm)84 static int initial_address_exchange(MPIR_Comm * init_comm)
85 {
86     int mpi_errno = MPI_SUCCESS;
87     ucs_status_t ucx_status;
88 
89     void *table;
90     int recv_bc_len;
91     int size = MPIR_Process.size;
92     int rank = MPIR_Process.rank;
93     mpi_errno = MPIDU_bc_table_create(rank, size, MPIDI_global.node_map[0],
94                                       MPIDI_UCX_global.ctx[0].if_address,
95                                       (int) MPIDI_UCX_global.ctx[0].addrname_len, FALSE,
96                                       MPIR_CVAR_CH4_ROOTS_ONLY_PMI, &table, &recv_bc_len);
97     MPIR_ERR_CHECK(mpi_errno);
98 
99     ucp_ep_params_t ep_params;
100     if (MPIR_CVAR_CH4_ROOTS_ONLY_PMI) {
101         int *node_roots = MPIR_Process.node_root_map;
102         int num_nodes = MPIR_Process.num_nodes;
103         int *rank_map;
104 
105         for (int i = 0; i < num_nodes; i++) {
106             ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS;
107             ep_params.address = (ucp_address_t *) ((char *) table + i * recv_bc_len);
108             ucx_status =
109                 ucp_ep_create(MPIDI_UCX_global.ctx[0].worker, &ep_params,
110                               &MPIDI_UCX_AV(&MPIDIU_get_av(0, node_roots[i])).dest[0][0]);
111             MPIDI_UCX_CHK_STATUS(ucx_status);
112         }
113         MPIDU_bc_allgather(init_comm, MPIDI_UCX_global.ctx[0].if_address,
114                            (int) MPIDI_UCX_global.ctx[0].addrname_len, FALSE,
115                            (void **) &table, &rank_map, &recv_bc_len);
116 
117         /* insert new addresses, skipping over node roots */
118         for (int i = 0; i < MPIR_Process.size; i++) {
119             if (rank_map[i] >= 0) {
120 
121                 ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS;
122                 ep_params.address = (ucp_address_t *) ((char *) table + i * recv_bc_len);
123                 ucx_status = ucp_ep_create(MPIDI_UCX_global.ctx[0].worker, &ep_params,
124                                            &MPIDI_UCX_AV(&MPIDIU_get_av(0, i)).dest[0][0]);
125                 MPIDI_UCX_CHK_STATUS(ucx_status);
126             }
127         }
128         MPIDU_bc_table_destroy();
129     } else {
130         for (int i = 0; i < size; i++) {
131             ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS;
132             ep_params.address = (ucp_address_t *) ((char *) table + i * recv_bc_len);
133             ucx_status =
134                 ucp_ep_create(MPIDI_UCX_global.ctx[0].worker, &ep_params,
135                               &MPIDI_UCX_AV(&MPIDIU_get_av(0, i)).dest[0][0]);
136             MPIDI_UCX_CHK_STATUS(ucx_status);
137         }
138         MPIDU_bc_table_destroy();
139     }
140 
141   fn_exit:
142     return mpi_errno;
143   fn_fail:
144     goto fn_exit;
145 }
146 
all_vnis_address_exchange(void)147 static int all_vnis_address_exchange(void)
148 {
149     int mpi_errno = MPI_SUCCESS;
150 
151     int size = MPIR_Process.size;
152     int rank = MPIR_Process.rank;
153     int num_vnis = MPIDI_UCX_global.num_vnis;
154 
155     /* ucx address lengths are non-uniform, use MPID_MAX_BC_SIZE */
156     size_t name_len = MPID_MAX_BC_SIZE;
157 
158     int my_len = num_vnis * name_len;
159     char *all_names = MPL_malloc(size * my_len, MPL_MEM_ADDRESS);
160     MPIR_Assert(all_names);
161 
162     char *my_names = all_names + rank * my_len;
163 
164     /* put in my addrnames */
165     for (int i = 0; i < num_vnis; i++) {
166         char *vni_addrname = my_names + i * name_len;
167         memcpy(vni_addrname, MPIDI_UCX_global.ctx[i].if_address,
168                MPIDI_UCX_global.ctx[i].addrname_len);
169     }
170     /* Allgather */
171     MPIR_Comm *comm = MPIR_Process.comm_world;
172     MPIR_Errflag_t errflag = MPIR_ERR_NONE;
173     mpi_errno = MPIR_Allgather_allcomm_auto(MPI_IN_PLACE, 0, MPI_BYTE,
174                                             all_names, my_len, MPI_BYTE, comm, &errflag);
175     MPIR_ERR_CHECK(mpi_errno);
176 
177     /* insert the addresses */
178     ucp_ep_params_t ep_params;
179     for (int vni_local = 0; vni_local < num_vnis; vni_local++) {
180         for (int r = 0; r < size; r++) {
181             MPIDI_UCX_addr_t *av = &MPIDI_UCX_AV(&MPIDIU_get_av(0, r));
182             for (int vni_remote = 0; vni_remote < num_vnis; vni_remote++) {
183                 if (vni_local == 0 && vni_remote == 0) {
184                     /* don't overwrite existing addr, or bad things will happen */
185                     continue;
186                 }
187                 int idx = r * num_vnis + vni_remote;
188                 ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS;
189                 ep_params.address = (ucp_address_t *) (all_names + idx * name_len);
190 
191                 ucs_status_t ucx_status;
192                 ucx_status = ucp_ep_create(MPIDI_UCX_global.ctx[vni_local].worker,
193                                            &ep_params, &av->dest[vni_local][vni_remote]);
194                 MPIDI_UCX_CHK_STATUS(ucx_status);
195             }
196         }
197     }
198   fn_exit:
199     MPL_free(all_names);
200     return mpi_errno;
201   fn_fail:
202     goto fn_exit;
203 }
204 
MPIDI_UCX_mpi_init_hook(int rank,int size,int appnum,int * tag_bits,MPIR_Comm * init_comm)205 int MPIDI_UCX_mpi_init_hook(int rank, int size, int appnum, int *tag_bits, MPIR_Comm * init_comm)
206 {
207     int mpi_errno = MPI_SUCCESS;
208     ucp_config_t *config;
209     ucs_status_t ucx_status;
210     uint64_t features = 0;
211     ucp_params_t ucp_params;
212 
213     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDI_UCX_MPI_INIT_HOOK);
214     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDI_UCX_MPI_INIT_HOOK);
215 
216     init_num_vnis();
217 
218     /* unable to support extended context id in current match bit configuration */
219     MPL_COMPILE_TIME_ASSERT(MPIR_CONTEXT_ID_BITS <= MPIDI_UCX_CONTEXT_TAG_BITS);
220 
221     ucx_status = ucp_config_read(NULL, NULL, &config);
222     MPIDI_UCX_CHK_STATUS(ucx_status);
223 
224     /* For now use only the tag feature */
225     features = UCP_FEATURE_TAG | UCP_FEATURE_RMA;
226     ucp_params.features = features;
227     ucp_params.request_size = sizeof(MPIDI_UCX_ucp_request_t);
228     ucp_params.request_init = request_init_callback;
229     ucp_params.request_cleanup = NULL;
230     ucp_params.estimated_num_eps = size;
231 
232     ucp_params.field_mask = UCP_PARAM_FIELD_FEATURES |
233         UCP_PARAM_FIELD_REQUEST_SIZE |
234         UCP_PARAM_FIELD_ESTIMATED_NUM_EPS | UCP_PARAM_FIELD_REQUEST_INIT;
235 
236     if (MPIDI_UCX_global.num_vnis > 1) {
237         ucp_params.mt_workers_shared = 1;
238         ucp_params.field_mask |= UCP_PARAM_FIELD_MT_WORKERS_SHARED;
239     }
240 
241     ucx_status = ucp_init(&ucp_params, config, &MPIDI_UCX_global.context);
242     MPIDI_UCX_CHK_STATUS(ucx_status);
243     ucp_config_release(config);
244 
245     /* initialize worker for vni 0 */
246     mpi_errno = init_worker(0);
247     MPIR_ERR_CHECK(mpi_errno);
248 
249     mpi_errno = initial_address_exchange(init_comm);
250     MPIR_ERR_CHECK(mpi_errno);
251 
252     *tag_bits = MPIR_TAG_BITS_DEFAULT;
253 
254   fn_exit:
255     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDI_UCX_MPI_INIT_HOOK);
256     return mpi_errno;
257   fn_fail:
258     if (MPIDI_UCX_global.ctx[0].worker != NULL)
259         ucp_worker_destroy(MPIDI_UCX_global.ctx[0].worker);
260 
261     if (MPIDI_UCX_global.context != NULL)
262         ucp_cleanup(MPIDI_UCX_global.context);
263 
264     goto fn_exit;
265 }
266 
MPIDI_UCX_mpi_finalize_hook(void)267 int MPIDI_UCX_mpi_finalize_hook(void)
268 {
269     int mpi_errno = MPI_SUCCESS;
270     MPIR_Comm *comm;
271     ucs_status_ptr_t ucp_request;
272     ucs_status_ptr_t *pending;
273 
274     comm = MPIR_Process.comm_world;
275     int n = MPIDI_UCX_global.num_vnis;
276     pending = MPL_malloc(sizeof(ucs_status_ptr_t) * comm->local_size * n * n, MPL_MEM_OTHER);
277 
278     int p = 0;
279     for (int i = 0; i < comm->local_size; i++) {
280         MPIDI_UCX_addr_t *av = &MPIDI_UCX_AV(&MPIDIU_get_av(0, i));
281         for (int vni_local = 0; vni_local < MPIDI_UCX_global.num_vnis; vni_local++) {
282             for (int vni_remote = 0; vni_remote < MPIDI_UCX_global.num_vnis; vni_remote++) {
283                 ucp_request = ucp_disconnect_nb(av->dest[vni_local][vni_remote]);
284                 MPIDI_UCX_CHK_REQUEST(ucp_request);
285                 if (ucp_request != UCS_OK) {
286                     pending[p] = ucp_request;
287                     p++;
288                 }
289             }
290         }
291     }
292 
293     /* now complete the outstaning requests! Important: call progress inbetween, otherwise we
294      * deadlock! */
295     int completed;
296     do {
297         for (int i = 0; i < MPIDI_UCX_global.num_vnis; i++) {
298             ucp_worker_progress(MPIDI_UCX_global.ctx[i].worker);
299         }
300         completed = p;
301         for (int i = 0; i < p; i++) {
302             if (ucp_request_is_completed(pending[i]) != 0)
303                 completed -= 1;
304         }
305     } while (completed != 0);
306 
307     for (int i = 0; i < p; i++) {
308         ucp_request_release(pending[i]);
309     }
310 
311     mpi_errno = MPIR_pmi_barrier();
312     MPIR_ERR_CHECK(mpi_errno);
313 
314     for (int i = 0; i < MPIDI_UCX_global.num_vnis; i++) {
315         if (MPIDI_UCX_global.ctx[i].worker != NULL) {
316             ucp_worker_destroy(MPIDI_UCX_global.ctx[i].worker);
317         }
318     }
319 
320     if (MPIDI_UCX_global.context != NULL)
321         ucp_cleanup(MPIDI_UCX_global.context);
322 
323   fn_exit:
324     MPL_free(pending);
325     return mpi_errno;
326   fn_fail:
327     goto fn_exit;
328 
329 }
330 
331 /* static functions for MPIDI_UCX_post_init */
flush_cb(void * request,ucs_status_t status)332 static void flush_cb(void *request, ucs_status_t status)
333 {
334 }
335 
flush_all(void)336 static void flush_all(void)
337 {
338     void *reqs[MPIDI_CH4_MAX_VCIS];
339     for (int vni = 0; vni < MPIDI_UCX_global.num_vnis; vni++) {
340         reqs[vni] = ucp_worker_flush_nb(MPIDI_UCX_global.ctx[vni].worker, 0, &flush_cb);
341     }
342     for (int vni = 0; vni < MPIDI_UCX_global.num_vnis; vni++) {
343         if (reqs[vni] == NULL) {
344             continue;
345         } else if (UCS_PTR_IS_ERR(reqs[vni])) {
346             continue;
347         } else {
348             ucs_status_t status;
349             do {
350                 MPID_Progress_test(NULL);
351                 status = ucp_request_check_status(reqs[vni]);
352             } while (status == UCS_INPROGRESS);
353             ucp_request_release(reqs[vni]);
354         }
355     }
356 }
357 
MPIDI_UCX_post_init(void)358 int MPIDI_UCX_post_init(void)
359 {
360     int mpi_errno = MPI_SUCCESS;
361 
362     for (int i = 1; i < MPIDI_UCX_global.num_vnis; i++) {
363         mpi_errno = init_worker(i);
364         MPIR_ERR_CHECK(mpi_errno);
365     }
366     mpi_errno = all_vnis_address_exchange();
367     MPIR_ERR_CHECK(mpi_errno);
368 
369     /* flush all pending wireup operations or it may interfere with RMA flush_ops count */
370     flush_all();
371 
372   fn_exit:
373     return mpi_errno;
374   fn_fail:
375     goto fn_exit;
376 }
377 
MPIDI_UCX_get_vci_attr(int vci)378 int MPIDI_UCX_get_vci_attr(int vci)
379 {
380     MPIR_Assert(0 <= vci && vci < 1);
381     return MPIDI_VCI_TX | MPIDI_VCI_RX;
382 }
383 
MPIDI_UCX_get_local_upids(MPIR_Comm * comm,size_t ** local_upid_size,char ** local_upids)384 int MPIDI_UCX_get_local_upids(MPIR_Comm * comm, size_t ** local_upid_size, char **local_upids)
385 {
386     MPIR_Assert(0);
387     return MPI_SUCCESS;
388 }
389 
MPIDI_UCX_upids_to_lupids(int size,size_t * remote_upid_size,char * remote_upids,int ** remote_lupids)390 int MPIDI_UCX_upids_to_lupids(int size, size_t * remote_upid_size, char *remote_upids,
391                               int **remote_lupids)
392 {
393     MPIR_Assert(0);
394     return MPI_SUCCESS;
395 }
396 
MPIDI_UCX_create_intercomm_from_lpids(MPIR_Comm * newcomm_ptr,int size,const int lpids[])397 int MPIDI_UCX_create_intercomm_from_lpids(MPIR_Comm * newcomm_ptr, int size, const int lpids[])
398 {
399     return MPI_SUCCESS;
400 }
401 
MPIDI_UCX_mpi_free_mem(void * ptr)402 int MPIDI_UCX_mpi_free_mem(void *ptr)
403 {
404     return MPIDIG_mpi_free_mem(ptr);
405 }
406 
MPIDI_UCX_mpi_alloc_mem(size_t size,MPIR_Info * info_ptr)407 void *MPIDI_UCX_mpi_alloc_mem(size_t size, MPIR_Info * info_ptr)
408 {
409     return MPIDIG_mpi_alloc_mem(size, info_ptr);
410 }
411