1 /*
2 * Copyright (C) by Argonne National Laboratory
3 * See COPYRIGHT in top-level directory
4 */
5
6 #include "mpidimpl.h"
7 #include "ucx_impl.h"
8 #include "mpidu_bc.h"
9 #include <ucp/api/ucp.h>
10
11 /*
12 === BEGIN_MPI_T_CVAR_INFO_BLOCK ===
13
14 categories :
15 - name : CH4_UCX
16 description : A category for CH4 UCX netmod variables
17
18 cvars:
19 - name : MPIR_CVAR_CH4_UCX_MAX_VNIS
20 category : CH4_UCX
21 type : int
22 default : 0
23 class : none
24 verbosity : MPI_T_VERBOSITY_USER_BASIC
25 scope : MPI_T_SCOPE_LOCAL
26 description : >-
27 If set to positive, this CVAR specifies the maximum number of CH4 VNIs
28 that UCX netmod exposes. If set to 0 (the default) or bigger than
29 MPIR_CVAR_CH4_NUM_VCIS, the number of exposed VNIs is set to MPIR_CVAR_CH4_NUM_VCIS.
30
31 === END_MPI_T_CVAR_INFO_BLOCK ===
32 */
33
34 static void request_init_callback(void *request);
35
request_init_callback(void * request)36 static void request_init_callback(void *request)
37 {
38
39 MPIDI_UCX_ucp_request_t *ucp_request = (MPIDI_UCX_ucp_request_t *) request;
40 ucp_request->req = NULL;
41
42 }
43
init_num_vnis(void)44 static void init_num_vnis(void)
45 {
46 int num_vnis = 1;
47 if (MPIR_CVAR_CH4_UCX_MAX_VNIS == 0 || MPIR_CVAR_CH4_UCX_MAX_VNIS > MPIDI_global.n_vcis) {
48 num_vnis = MPIDI_global.n_vcis;
49 } else {
50 num_vnis = MPIR_CVAR_CH4_UCX_MAX_VNIS;
51 }
52
53 /* for best performance, we ensure 1-to-1 vci/vni mapping. ref: MPIDI_OFI_vci_to_vni */
54 /* TODO: allow less num_vnis. Option 1. runtime MOD; 2. overide MPIDI_global.n_vcis */
55 MPIR_Assert(num_vnis == MPIDI_global.n_vcis);
56
57 MPIDI_UCX_global.num_vnis = num_vnis;
58 }
59
init_worker(int vni)60 static int init_worker(int vni)
61 {
62 int mpi_errno = MPI_SUCCESS;
63
64 ucp_worker_params_t worker_params;
65 worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE;
66 worker_params.thread_mode = UCS_THREAD_MODE_SERIALIZED;
67
68 ucs_status_t ucx_status;
69 ucx_status = ucp_worker_create(MPIDI_UCX_global.context, &worker_params,
70 &MPIDI_UCX_global.ctx[vni].worker);
71 MPIDI_UCX_CHK_STATUS(ucx_status);
72 ucx_status = ucp_worker_get_address(MPIDI_UCX_global.ctx[vni].worker,
73 &MPIDI_UCX_global.ctx[vni].if_address,
74 &MPIDI_UCX_global.ctx[vni].addrname_len);
75 MPIDI_UCX_CHK_STATUS(ucx_status);
76 MPIR_Assert(MPIDI_UCX_global.ctx[vni].addrname_len <= INT_MAX);
77
78 fn_exit:
79 return mpi_errno;
80 fn_fail:
81 goto fn_exit;
82 }
83
initial_address_exchange(MPIR_Comm * init_comm)84 static int initial_address_exchange(MPIR_Comm * init_comm)
85 {
86 int mpi_errno = MPI_SUCCESS;
87 ucs_status_t ucx_status;
88
89 void *table;
90 int recv_bc_len;
91 int size = MPIR_Process.size;
92 int rank = MPIR_Process.rank;
93 mpi_errno = MPIDU_bc_table_create(rank, size, MPIDI_global.node_map[0],
94 MPIDI_UCX_global.ctx[0].if_address,
95 (int) MPIDI_UCX_global.ctx[0].addrname_len, FALSE,
96 MPIR_CVAR_CH4_ROOTS_ONLY_PMI, &table, &recv_bc_len);
97 MPIR_ERR_CHECK(mpi_errno);
98
99 ucp_ep_params_t ep_params;
100 if (MPIR_CVAR_CH4_ROOTS_ONLY_PMI) {
101 int *node_roots = MPIR_Process.node_root_map;
102 int num_nodes = MPIR_Process.num_nodes;
103 int *rank_map;
104
105 for (int i = 0; i < num_nodes; i++) {
106 ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS;
107 ep_params.address = (ucp_address_t *) ((char *) table + i * recv_bc_len);
108 ucx_status =
109 ucp_ep_create(MPIDI_UCX_global.ctx[0].worker, &ep_params,
110 &MPIDI_UCX_AV(&MPIDIU_get_av(0, node_roots[i])).dest[0][0]);
111 MPIDI_UCX_CHK_STATUS(ucx_status);
112 }
113 MPIDU_bc_allgather(init_comm, MPIDI_UCX_global.ctx[0].if_address,
114 (int) MPIDI_UCX_global.ctx[0].addrname_len, FALSE,
115 (void **) &table, &rank_map, &recv_bc_len);
116
117 /* insert new addresses, skipping over node roots */
118 for (int i = 0; i < MPIR_Process.size; i++) {
119 if (rank_map[i] >= 0) {
120
121 ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS;
122 ep_params.address = (ucp_address_t *) ((char *) table + i * recv_bc_len);
123 ucx_status = ucp_ep_create(MPIDI_UCX_global.ctx[0].worker, &ep_params,
124 &MPIDI_UCX_AV(&MPIDIU_get_av(0, i)).dest[0][0]);
125 MPIDI_UCX_CHK_STATUS(ucx_status);
126 }
127 }
128 MPIDU_bc_table_destroy();
129 } else {
130 for (int i = 0; i < size; i++) {
131 ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS;
132 ep_params.address = (ucp_address_t *) ((char *) table + i * recv_bc_len);
133 ucx_status =
134 ucp_ep_create(MPIDI_UCX_global.ctx[0].worker, &ep_params,
135 &MPIDI_UCX_AV(&MPIDIU_get_av(0, i)).dest[0][0]);
136 MPIDI_UCX_CHK_STATUS(ucx_status);
137 }
138 MPIDU_bc_table_destroy();
139 }
140
141 fn_exit:
142 return mpi_errno;
143 fn_fail:
144 goto fn_exit;
145 }
146
all_vnis_address_exchange(void)147 static int all_vnis_address_exchange(void)
148 {
149 int mpi_errno = MPI_SUCCESS;
150
151 int size = MPIR_Process.size;
152 int rank = MPIR_Process.rank;
153 int num_vnis = MPIDI_UCX_global.num_vnis;
154
155 /* ucx address lengths are non-uniform, use MPID_MAX_BC_SIZE */
156 size_t name_len = MPID_MAX_BC_SIZE;
157
158 int my_len = num_vnis * name_len;
159 char *all_names = MPL_malloc(size * my_len, MPL_MEM_ADDRESS);
160 MPIR_Assert(all_names);
161
162 char *my_names = all_names + rank * my_len;
163
164 /* put in my addrnames */
165 for (int i = 0; i < num_vnis; i++) {
166 char *vni_addrname = my_names + i * name_len;
167 memcpy(vni_addrname, MPIDI_UCX_global.ctx[i].if_address,
168 MPIDI_UCX_global.ctx[i].addrname_len);
169 }
170 /* Allgather */
171 MPIR_Comm *comm = MPIR_Process.comm_world;
172 MPIR_Errflag_t errflag = MPIR_ERR_NONE;
173 mpi_errno = MPIR_Allgather_allcomm_auto(MPI_IN_PLACE, 0, MPI_BYTE,
174 all_names, my_len, MPI_BYTE, comm, &errflag);
175 MPIR_ERR_CHECK(mpi_errno);
176
177 /* insert the addresses */
178 ucp_ep_params_t ep_params;
179 for (int vni_local = 0; vni_local < num_vnis; vni_local++) {
180 for (int r = 0; r < size; r++) {
181 MPIDI_UCX_addr_t *av = &MPIDI_UCX_AV(&MPIDIU_get_av(0, r));
182 for (int vni_remote = 0; vni_remote < num_vnis; vni_remote++) {
183 if (vni_local == 0 && vni_remote == 0) {
184 /* don't overwrite existing addr, or bad things will happen */
185 continue;
186 }
187 int idx = r * num_vnis + vni_remote;
188 ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS;
189 ep_params.address = (ucp_address_t *) (all_names + idx * name_len);
190
191 ucs_status_t ucx_status;
192 ucx_status = ucp_ep_create(MPIDI_UCX_global.ctx[vni_local].worker,
193 &ep_params, &av->dest[vni_local][vni_remote]);
194 MPIDI_UCX_CHK_STATUS(ucx_status);
195 }
196 }
197 }
198 fn_exit:
199 MPL_free(all_names);
200 return mpi_errno;
201 fn_fail:
202 goto fn_exit;
203 }
204
MPIDI_UCX_mpi_init_hook(int rank,int size,int appnum,int * tag_bits,MPIR_Comm * init_comm)205 int MPIDI_UCX_mpi_init_hook(int rank, int size, int appnum, int *tag_bits, MPIR_Comm * init_comm)
206 {
207 int mpi_errno = MPI_SUCCESS;
208 ucp_config_t *config;
209 ucs_status_t ucx_status;
210 uint64_t features = 0;
211 ucp_params_t ucp_params;
212
213 MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIDI_UCX_MPI_INIT_HOOK);
214 MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIDI_UCX_MPI_INIT_HOOK);
215
216 init_num_vnis();
217
218 /* unable to support extended context id in current match bit configuration */
219 MPL_COMPILE_TIME_ASSERT(MPIR_CONTEXT_ID_BITS <= MPIDI_UCX_CONTEXT_TAG_BITS);
220
221 ucx_status = ucp_config_read(NULL, NULL, &config);
222 MPIDI_UCX_CHK_STATUS(ucx_status);
223
224 /* For now use only the tag feature */
225 features = UCP_FEATURE_TAG | UCP_FEATURE_RMA;
226 ucp_params.features = features;
227 ucp_params.request_size = sizeof(MPIDI_UCX_ucp_request_t);
228 ucp_params.request_init = request_init_callback;
229 ucp_params.request_cleanup = NULL;
230 ucp_params.estimated_num_eps = size;
231
232 ucp_params.field_mask = UCP_PARAM_FIELD_FEATURES |
233 UCP_PARAM_FIELD_REQUEST_SIZE |
234 UCP_PARAM_FIELD_ESTIMATED_NUM_EPS | UCP_PARAM_FIELD_REQUEST_INIT;
235
236 if (MPIDI_UCX_global.num_vnis > 1) {
237 ucp_params.mt_workers_shared = 1;
238 ucp_params.field_mask |= UCP_PARAM_FIELD_MT_WORKERS_SHARED;
239 }
240
241 ucx_status = ucp_init(&ucp_params, config, &MPIDI_UCX_global.context);
242 MPIDI_UCX_CHK_STATUS(ucx_status);
243 ucp_config_release(config);
244
245 /* initialize worker for vni 0 */
246 mpi_errno = init_worker(0);
247 MPIR_ERR_CHECK(mpi_errno);
248
249 mpi_errno = initial_address_exchange(init_comm);
250 MPIR_ERR_CHECK(mpi_errno);
251
252 *tag_bits = MPIR_TAG_BITS_DEFAULT;
253
254 fn_exit:
255 MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIDI_UCX_MPI_INIT_HOOK);
256 return mpi_errno;
257 fn_fail:
258 if (MPIDI_UCX_global.ctx[0].worker != NULL)
259 ucp_worker_destroy(MPIDI_UCX_global.ctx[0].worker);
260
261 if (MPIDI_UCX_global.context != NULL)
262 ucp_cleanup(MPIDI_UCX_global.context);
263
264 goto fn_exit;
265 }
266
MPIDI_UCX_mpi_finalize_hook(void)267 int MPIDI_UCX_mpi_finalize_hook(void)
268 {
269 int mpi_errno = MPI_SUCCESS;
270 MPIR_Comm *comm;
271 ucs_status_ptr_t ucp_request;
272 ucs_status_ptr_t *pending;
273
274 comm = MPIR_Process.comm_world;
275 int n = MPIDI_UCX_global.num_vnis;
276 pending = MPL_malloc(sizeof(ucs_status_ptr_t) * comm->local_size * n * n, MPL_MEM_OTHER);
277
278 int p = 0;
279 for (int i = 0; i < comm->local_size; i++) {
280 MPIDI_UCX_addr_t *av = &MPIDI_UCX_AV(&MPIDIU_get_av(0, i));
281 for (int vni_local = 0; vni_local < MPIDI_UCX_global.num_vnis; vni_local++) {
282 for (int vni_remote = 0; vni_remote < MPIDI_UCX_global.num_vnis; vni_remote++) {
283 ucp_request = ucp_disconnect_nb(av->dest[vni_local][vni_remote]);
284 MPIDI_UCX_CHK_REQUEST(ucp_request);
285 if (ucp_request != UCS_OK) {
286 pending[p] = ucp_request;
287 p++;
288 }
289 }
290 }
291 }
292
293 /* now complete the outstaning requests! Important: call progress inbetween, otherwise we
294 * deadlock! */
295 int completed;
296 do {
297 for (int i = 0; i < MPIDI_UCX_global.num_vnis; i++) {
298 ucp_worker_progress(MPIDI_UCX_global.ctx[i].worker);
299 }
300 completed = p;
301 for (int i = 0; i < p; i++) {
302 if (ucp_request_is_completed(pending[i]) != 0)
303 completed -= 1;
304 }
305 } while (completed != 0);
306
307 for (int i = 0; i < p; i++) {
308 ucp_request_release(pending[i]);
309 }
310
311 mpi_errno = MPIR_pmi_barrier();
312 MPIR_ERR_CHECK(mpi_errno);
313
314 for (int i = 0; i < MPIDI_UCX_global.num_vnis; i++) {
315 if (MPIDI_UCX_global.ctx[i].worker != NULL) {
316 ucp_worker_destroy(MPIDI_UCX_global.ctx[i].worker);
317 }
318 }
319
320 if (MPIDI_UCX_global.context != NULL)
321 ucp_cleanup(MPIDI_UCX_global.context);
322
323 fn_exit:
324 MPL_free(pending);
325 return mpi_errno;
326 fn_fail:
327 goto fn_exit;
328
329 }
330
331 /* static functions for MPIDI_UCX_post_init */
flush_cb(void * request,ucs_status_t status)332 static void flush_cb(void *request, ucs_status_t status)
333 {
334 }
335
flush_all(void)336 static void flush_all(void)
337 {
338 void *reqs[MPIDI_CH4_MAX_VCIS];
339 for (int vni = 0; vni < MPIDI_UCX_global.num_vnis; vni++) {
340 reqs[vni] = ucp_worker_flush_nb(MPIDI_UCX_global.ctx[vni].worker, 0, &flush_cb);
341 }
342 for (int vni = 0; vni < MPIDI_UCX_global.num_vnis; vni++) {
343 if (reqs[vni] == NULL) {
344 continue;
345 } else if (UCS_PTR_IS_ERR(reqs[vni])) {
346 continue;
347 } else {
348 ucs_status_t status;
349 do {
350 MPID_Progress_test(NULL);
351 status = ucp_request_check_status(reqs[vni]);
352 } while (status == UCS_INPROGRESS);
353 ucp_request_release(reqs[vni]);
354 }
355 }
356 }
357
MPIDI_UCX_post_init(void)358 int MPIDI_UCX_post_init(void)
359 {
360 int mpi_errno = MPI_SUCCESS;
361
362 for (int i = 1; i < MPIDI_UCX_global.num_vnis; i++) {
363 mpi_errno = init_worker(i);
364 MPIR_ERR_CHECK(mpi_errno);
365 }
366 mpi_errno = all_vnis_address_exchange();
367 MPIR_ERR_CHECK(mpi_errno);
368
369 /* flush all pending wireup operations or it may interfere with RMA flush_ops count */
370 flush_all();
371
372 fn_exit:
373 return mpi_errno;
374 fn_fail:
375 goto fn_exit;
376 }
377
MPIDI_UCX_get_vci_attr(int vci)378 int MPIDI_UCX_get_vci_attr(int vci)
379 {
380 MPIR_Assert(0 <= vci && vci < 1);
381 return MPIDI_VCI_TX | MPIDI_VCI_RX;
382 }
383
MPIDI_UCX_get_local_upids(MPIR_Comm * comm,size_t ** local_upid_size,char ** local_upids)384 int MPIDI_UCX_get_local_upids(MPIR_Comm * comm, size_t ** local_upid_size, char **local_upids)
385 {
386 MPIR_Assert(0);
387 return MPI_SUCCESS;
388 }
389
MPIDI_UCX_upids_to_lupids(int size,size_t * remote_upid_size,char * remote_upids,int ** remote_lupids)390 int MPIDI_UCX_upids_to_lupids(int size, size_t * remote_upid_size, char *remote_upids,
391 int **remote_lupids)
392 {
393 MPIR_Assert(0);
394 return MPI_SUCCESS;
395 }
396
MPIDI_UCX_create_intercomm_from_lpids(MPIR_Comm * newcomm_ptr,int size,const int lpids[])397 int MPIDI_UCX_create_intercomm_from_lpids(MPIR_Comm * newcomm_ptr, int size, const int lpids[])
398 {
399 return MPI_SUCCESS;
400 }
401
MPIDI_UCX_mpi_free_mem(void * ptr)402 int MPIDI_UCX_mpi_free_mem(void *ptr)
403 {
404 return MPIDIG_mpi_free_mem(ptr);
405 }
406
MPIDI_UCX_mpi_alloc_mem(size_t size,MPIR_Info * info_ptr)407 void *MPIDI_UCX_mpi_alloc_mem(size_t size, MPIR_Info * info_ptr)
408 {
409 return MPIDIG_mpi_alloc_mem(size, info_ptr);
410 }
411