1 /*
2 * Copyright (C) Mellanox Technologies Ltd. 2001-2011. ALL RIGHTS RESERVED.
3 * Copyright (c) 2016 The University of Tennessee and The University
4 * of Tennessee Research Foundation. All rights
5 * reserved.
6 * $COPYRIGHT$
7 *
8 * Additional copyrights may follow
9 *
10 * $HEADER$
11 */
12
13 #include "pml_ucx.h"
14
15 #include "opal/runtime/opal.h"
16 #include "opal/mca/pmix/pmix.h"
17 #include "ompi/attribute/attribute.h"
18 #include "ompi/message/message.h"
19 #include "ompi/mca/pml/base/pml_base_bsend.h"
20 #include "pml_ucx_request.h"
21
22 #include <inttypes.h>
23
24
25 #define PML_UCX_TRACE_SEND(_msg, _buf, _count, _datatype, _dst, _tag, _mode, _comm, ...) \
26 PML_UCX_VERBOSE(8, _msg " buf %p count %zu type '%s' dst %d tag %d mode %s comm %d '%s'", \
27 __VA_ARGS__, \
28 (_buf), (_count), (_datatype)->name, (_dst), (_tag), \
29 mca_pml_ucx_send_mode_name(_mode), (_comm)->c_contextid, \
30 (_comm)->c_name);
31
32 #define PML_UCX_TRACE_RECV(_msg, _buf, _count, _datatype, _src, _tag, _comm, ...) \
33 PML_UCX_VERBOSE(8, _msg " buf %p count %zu type '%s' src %d tag %d comm %d '%s'", \
34 __VA_ARGS__, \
35 (_buf), (_count), (_datatype)->name, (_src), (_tag), \
36 (_comm)->c_contextid, (_comm)->c_name);
37
38 #define PML_UCX_TRACE_PROBE(_msg, _src, _tag, _comm) \
39 PML_UCX_VERBOSE(8, _msg " src %d tag %d comm %d '%s'", \
40 _src, (_tag), (_comm)->c_contextid, (_comm)->c_name);
41
42 #define PML_UCX_TRACE_MRECV(_msg, _buf, _count, _datatype, _message) \
43 PML_UCX_VERBOSE(8, _msg " buf %p count %zu type '%s' msg *%p=%p (%p)", \
44 (_buf), (_count), (_datatype)->name, (void*)(_message), \
45 (void*)*(_message), (*(_message))->req_ptr);
46
47 #define MODEX_KEY "pml-ucx"
48
49 mca_pml_ucx_module_t ompi_pml_ucx = {
50 {
51 mca_pml_ucx_add_procs,
52 mca_pml_ucx_del_procs,
53 mca_pml_ucx_enable,
54 NULL,
55 mca_pml_ucx_add_comm,
56 mca_pml_ucx_del_comm,
57 mca_pml_ucx_irecv_init,
58 mca_pml_ucx_irecv,
59 mca_pml_ucx_recv,
60 mca_pml_ucx_isend_init,
61 mca_pml_ucx_isend,
62 mca_pml_ucx_send,
63 mca_pml_ucx_iprobe,
64 mca_pml_ucx_probe,
65 mca_pml_ucx_start,
66 mca_pml_ucx_improbe,
67 mca_pml_ucx_mprobe,
68 mca_pml_ucx_imrecv,
69 mca_pml_ucx_mrecv,
70 mca_pml_ucx_dump,
71 NULL, /* FT */
72 1ul << (PML_UCX_CONTEXT_BITS),
73 1ul << (PML_UCX_TAG_BITS - 1),
74 },
75 NULL, /* ucp_context */
76 NULL /* ucp_worker */
77 };
78
79 #define PML_UCX_REQ_ALLOCA() \
80 ((char *)alloca(ompi_pml_ucx.request_size) + ompi_pml_ucx.request_size);
81
82
mca_pml_ucx_send_worker_address(void)83 static int mca_pml_ucx_send_worker_address(void)
84 {
85 ucp_address_t *address;
86 ucs_status_t status;
87 size_t addrlen;
88 int rc;
89
90 status = ucp_worker_get_address(ompi_pml_ucx.ucp_worker, &address, &addrlen);
91 if (UCS_OK != status) {
92 PML_UCX_ERROR("Failed to get worker address");
93 return OMPI_ERROR;
94 }
95
96 OPAL_MODEX_SEND(rc, OPAL_PMIX_GLOBAL,
97 &mca_pml_ucx_component.pmlm_version, (void*)address, addrlen);
98 if (OMPI_SUCCESS != rc) {
99 PML_UCX_ERROR("Open MPI couldn't distribute EP connection details");
100 return OMPI_ERROR;
101 }
102
103 ucp_worker_release_address(ompi_pml_ucx.ucp_worker, address);
104
105 return OMPI_SUCCESS;
106 }
107
mca_pml_ucx_recv_worker_address(ompi_proc_t * proc,ucp_address_t ** address_p,size_t * addrlen_p)108 static int mca_pml_ucx_recv_worker_address(ompi_proc_t *proc,
109 ucp_address_t **address_p,
110 size_t *addrlen_p)
111 {
112 int ret;
113
114 *address_p = NULL;
115 OPAL_MODEX_RECV(ret, &mca_pml_ucx_component.pmlm_version, &proc->super.proc_name,
116 (void**)address_p, addrlen_p);
117 if (ret < 0) {
118 PML_UCX_ERROR("Failed to receive UCX worker address: %s (%d)",
119 opal_strerror(ret), ret);
120 }
121 return ret;
122 }
123
mca_pml_ucx_open(void)124 int mca_pml_ucx_open(void)
125 {
126 ucp_context_attr_t attr;
127 ucp_params_t params;
128 ucp_config_t *config;
129 ucs_status_t status;
130
131 PML_UCX_VERBOSE(1, "mca_pml_ucx_open");
132
133 /* Read options */
134 status = ucp_config_read("MPI", NULL, &config);
135 if (UCS_OK != status) {
136 return OMPI_ERROR;
137 }
138
139 /* Initialize UCX context */
140 params.field_mask = UCP_PARAM_FIELD_FEATURES |
141 UCP_PARAM_FIELD_REQUEST_SIZE |
142 UCP_PARAM_FIELD_REQUEST_INIT |
143 UCP_PARAM_FIELD_REQUEST_CLEANUP |
144 UCP_PARAM_FIELD_TAG_SENDER_MASK |
145 UCP_PARAM_FIELD_MT_WORKERS_SHARED |
146 UCP_PARAM_FIELD_ESTIMATED_NUM_EPS;
147 params.features = UCP_FEATURE_TAG;
148 params.request_size = sizeof(ompi_request_t);
149 params.request_init = mca_pml_ucx_request_init;
150 params.request_cleanup = mca_pml_ucx_request_cleanup;
151 params.tag_sender_mask = PML_UCX_SPECIFIC_SOURCE_MASK;
152 params.mt_workers_shared = 0; /* we do not need mt support for context
153 since it will be protected by worker */
154 params.estimated_num_eps = ompi_proc_world_size();
155
156 status = ucp_init(¶ms, config, &ompi_pml_ucx.ucp_context);
157 ucp_config_release(config);
158
159 if (UCS_OK != status) {
160 return OMPI_ERROR;
161 }
162
163 /* Query UCX attributes */
164 attr.field_mask = UCP_ATTR_FIELD_REQUEST_SIZE;
165 status = ucp_context_query(ompi_pml_ucx.ucp_context, &attr);
166 if (UCS_OK != status) {
167 ucp_cleanup(ompi_pml_ucx.ucp_context);
168 ompi_pml_ucx.ucp_context = NULL;
169 return OMPI_ERROR;
170 }
171
172 ompi_pml_ucx.request_size = attr.request_size;
173
174 return OMPI_SUCCESS;
175 }
176
mca_pml_ucx_close(void)177 int mca_pml_ucx_close(void)
178 {
179 PML_UCX_VERBOSE(1, "mca_pml_ucx_close");
180
181 if (ompi_pml_ucx.ucp_context != NULL) {
182 ucp_cleanup(ompi_pml_ucx.ucp_context);
183 ompi_pml_ucx.ucp_context = NULL;
184 }
185 return OMPI_SUCCESS;
186 }
187
mca_pml_ucx_init(void)188 int mca_pml_ucx_init(void)
189 {
190 ucp_worker_params_t params;
191 ucp_worker_attr_t attr;
192 ucs_status_t status;
193 int i, rc;
194
195 PML_UCX_VERBOSE(1, "mca_pml_ucx_init");
196
197 /* TODO check MPI thread mode */
198 params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE;
199 params.thread_mode = UCS_THREAD_MODE_SINGLE;
200 if (ompi_mpi_thread_multiple) {
201 params.thread_mode = UCS_THREAD_MODE_MULTI;
202 } else {
203 params.thread_mode = UCS_THREAD_MODE_SINGLE;
204 }
205
206 status = ucp_worker_create(ompi_pml_ucx.ucp_context, ¶ms,
207 &ompi_pml_ucx.ucp_worker);
208 if (UCS_OK != status) {
209 PML_UCX_ERROR("Failed to create UCP worker");
210 rc = OMPI_ERROR;
211 goto err;
212 }
213
214 attr.field_mask = UCP_WORKER_ATTR_FIELD_THREAD_MODE;
215 status = ucp_worker_query(ompi_pml_ucx.ucp_worker, &attr);
216 if (UCS_OK != status) {
217 PML_UCX_ERROR("Failed to query UCP worker thread level");
218 rc = OMPI_ERROR;
219 goto err_destroy_worker;
220 }
221
222 if (ompi_mpi_thread_multiple && (attr.thread_mode != UCS_THREAD_MODE_MULTI)) {
223 /* UCX does not support multithreading, disqualify current PML for now */
224 /* TODO: we should let OMPI to fallback to THREAD_SINGLE mode */
225 PML_UCX_ERROR("UCP worker does not support MPI_THREAD_MULTIPLE");
226 rc = OMPI_ERR_NOT_SUPPORTED;
227 goto err_destroy_worker;
228 }
229
230 rc = mca_pml_ucx_send_worker_address();
231 if (rc < 0) {
232 goto err_destroy_worker;
233 }
234
235 ompi_pml_ucx.datatype_attr_keyval = MPI_KEYVAL_INVALID;
236 for (i = 0; i < OMPI_DATATYPE_MAX_PREDEFINED; ++i) {
237 ompi_pml_ucx.predefined_types[i] = PML_UCX_DATATYPE_INVALID;
238 }
239
240 /* Initialize the free lists */
241 OBJ_CONSTRUCT(&ompi_pml_ucx.persistent_reqs, mca_pml_ucx_freelist_t);
242 OBJ_CONSTRUCT(&ompi_pml_ucx.convs, mca_pml_ucx_freelist_t);
243
244 /* Create a completed request to be returned from isend */
245 OBJ_CONSTRUCT(&ompi_pml_ucx.completed_send_req, ompi_request_t);
246 mca_pml_ucx_completed_request_init(&ompi_pml_ucx.completed_send_req);
247
248 opal_progress_register(mca_pml_ucx_progress);
249
250 PML_UCX_VERBOSE(2, "created ucp context %p, worker %p",
251 (void *)ompi_pml_ucx.ucp_context,
252 (void *)ompi_pml_ucx.ucp_worker);
253 return rc;
254
255 err_destroy_worker:
256 ucp_worker_destroy(ompi_pml_ucx.ucp_worker);
257 ompi_pml_ucx.ucp_worker = NULL;
258 err:
259 return OMPI_ERROR;
260 }
261
mca_pml_ucx_cleanup(void)262 int mca_pml_ucx_cleanup(void)
263 {
264 int i;
265
266 PML_UCX_VERBOSE(1, "mca_pml_ucx_cleanup");
267
268 opal_progress_unregister(mca_pml_ucx_progress);
269
270 if (ompi_pml_ucx.datatype_attr_keyval != MPI_KEYVAL_INVALID) {
271 ompi_attr_free_keyval(TYPE_ATTR, &ompi_pml_ucx.datatype_attr_keyval, false);
272 }
273
274 for (i = 0; i < OMPI_DATATYPE_MAX_PREDEFINED; ++i) {
275 if (ompi_pml_ucx.predefined_types[i] != PML_UCX_DATATYPE_INVALID) {
276 ucp_dt_destroy(ompi_pml_ucx.predefined_types[i]);
277 ompi_pml_ucx.predefined_types[i] = PML_UCX_DATATYPE_INVALID;
278 }
279 }
280
281 ompi_pml_ucx.completed_send_req.req_state = OMPI_REQUEST_INVALID;
282 OMPI_REQUEST_FINI(&ompi_pml_ucx.completed_send_req);
283 OBJ_DESTRUCT(&ompi_pml_ucx.completed_send_req);
284
285 OBJ_DESTRUCT(&ompi_pml_ucx.convs);
286 OBJ_DESTRUCT(&ompi_pml_ucx.persistent_reqs);
287
288 if (ompi_pml_ucx.ucp_worker) {
289 ucp_worker_destroy(ompi_pml_ucx.ucp_worker);
290 ompi_pml_ucx.ucp_worker = NULL;
291 }
292
293 return OMPI_SUCCESS;
294 }
295
mca_pml_ucx_add_proc_common(ompi_proc_t * proc)296 static ucp_ep_h mca_pml_ucx_add_proc_common(ompi_proc_t *proc)
297 {
298 ucp_ep_params_t ep_params;
299 ucp_address_t *address;
300 ucs_status_t status;
301 size_t addrlen;
302 ucp_ep_h ep;
303 int ret;
304
305 ret = mca_pml_ucx_recv_worker_address(proc, &address, &addrlen);
306 if (ret < 0) {
307 return NULL;
308 }
309
310 PML_UCX_VERBOSE(2, "connecting to proc. %d", proc->super.proc_name.vpid);
311
312 ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS;
313 ep_params.address = address;
314
315 status = ucp_ep_create(ompi_pml_ucx.ucp_worker, &ep_params, &ep);
316 free(address);
317 if (UCS_OK != status) {
318 PML_UCX_ERROR("ucp_ep_create(proc=%d) failed: %s",
319 proc->super.proc_name.vpid,
320 ucs_status_string(status));
321 return NULL;
322 }
323
324 proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML] = ep;
325 return ep;
326 }
327
mca_pml_ucx_add_proc(ompi_communicator_t * comm,int dst)328 static ucp_ep_h mca_pml_ucx_add_proc(ompi_communicator_t *comm, int dst)
329 {
330 ompi_proc_t *proc0 = ompi_comm_peer_lookup(comm, 0);
331 ompi_proc_t *proc_peer = ompi_comm_peer_lookup(comm, dst);
332 int ret;
333
334 /* Note, mca_pml_base_pml_check_selected, doesn't use 3rd argument */
335 if (OMPI_SUCCESS != (ret = mca_pml_base_pml_check_selected("ucx",
336 &proc0,
337 dst))) {
338 return NULL;
339 }
340
341 return mca_pml_ucx_add_proc_common(proc_peer);
342 }
343
mca_pml_ucx_add_procs(struct ompi_proc_t ** procs,size_t nprocs)344 int mca_pml_ucx_add_procs(struct ompi_proc_t **procs, size_t nprocs)
345 {
346 ompi_proc_t *proc;
347 ucp_ep_h ep;
348 size_t i;
349 int ret;
350
351 if (OMPI_SUCCESS != (ret = mca_pml_base_pml_check_selected("ucx",
352 procs,
353 nprocs))) {
354 return ret;
355 }
356
357 for (i = 0; i < nprocs; ++i) {
358 proc = procs[(i + OMPI_PROC_MY_NAME->vpid) % nprocs];
359 ep = mca_pml_ucx_add_proc_common(proc);
360 if (ep == NULL) {
361 return OMPI_ERROR;
362 }
363 }
364
365 return OMPI_SUCCESS;
366 }
367
mca_pml_ucx_get_ep(ompi_communicator_t * comm,int rank)368 static inline ucp_ep_h mca_pml_ucx_get_ep(ompi_communicator_t *comm, int rank)
369 {
370 ucp_ep_h ep;
371
372 ep = ompi_comm_peer_lookup(comm, rank)->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML];
373 if (OPAL_LIKELY(ep != NULL)) {
374 return ep;
375 }
376
377 ep = mca_pml_ucx_add_proc(comm, rank);
378 if (OPAL_LIKELY(ep != NULL)) {
379 return ep;
380 }
381
382 if (rank >= ompi_comm_size(comm)) {
383 PML_UCX_ERROR("Rank number (%d) is larger than communicator size (%d)",
384 rank, ompi_comm_size(comm));
385 } else {
386 PML_UCX_ERROR("Failed to resolve UCX endpoint for rank %d", rank);
387 }
388
389 return NULL;
390 }
391
mca_pml_ucx_waitall(void ** reqs,size_t * count_p)392 static void mca_pml_ucx_waitall(void **reqs, size_t *count_p)
393 {
394 ucs_status_t status;
395 size_t i;
396
397 PML_UCX_VERBOSE(2, "waiting for %d disconnect requests", (int)*count_p);
398 for (i = 0; i < *count_p; ++i) {
399 do {
400 opal_progress();
401 status = ucp_request_test(reqs[i], NULL);
402 } while (status == UCS_INPROGRESS);
403 if (status != UCS_OK) {
404 PML_UCX_ERROR("disconnect request failed: %s",
405 ucs_status_string(status));
406 }
407 ucp_request_free(reqs[i]);
408 reqs[i] = NULL;
409 }
410
411 *count_p = 0;
412 }
413
mca_pml_fence_complete_cb(int status,void * fenced)414 static void mca_pml_fence_complete_cb(int status, void *fenced)
415 {
416 *(int*)fenced = 1;
417 }
418
mca_pml_ucx_del_procs(struct ompi_proc_t ** procs,size_t nprocs)419 int mca_pml_ucx_del_procs(struct ompi_proc_t **procs, size_t nprocs)
420 {
421 int fenced = 0;
422 ompi_proc_t *proc;
423 size_t num_reqs, max_reqs;
424 void *dreq, **dreqs;
425 ucp_ep_h ep;
426 size_t i;
427
428 max_reqs = ompi_pml_ucx.num_disconnect;
429 if (max_reqs > nprocs) {
430 max_reqs = nprocs;
431 }
432
433 dreqs = malloc(sizeof(*dreqs) * max_reqs);
434 if (dreqs == NULL) {
435 return OMPI_ERR_OUT_OF_RESOURCE;
436 }
437
438 num_reqs = 0;
439
440 for (i = 0; i < nprocs; ++i) {
441 proc = procs[(i + OMPI_PROC_MY_NAME->vpid) % nprocs];
442 ep = proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML];
443 if (ep == NULL) {
444 continue;
445 }
446
447 PML_UCX_VERBOSE(2, "disconnecting from rank %d", proc->super.proc_name.vpid);
448 dreq = ucp_disconnect_nb(ep);
449 if (dreq != NULL) {
450 if (UCS_PTR_IS_ERR(dreq)) {
451 PML_UCX_ERROR("ucp_disconnect_nb(%d) failed: %s",
452 proc->super.proc_name.vpid,
453 ucs_status_string(UCS_PTR_STATUS(dreq)));
454 } else {
455 dreqs[num_reqs++] = dreq;
456 }
457 }
458
459 proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML] = NULL;
460
461 if ((int)num_reqs >= ompi_pml_ucx.num_disconnect) {
462 mca_pml_ucx_waitall(dreqs, &num_reqs);
463 }
464 }
465
466 mca_pml_ucx_waitall(dreqs, &num_reqs);
467 free(dreqs);
468
469 opal_pmix.fence_nb(NULL, 0, mca_pml_fence_complete_cb, &fenced);
470 while (!fenced) {
471 ucp_worker_progress(ompi_pml_ucx.ucp_worker);
472 }
473
474 return OMPI_SUCCESS;
475 }
476
mca_pml_ucx_enable(bool enable)477 int mca_pml_ucx_enable(bool enable)
478 {
479 ompi_attribute_fn_ptr_union_t copy_fn;
480 ompi_attribute_fn_ptr_union_t del_fn;
481 int ret;
482
483 /* Create a key for adding custom attributes to datatypes */
484 copy_fn.attr_datatype_copy_fn =
485 (MPI_Type_internal_copy_attr_function*)MPI_TYPE_NULL_COPY_FN;
486 del_fn.attr_datatype_delete_fn = mca_pml_ucx_datatype_attr_del_fn;
487 ret = ompi_attr_create_keyval(TYPE_ATTR, copy_fn, del_fn,
488 &ompi_pml_ucx.datatype_attr_keyval, NULL, 0,
489 NULL);
490 if (ret != OMPI_SUCCESS) {
491 PML_UCX_ERROR("Failed to create keyval for UCX datatypes: %d", ret);
492 return ret;
493 }
494
495 PML_UCX_FREELIST_INIT(&ompi_pml_ucx.persistent_reqs,
496 mca_pml_ucx_persistent_request_t,
497 128, -1, 128);
498 PML_UCX_FREELIST_INIT(&ompi_pml_ucx.convs,
499 mca_pml_ucx_convertor_t,
500 128, -1, 128);
501 return OMPI_SUCCESS;
502 }
503
mca_pml_ucx_progress(void)504 int mca_pml_ucx_progress(void)
505 {
506 ucp_worker_progress(ompi_pml_ucx.ucp_worker);
507 return OMPI_SUCCESS;
508 }
509
mca_pml_ucx_add_comm(struct ompi_communicator_t * comm)510 int mca_pml_ucx_add_comm(struct ompi_communicator_t* comm)
511 {
512 return OMPI_SUCCESS;
513 }
514
mca_pml_ucx_del_comm(struct ompi_communicator_t * comm)515 int mca_pml_ucx_del_comm(struct ompi_communicator_t* comm)
516 {
517 return OMPI_SUCCESS;
518 }
519
mca_pml_ucx_irecv_init(void * buf,size_t count,ompi_datatype_t * datatype,int src,int tag,struct ompi_communicator_t * comm,struct ompi_request_t ** request)520 int mca_pml_ucx_irecv_init(void *buf, size_t count, ompi_datatype_t *datatype,
521 int src, int tag, struct ompi_communicator_t* comm,
522 struct ompi_request_t **request)
523 {
524 mca_pml_ucx_persistent_request_t *req;
525
526 req = (mca_pml_ucx_persistent_request_t *)PML_UCX_FREELIST_GET(&ompi_pml_ucx.persistent_reqs);
527 if (req == NULL) {
528 return OMPI_ERR_OUT_OF_RESOURCE;
529 }
530
531 PML_UCX_TRACE_RECV("irecv_init request *%p=%p", buf, count, datatype, src,
532 tag, comm, (void*)request, (void*)req);
533
534 req->ompi.req_state = OMPI_REQUEST_INACTIVE;
535 req->flags = 0;
536 req->buffer = buf;
537 req->count = count;
538 req->datatype = mca_pml_ucx_get_datatype(datatype);
539
540 PML_UCX_MAKE_RECV_TAG(req->tag, req->recv.tag_mask, tag, src, comm);
541
542 *request = &req->ompi;
543 return OMPI_SUCCESS;
544 }
545
mca_pml_ucx_irecv(void * buf,size_t count,ompi_datatype_t * datatype,int src,int tag,struct ompi_communicator_t * comm,struct ompi_request_t ** request)546 int mca_pml_ucx_irecv(void *buf, size_t count, ompi_datatype_t *datatype,
547 int src, int tag, struct ompi_communicator_t* comm,
548 struct ompi_request_t **request)
549 {
550 ucp_tag_t ucp_tag, ucp_tag_mask;
551 ompi_request_t *req;
552
553 PML_UCX_TRACE_RECV("irecv request *%p", buf, count, datatype, src, tag, comm,
554 (void*)request);
555
556 PML_UCX_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, tag, src, comm);
557 req = (ompi_request_t*)ucp_tag_recv_nb(ompi_pml_ucx.ucp_worker, buf, count,
558 mca_pml_ucx_get_datatype(datatype),
559 ucp_tag, ucp_tag_mask,
560 mca_pml_ucx_recv_completion);
561 if (UCS_PTR_IS_ERR(req)) {
562 PML_UCX_ERROR("ucx recv failed: %s", ucs_status_string(UCS_PTR_STATUS(req)));
563 return OMPI_ERROR;
564 }
565
566 PML_UCX_VERBOSE(8, "got request %p", (void*)req);
567 *request = req;
568 return OMPI_SUCCESS;
569 }
570
mca_pml_ucx_recv(void * buf,size_t count,ompi_datatype_t * datatype,int src,int tag,struct ompi_communicator_t * comm,ompi_status_public_t * mpi_status)571 int mca_pml_ucx_recv(void *buf, size_t count, ompi_datatype_t *datatype, int src,
572 int tag, struct ompi_communicator_t* comm,
573 ompi_status_public_t* mpi_status)
574 {
575 ucp_tag_t ucp_tag, ucp_tag_mask;
576 ucp_tag_recv_info_t info;
577 ucs_status_t status;
578 void *req;
579
580 PML_UCX_TRACE_RECV("%s", buf, count, datatype, src, tag, comm, "recv");
581
582 PML_UCX_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, tag, src, comm);
583 req = PML_UCX_REQ_ALLOCA();
584 status = ucp_tag_recv_nbr(ompi_pml_ucx.ucp_worker, buf, count,
585 mca_pml_ucx_get_datatype(datatype),
586 ucp_tag, ucp_tag_mask, req);
587
588 for (;;) {
589 status = ucp_request_test(req, &info);
590 if (status != UCS_INPROGRESS) {
591 mca_pml_ucx_set_recv_status_safe(mpi_status, status, &info);
592 return OMPI_SUCCESS;
593 }
594 opal_progress();
595 }
596 }
597
mca_pml_ucx_send_mode_name(mca_pml_base_send_mode_t mode)598 static inline const char *mca_pml_ucx_send_mode_name(mca_pml_base_send_mode_t mode)
599 {
600 switch (mode) {
601 case MCA_PML_BASE_SEND_SYNCHRONOUS:
602 return "sync";
603 case MCA_PML_BASE_SEND_COMPLETE:
604 return "complete";
605 case MCA_PML_BASE_SEND_BUFFERED:
606 return "buffered";
607 case MCA_PML_BASE_SEND_READY:
608 return "ready";
609 case MCA_PML_BASE_SEND_STANDARD:
610 return "standard";
611 case MCA_PML_BASE_SEND_SIZE:
612 return "size";
613 default:
614 return "unknown";
615 }
616 }
617
mca_pml_ucx_isend_init(const void * buf,size_t count,ompi_datatype_t * datatype,int dst,int tag,mca_pml_base_send_mode_t mode,struct ompi_communicator_t * comm,struct ompi_request_t ** request)618 int mca_pml_ucx_isend_init(const void *buf, size_t count, ompi_datatype_t *datatype,
619 int dst, int tag, mca_pml_base_send_mode_t mode,
620 struct ompi_communicator_t* comm,
621 struct ompi_request_t **request)
622 {
623 mca_pml_ucx_persistent_request_t *req;
624 ucp_ep_h ep;
625
626 req = (mca_pml_ucx_persistent_request_t *)PML_UCX_FREELIST_GET(&ompi_pml_ucx.persistent_reqs);
627 if (req == NULL) {
628 return OMPI_ERR_OUT_OF_RESOURCE;
629 }
630
631 PML_UCX_TRACE_SEND("isend_init request *%p=%p", buf, count, datatype, dst,
632 tag, mode, comm, (void*)request, (void*)req)
633
634 ep = mca_pml_ucx_get_ep(comm, dst);
635 if (OPAL_UNLIKELY(NULL == ep)) {
636 return OMPI_ERROR;
637 }
638
639 req->ompi.req_state = OMPI_REQUEST_INACTIVE;
640 req->flags = MCA_PML_UCX_REQUEST_FLAG_SEND;
641 req->buffer = (void *)buf;
642 req->count = count;
643 req->tag = PML_UCX_MAKE_SEND_TAG(tag, comm);
644 req->send.mode = mode;
645 req->send.ep = ep;
646 if (MCA_PML_BASE_SEND_BUFFERED == mode) {
647 req->ompi_datatype = datatype;
648 OBJ_RETAIN(datatype);
649 } else {
650 req->datatype = mca_pml_ucx_get_datatype(datatype);
651 }
652
653 *request = &req->ompi;
654 return OMPI_SUCCESS;
655 }
656
657 static ucs_status_ptr_t
mca_pml_ucx_bsend(ucp_ep_h ep,const void * buf,size_t count,ompi_datatype_t * datatype,uint64_t pml_tag)658 mca_pml_ucx_bsend(ucp_ep_h ep, const void *buf, size_t count,
659 ompi_datatype_t *datatype, uint64_t pml_tag)
660 {
661 ompi_request_t *req;
662 void *packed_data;
663 size_t packed_length;
664 size_t offset;
665 uint32_t iov_count;
666 struct iovec iov;
667 opal_convertor_t opal_conv;
668
669 OBJ_CONSTRUCT(&opal_conv, opal_convertor_t);
670 opal_convertor_copy_and_prepare_for_send(ompi_proc_local_proc->super.proc_convertor,
671 &datatype->super, count, buf, 0,
672 &opal_conv);
673 opal_convertor_get_packed_size(&opal_conv, &packed_length);
674
675 packed_data = mca_pml_base_bsend_request_alloc_buf(packed_length);
676 if (OPAL_UNLIKELY(NULL == packed_data)) {
677 OBJ_DESTRUCT(&opal_conv);
678 PML_UCX_ERROR("bsend: failed to allocate buffer");
679 return UCS_STATUS_PTR(OMPI_ERROR);
680 }
681
682 iov_count = 1;
683 iov.iov_base = packed_data;
684 iov.iov_len = packed_length;
685
686 PML_UCX_VERBOSE(8, "bsend of packed buffer %p len %zu", packed_data, packed_length);
687 offset = 0;
688 opal_convertor_set_position(&opal_conv, &offset);
689 if (0 > opal_convertor_pack(&opal_conv, &iov, &iov_count, &packed_length)) {
690 mca_pml_base_bsend_request_free(packed_data);
691 OBJ_DESTRUCT(&opal_conv);
692 PML_UCX_ERROR("bsend: failed to pack user datatype");
693 return UCS_STATUS_PTR(OMPI_ERROR);
694 }
695
696 OBJ_DESTRUCT(&opal_conv);
697
698 req = (ompi_request_t*)ucp_tag_send_nb(ep, packed_data, packed_length,
699 ucp_dt_make_contig(1), pml_tag,
700 mca_pml_ucx_bsend_completion);
701 if (NULL == req) {
702 /* request was completed in place */
703 mca_pml_base_bsend_request_free(packed_data);
704 return NULL;
705 }
706
707 if (OPAL_UNLIKELY(UCS_PTR_IS_ERR(req))) {
708 mca_pml_base_bsend_request_free(packed_data);
709 PML_UCX_ERROR("ucx bsend failed: %s", ucs_status_string(UCS_PTR_STATUS(req)));
710 return UCS_STATUS_PTR(OMPI_ERROR);
711 }
712
713 req->req_complete_cb_data = packed_data;
714 return NULL;
715 }
716
mca_pml_ucx_common_send(ucp_ep_h ep,const void * buf,size_t count,ompi_datatype_t * datatype,ucp_datatype_t ucx_datatype,ucp_tag_t tag,mca_pml_base_send_mode_t mode,ucp_send_callback_t cb)717 static inline ucs_status_ptr_t mca_pml_ucx_common_send(ucp_ep_h ep, const void *buf,
718 size_t count,
719 ompi_datatype_t *datatype,
720 ucp_datatype_t ucx_datatype,
721 ucp_tag_t tag,
722 mca_pml_base_send_mode_t mode,
723 ucp_send_callback_t cb)
724 {
725 if (OPAL_UNLIKELY(MCA_PML_BASE_SEND_BUFFERED == mode)) {
726 return mca_pml_ucx_bsend(ep, buf, count, datatype, tag);
727 } else if (OPAL_UNLIKELY(MCA_PML_BASE_SEND_SYNCHRONOUS == mode)) {
728 return ucp_tag_send_sync_nb(ep, buf, count, ucx_datatype, tag, cb);
729 } else {
730 return ucp_tag_send_nb(ep, buf, count, ucx_datatype, tag, cb);
731 }
732 }
733
mca_pml_ucx_isend(const void * buf,size_t count,ompi_datatype_t * datatype,int dst,int tag,mca_pml_base_send_mode_t mode,struct ompi_communicator_t * comm,struct ompi_request_t ** request)734 int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype,
735 int dst, int tag, mca_pml_base_send_mode_t mode,
736 struct ompi_communicator_t* comm,
737 struct ompi_request_t **request)
738 {
739 ompi_request_t *req;
740 ucp_ep_h ep;
741
742 PML_UCX_TRACE_SEND("i%ssend request *%p",
743 buf, count, datatype, dst, tag, mode, comm,
744 mode == MCA_PML_BASE_SEND_BUFFERED ? "b" : "",
745 (void*)request)
746
747 ep = mca_pml_ucx_get_ep(comm, dst);
748 if (OPAL_UNLIKELY(NULL == ep)) {
749 return OMPI_ERROR;
750 }
751
752 req = (ompi_request_t*)mca_pml_ucx_common_send(ep, buf, count, datatype,
753 mca_pml_ucx_get_datatype(datatype),
754 PML_UCX_MAKE_SEND_TAG(tag, comm), mode,
755 mca_pml_ucx_send_completion);
756
757 if (req == NULL) {
758 PML_UCX_VERBOSE(8, "returning completed request");
759 *request = &ompi_pml_ucx.completed_send_req;
760 return OMPI_SUCCESS;
761 } else if (!UCS_PTR_IS_ERR(req)) {
762 PML_UCX_VERBOSE(8, "got request %p", (void*)req);
763 *request = req;
764 return OMPI_SUCCESS;
765 } else {
766 PML_UCX_ERROR("ucx send failed: %s", ucs_status_string(UCS_PTR_STATUS(req)));
767 return OMPI_ERROR;
768 }
769 }
770
771 static inline __opal_attribute_always_inline__ int
mca_pml_ucx_send_nb(ucp_ep_h ep,const void * buf,size_t count,ompi_datatype_t * datatype,ucp_datatype_t ucx_datatype,ucp_tag_t tag,mca_pml_base_send_mode_t mode,ucp_send_callback_t cb)772 mca_pml_ucx_send_nb(ucp_ep_h ep, const void *buf, size_t count,
773 ompi_datatype_t *datatype, ucp_datatype_t ucx_datatype,
774 ucp_tag_t tag, mca_pml_base_send_mode_t mode,
775 ucp_send_callback_t cb)
776 {
777 ompi_request_t *req;
778
779 req = (ompi_request_t*)mca_pml_ucx_common_send(ep, buf, count, datatype,
780 mca_pml_ucx_get_datatype(datatype),
781 tag, mode,
782 mca_pml_ucx_send_completion);
783
784 if (OPAL_LIKELY(req == NULL)) {
785 return OMPI_SUCCESS;
786 } else if (!UCS_PTR_IS_ERR(req)) {
787 PML_UCX_VERBOSE(8, "got request %p", (void*)req);
788 ucp_worker_progress(ompi_pml_ucx.ucp_worker);
789 ompi_request_wait(&req, MPI_STATUS_IGNORE);
790 return OMPI_SUCCESS;
791 } else {
792 PML_UCX_ERROR("ucx send failed: %s", ucs_status_string(UCS_PTR_STATUS(req)));
793 return OMPI_ERROR;
794 }
795 }
796
797 #if HAVE_DECL_UCP_TAG_SEND_NBR
798 static inline __opal_attribute_always_inline__ int
mca_pml_ucx_send_nbr(ucp_ep_h ep,const void * buf,size_t count,ucp_datatype_t ucx_datatype,ucp_tag_t tag)799 mca_pml_ucx_send_nbr(ucp_ep_h ep, const void *buf, size_t count,
800 ucp_datatype_t ucx_datatype, ucp_tag_t tag)
801
802 {
803 void *req;
804 ucs_status_t status;
805
806 req = PML_UCX_REQ_ALLOCA();
807 status = ucp_tag_send_nbr(ep, buf, count, ucx_datatype, tag, req);
808 if (OPAL_LIKELY(status == UCS_OK)) {
809 return OMPI_SUCCESS;
810 }
811
812 ucp_worker_progress(ompi_pml_ucx.ucp_worker);
813 while ((status = ucp_request_check_status(req)) == UCS_INPROGRESS) {
814 opal_progress();
815 }
816
817 return OPAL_LIKELY(UCS_OK == status) ? OMPI_SUCCESS : OMPI_ERROR;
818 }
819 #endif
820
mca_pml_ucx_send(const void * buf,size_t count,ompi_datatype_t * datatype,int dst,int tag,mca_pml_base_send_mode_t mode,struct ompi_communicator_t * comm)821 int mca_pml_ucx_send(const void *buf, size_t count, ompi_datatype_t *datatype, int dst,
822 int tag, mca_pml_base_send_mode_t mode,
823 struct ompi_communicator_t* comm)
824 {
825 ucp_ep_h ep;
826
827 PML_UCX_TRACE_SEND("%s", buf, count, datatype, dst, tag, mode, comm,
828 mode == MCA_PML_BASE_SEND_BUFFERED ? "bsend" : "send");
829
830 ep = mca_pml_ucx_get_ep(comm, dst);
831 if (OPAL_UNLIKELY(NULL == ep)) {
832 return OMPI_ERROR;
833 }
834
835 #if HAVE_DECL_UCP_TAG_SEND_NBR
836 if (OPAL_LIKELY((MCA_PML_BASE_SEND_BUFFERED != mode) &&
837 (MCA_PML_BASE_SEND_SYNCHRONOUS != mode))) {
838 return mca_pml_ucx_send_nbr(ep, buf, count,
839 mca_pml_ucx_get_datatype(datatype),
840 PML_UCX_MAKE_SEND_TAG(tag, comm));
841 }
842 #endif
843
844 return mca_pml_ucx_send_nb(ep, buf, count, datatype,
845 mca_pml_ucx_get_datatype(datatype),
846 PML_UCX_MAKE_SEND_TAG(tag, comm), mode,
847 mca_pml_ucx_send_completion);
848 }
849
mca_pml_ucx_iprobe(int src,int tag,struct ompi_communicator_t * comm,int * matched,ompi_status_public_t * mpi_status)850 int mca_pml_ucx_iprobe(int src, int tag, struct ompi_communicator_t* comm,
851 int *matched, ompi_status_public_t* mpi_status)
852 {
853 ucp_tag_t ucp_tag, ucp_tag_mask;
854 ucp_tag_recv_info_t info;
855 ucp_tag_message_h ucp_msg;
856
857 PML_UCX_TRACE_PROBE("iprobe", src, tag, comm);
858
859 PML_UCX_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, tag, src, comm);
860 ucp_msg = ucp_tag_probe_nb(ompi_pml_ucx.ucp_worker, ucp_tag, ucp_tag_mask,
861 0, &info);
862 if (ucp_msg != NULL) {
863 *matched = 1;
864 mca_pml_ucx_set_recv_status_safe(mpi_status, UCS_OK, &info);
865 } else {
866 opal_progress();
867 *matched = 0;
868 }
869 return OMPI_SUCCESS;
870 }
871
mca_pml_ucx_probe(int src,int tag,struct ompi_communicator_t * comm,ompi_status_public_t * mpi_status)872 int mca_pml_ucx_probe(int src, int tag, struct ompi_communicator_t* comm,
873 ompi_status_public_t* mpi_status)
874 {
875 ucp_tag_t ucp_tag, ucp_tag_mask;
876 ucp_tag_recv_info_t info;
877 ucp_tag_message_h ucp_msg;
878
879 PML_UCX_TRACE_PROBE("probe", src, tag, comm);
880
881 PML_UCX_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, tag, src, comm);
882 for (;;) {
883 ucp_msg = ucp_tag_probe_nb(ompi_pml_ucx.ucp_worker, ucp_tag, ucp_tag_mask,
884 0, &info);
885 if (ucp_msg != NULL) {
886 mca_pml_ucx_set_recv_status_safe(mpi_status, UCS_OK, &info);
887 return OMPI_SUCCESS;
888 }
889
890 opal_progress();
891 }
892 }
893
mca_pml_ucx_improbe(int src,int tag,struct ompi_communicator_t * comm,int * matched,struct ompi_message_t ** message,ompi_status_public_t * mpi_status)894 int mca_pml_ucx_improbe(int src, int tag, struct ompi_communicator_t* comm,
895 int *matched, struct ompi_message_t **message,
896 ompi_status_public_t* mpi_status)
897 {
898 ucp_tag_t ucp_tag, ucp_tag_mask;
899 ucp_tag_recv_info_t info;
900 ucp_tag_message_h ucp_msg;
901
902 PML_UCX_TRACE_PROBE("improbe", src, tag, comm);
903
904 PML_UCX_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, tag, src, comm);
905 ucp_msg = ucp_tag_probe_nb(ompi_pml_ucx.ucp_worker, ucp_tag, ucp_tag_mask,
906 1, &info);
907 if (ucp_msg != NULL) {
908 PML_UCX_MESSAGE_NEW(comm, ucp_msg, &info, message);
909 PML_UCX_VERBOSE(8, "got message %p (%p)", (void*)*message, (void*)ucp_msg);
910 *matched = 1;
911 mca_pml_ucx_set_recv_status_safe(mpi_status, UCS_OK, &info);
912 } else {
913 opal_progress();
914 *matched = 0;
915 }
916 return OMPI_SUCCESS;
917 }
918
mca_pml_ucx_mprobe(int src,int tag,struct ompi_communicator_t * comm,struct ompi_message_t ** message,ompi_status_public_t * mpi_status)919 int mca_pml_ucx_mprobe(int src, int tag, struct ompi_communicator_t* comm,
920 struct ompi_message_t **message,
921 ompi_status_public_t* mpi_status)
922 {
923 ucp_tag_t ucp_tag, ucp_tag_mask;
924 ucp_tag_recv_info_t info;
925 ucp_tag_message_h ucp_msg;
926
927 PML_UCX_TRACE_PROBE("mprobe", src, tag, comm);
928
929 PML_UCX_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, tag, src, comm);
930 for (;;) {
931 ucp_msg = ucp_tag_probe_nb(ompi_pml_ucx.ucp_worker, ucp_tag, ucp_tag_mask,
932 1, &info);
933 if (ucp_msg != NULL) {
934 PML_UCX_MESSAGE_NEW(comm, ucp_msg, &info, message);
935 PML_UCX_VERBOSE(8, "got message %p (%p)", (void*)*message, (void*)ucp_msg);
936 mca_pml_ucx_set_recv_status_safe(mpi_status, UCS_OK, &info);
937 return OMPI_SUCCESS;
938 }
939
940 opal_progress();
941 }
942 }
943
mca_pml_ucx_imrecv(void * buf,size_t count,ompi_datatype_t * datatype,struct ompi_message_t ** message,struct ompi_request_t ** request)944 int mca_pml_ucx_imrecv(void *buf, size_t count, ompi_datatype_t *datatype,
945 struct ompi_message_t **message,
946 struct ompi_request_t **request)
947 {
948 ompi_request_t *req;
949
950 PML_UCX_TRACE_MRECV("imrecv", buf, count, datatype, message);
951
952 req = (ompi_request_t*)ucp_tag_msg_recv_nb(ompi_pml_ucx.ucp_worker, buf, count,
953 mca_pml_ucx_get_datatype(datatype),
954 (*message)->req_ptr,
955 mca_pml_ucx_recv_completion);
956 if (UCS_PTR_IS_ERR(req)) {
957 PML_UCX_ERROR("ucx msg recv failed: %s", ucs_status_string(UCS_PTR_STATUS(req)));
958 return OMPI_ERROR;
959 }
960
961 PML_UCX_VERBOSE(8, "got request %p", (void*)req);
962 PML_UCX_MESSAGE_RELEASE(message);
963 *request = req;
964 return OMPI_SUCCESS;
965 }
966
mca_pml_ucx_mrecv(void * buf,size_t count,ompi_datatype_t * datatype,struct ompi_message_t ** message,ompi_status_public_t * status)967 int mca_pml_ucx_mrecv(void *buf, size_t count, ompi_datatype_t *datatype,
968 struct ompi_message_t **message,
969 ompi_status_public_t* status)
970 {
971 ompi_request_t *req;
972
973 PML_UCX_TRACE_MRECV("mrecv", buf, count, datatype, message);
974
975 req = (ompi_request_t*)ucp_tag_msg_recv_nb(ompi_pml_ucx.ucp_worker, buf, count,
976 mca_pml_ucx_get_datatype(datatype),
977 (*message)->req_ptr,
978 mca_pml_ucx_recv_completion);
979 if (UCS_PTR_IS_ERR(req)) {
980 PML_UCX_ERROR("ucx msg recv failed: %s", ucs_status_string(UCS_PTR_STATUS(req)));
981 return OMPI_ERROR;
982 }
983
984 PML_UCX_MESSAGE_RELEASE(message);
985
986 ompi_request_wait(&req, status);
987 return OMPI_SUCCESS;
988 }
989
mca_pml_ucx_start(size_t count,ompi_request_t ** requests)990 int mca_pml_ucx_start(size_t count, ompi_request_t** requests)
991 {
992 mca_pml_ucx_persistent_request_t *preq;
993 ompi_request_t *tmp_req;
994 size_t i;
995
996 for (i = 0; i < count; ++i) {
997 preq = (mca_pml_ucx_persistent_request_t *)requests[i];
998
999 if ((preq == NULL) || (OMPI_REQUEST_PML != preq->ompi.req_type)) {
1000 /* Skip irrelevant requests */
1001 continue;
1002 }
1003
1004 PML_UCX_ASSERT(preq->ompi.req_state != OMPI_REQUEST_INVALID);
1005 preq->ompi.req_state = OMPI_REQUEST_ACTIVE;
1006 mca_pml_ucx_request_reset(&preq->ompi);
1007
1008 if (preq->flags & MCA_PML_UCX_REQUEST_FLAG_SEND) {
1009 tmp_req = (ompi_request_t*)mca_pml_ucx_common_send(preq->send.ep,
1010 preq->buffer,
1011 preq->count,
1012 preq->ompi_datatype,
1013 preq->datatype,
1014 preq->tag,
1015 preq->send.mode,
1016 mca_pml_ucx_psend_completion);
1017 } else {
1018 PML_UCX_VERBOSE(8, "start recv request %p", (void*)preq);
1019 tmp_req = (ompi_request_t*)ucp_tag_recv_nb(ompi_pml_ucx.ucp_worker,
1020 preq->buffer, preq->count,
1021 preq->datatype, preq->tag,
1022 preq->recv.tag_mask,
1023 mca_pml_ucx_precv_completion);
1024 }
1025
1026 if (tmp_req == NULL) {
1027 /* Only send can complete immediately */
1028 PML_UCX_ASSERT(preq->flags & MCA_PML_UCX_REQUEST_FLAG_SEND);
1029
1030 PML_UCX_VERBOSE(8, "send completed immediately, completing persistent request %p",
1031 (void*)preq);
1032 mca_pml_ucx_set_send_status(&preq->ompi.req_status, UCS_OK);
1033 ompi_request_complete(&preq->ompi, true);
1034 } else if (!UCS_PTR_IS_ERR(tmp_req)) {
1035 if (REQUEST_COMPLETE(tmp_req)) {
1036 /* tmp_req is already completed */
1037 PML_UCX_VERBOSE(8, "completing persistent request %p", (void*)preq);
1038 mca_pml_ucx_persistent_request_complete(preq, tmp_req);
1039 } else {
1040 /* tmp_req would be completed by callback and trigger completion
1041 * of preq */
1042 PML_UCX_VERBOSE(8, "temporary request %p will complete persistent request %p",
1043 (void*)tmp_req, (void*)preq);
1044 tmp_req->req_complete_cb_data = preq;
1045 preq->tmp_req = tmp_req;
1046 }
1047 } else {
1048 PML_UCX_ERROR("ucx %s failed: %s",
1049 (preq->flags & MCA_PML_UCX_REQUEST_FLAG_SEND) ? "send" : "recv",
1050 ucs_status_string(UCS_PTR_STATUS(tmp_req)));
1051 return OMPI_ERROR;
1052 }
1053 }
1054
1055 return OMPI_SUCCESS;
1056 }
1057
mca_pml_ucx_dump(struct ompi_communicator_t * comm,int verbose)1058 int mca_pml_ucx_dump(struct ompi_communicator_t* comm, int verbose)
1059 {
1060 return OMPI_SUCCESS;
1061 }
1062