1 /**
2  * Copyright (C) Mellanox Technologies Ltd. 2001-2019.  ALL RIGHTS RESERVED.
3  *
4  * See file LICENSE for terms.
5  */
6 
7 #ifdef HAVE_CONFIG_H
8 #  include "config.h"
9 #endif
10 
11 #include "rndv.h"
12 #include "tag_match.inl"
13 #include "offload.h"
14 
15 #include <ucp/proto/proto_am.inl>
16 #include <ucs/datastruct/queue.h>
17 
ucp_rndv_is_recv_pipeline_needed(ucp_request_t * rndv_req,const ucp_rndv_rts_hdr_t * rndv_rts_hdr,ucs_memory_type_t mem_type,int is_get_zcopy_failed)18 static int ucp_rndv_is_recv_pipeline_needed(ucp_request_t *rndv_req,
19                                             const ucp_rndv_rts_hdr_t *rndv_rts_hdr,
20                                             ucs_memory_type_t mem_type,
21                                             int is_get_zcopy_failed)
22 {
23     const ucp_ep_config_t *ep_config = ucp_ep_config(rndv_req->send.ep);
24     ucp_context_h context            = rndv_req->send.ep->worker->context;
25     int found                        = 0;
26     ucp_md_index_t md_index;
27     uct_md_attr_t *md_attr;
28     uint64_t mem_types;
29     int i;
30 
31     for (i = 0;
32          (i < UCP_MAX_LANES) &&
33          (ep_config->key.rma_bw_lanes[i] != UCP_NULL_LANE); i++) {
34         md_index = ep_config->md_index[ep_config->key.rma_bw_lanes[i]];
35         if (context->tl_mds[md_index].attr.cap.access_mem_type == UCS_MEMORY_TYPE_HOST) {
36             found = 1;
37             break;
38         }
39     }
40 
41     /* no host bw lanes for pipeline staging */
42     if (!found) {
43         return 0;
44     }
45 
46     if (is_get_zcopy_failed) {
47         return 1;
48     }
49 
50     /* disqualify recv side pipeline if
51      * a mem_type bw lane exist AND
52      * lane can do RMA on remote mem_type
53      */
54     mem_types = UCS_BIT(mem_type);
55     if (rndv_rts_hdr->address) {
56         mem_types |= UCS_BIT(ucp_rkey_packed_mem_type(rndv_rts_hdr + 1));
57     }
58 
59     ucs_for_each_bit(md_index, ep_config->key.rma_bw_md_map) {
60         md_attr = &context->tl_mds[md_index].attr;
61         if (ucs_test_all_flags(md_attr->cap.reg_mem_types, mem_types)) {
62             return 0;
63         }
64     }
65 
66     return 1;
67 }
68 
ucp_rndv_is_put_pipeline_needed(uintptr_t remote_address,size_t length,size_t min_get_zcopy,size_t max_get_zcopy,int is_get_zcopy_failed)69 static int ucp_rndv_is_put_pipeline_needed(uintptr_t remote_address,
70                                            size_t length, size_t min_get_zcopy,
71                                            size_t max_get_zcopy,
72                                            int is_get_zcopy_failed)
73 {
74     /* fallback to PUT pipeline if remote mem type is non-HOST memory OR
75      * can't do GET ZCOPY */
76     return ((remote_address == 0) || (max_get_zcopy == 0) ||
77             (length < min_get_zcopy) || is_get_zcopy_failed);
78 }
79 
ucp_tag_rndv_rts_pack(void * dest,void * arg)80 size_t ucp_tag_rndv_rts_pack(void *dest, void *arg)
81 {
82     ucp_request_t *sreq              = arg;   /* send request */
83     ucp_rndv_rts_hdr_t *rndv_rts_hdr = dest;
84     ucp_worker_h worker              = sreq->send.ep->worker;
85     ssize_t packed_rkey_size;
86 
87     rndv_rts_hdr->super.tag        = sreq->send.msg_proto.tag.tag;
88     rndv_rts_hdr->sreq.reqptr      = (uintptr_t)sreq;
89     rndv_rts_hdr->sreq.ep_ptr      = ucp_request_get_dest_ep_ptr(sreq);
90     rndv_rts_hdr->size             = sreq->send.length;
91 
92     /* Pack remote keys (which can be empty list) */
93     if (UCP_DT_IS_CONTIG(sreq->send.datatype) &&
94         ucp_rndv_is_get_zcopy(sreq, worker->context)) {
95         /* pack rkey, ask target to do get_zcopy */
96         rndv_rts_hdr->address = (uintptr_t)sreq->send.buffer;
97         packed_rkey_size = ucp_rkey_pack_uct(worker->context,
98                                              sreq->send.state.dt.dt.contig.md_map,
99                                              sreq->send.state.dt.dt.contig.memh,
100                                              sreq->send.mem_type,
101                                              rndv_rts_hdr + 1);
102         if (packed_rkey_size < 0) {
103             ucs_fatal("failed to pack rendezvous remote key: %s",
104                       ucs_status_string((ucs_status_t)packed_rkey_size));
105         }
106 
107         ucs_assert(packed_rkey_size <=
108                    ucp_ep_config(sreq->send.ep)->tag.rndv.rkey_size);
109     } else {
110         rndv_rts_hdr->address = 0;
111         packed_rkey_size      = 0;
112     }
113 
114     return sizeof(*rndv_rts_hdr) + packed_rkey_size;
115 }
116 
117 UCS_PROFILE_FUNC(ucs_status_t, ucp_proto_progress_rndv_rts, (self),
118                  uct_pending_req_t *self)
119 {
120     ucp_request_t *sreq = ucs_container_of(self, ucp_request_t, send.uct);
121     size_t packed_rkey_size;
122 
123     /* send the RTS. the pack_cb will pack all the necessary fields in the RTS */
124     packed_rkey_size = ucp_ep_config(sreq->send.ep)->tag.rndv.rkey_size;
125     return ucp_do_am_single(self, UCP_AM_ID_RNDV_RTS, ucp_tag_rndv_rts_pack,
126                             sizeof(ucp_rndv_rts_hdr_t) + packed_rkey_size);
127 }
128 
ucp_tag_rndv_rtr_pack(void * dest,void * arg)129 static size_t ucp_tag_rndv_rtr_pack(void *dest, void *arg)
130 {
131     ucp_request_t *rndv_req          = arg;
132     ucp_rndv_rtr_hdr_t *rndv_rtr_hdr = dest;
133     ucp_request_t *rreq              = rndv_req->send.rndv_rtr.rreq;
134     ssize_t packed_rkey_size;
135 
136     rndv_rtr_hdr->sreq_ptr = rndv_req->send.rndv_rtr.remote_request;
137     rndv_rtr_hdr->rreq_ptr = (uintptr_t)rreq; /* request of receiver side */
138 
139     /* Pack remote keys (which can be empty list) */
140     if (UCP_DT_IS_CONTIG(rreq->recv.datatype)) {
141         rndv_rtr_hdr->address = (uintptr_t)rreq->recv.buffer;
142         rndv_rtr_hdr->size    = rndv_req->send.rndv_rtr.length;
143         rndv_rtr_hdr->offset  = rndv_req->send.rndv_rtr.offset;
144 
145         packed_rkey_size = ucp_rkey_pack_uct(rndv_req->send.ep->worker->context,
146                                              rreq->recv.state.dt.contig.md_map,
147                                              rreq->recv.state.dt.contig.memh,
148                                              rreq->recv.mem_type,
149                                              rndv_rtr_hdr + 1);
150         if (packed_rkey_size < 0) {
151             return packed_rkey_size;
152         }
153     } else {
154         rndv_rtr_hdr->address = 0;
155         rndv_rtr_hdr->size    = 0;
156         rndv_rtr_hdr->offset  = 0;
157         packed_rkey_size      = 0;
158     }
159 
160     return sizeof(*rndv_rtr_hdr) + packed_rkey_size;
161 }
162 
163 UCS_PROFILE_FUNC(ucs_status_t, ucp_proto_progress_rndv_rtr, (self),
164                  uct_pending_req_t *self)
165 {
166     ucp_request_t *rndv_req = ucs_container_of(self, ucp_request_t, send.uct);
167     size_t packed_rkey_size;
168     ucs_status_t status;
169 
170     /* send the RTR. the pack_cb will pack all the necessary fields in the RTR */
171     packed_rkey_size = ucp_ep_config(rndv_req->send.ep)->tag.rndv.rkey_size;
172     status = ucp_do_am_single(self, UCP_AM_ID_RNDV_RTR, ucp_tag_rndv_rtr_pack,
173                               sizeof(ucp_rndv_rtr_hdr_t) + packed_rkey_size);
174     if (status == UCS_OK) {
175         /* release rndv request */
176         ucp_request_put(rndv_req);
177     }
178 
179     return status;
180 }
181 
ucp_tag_rndv_reg_send_buffer(ucp_request_t * sreq)182 ucs_status_t ucp_tag_rndv_reg_send_buffer(ucp_request_t *sreq)
183 {
184     ucp_ep_h ep = sreq->send.ep;
185     ucp_md_map_t md_map;
186     ucs_status_t status;
187 
188     if (UCP_DT_IS_CONTIG(sreq->send.datatype) &&
189         ucp_rndv_is_get_zcopy(sreq, ep->worker->context)) {
190 
191         /* register a contiguous buffer for rma_get */
192         md_map = ucp_ep_config(ep)->key.rma_bw_md_map;
193 
194         /* Pass UCT_MD_MEM_FLAG_HIDE_ERRORS flag, because registration may fail
195          * if md does not support send memory type (e.g. CUDA memory). In this
196          * case RTS will be sent with empty key, and sender will fallback to
197          * PUT or pipeline protocols. */
198         status = ucp_request_send_buffer_reg(sreq, md_map,
199                                              UCT_MD_MEM_FLAG_HIDE_ERRORS);
200         if (status != UCS_OK) {
201             return status;
202         }
203     }
204 
205     return UCS_OK;
206 }
207 
ucp_tag_send_start_rndv(ucp_request_t * sreq)208 ucs_status_t ucp_tag_send_start_rndv(ucp_request_t *sreq)
209 {
210     ucp_ep_h ep = sreq->send.ep;
211     ucs_status_t status;
212 
213     ucp_trace_req(sreq, "start_rndv to %s buffer %p length %zu",
214                   ucp_ep_peer_name(ep), sreq->send.buffer,
215                   sreq->send.length);
216     UCS_PROFILE_REQUEST_EVENT(sreq, "start_rndv", sreq->send.length);
217 
218     status = ucp_ep_resolve_dest_ep_ptr(ep, sreq->send.lane);
219     if (status != UCS_OK) {
220         return status;
221     }
222 
223     if (ucp_ep_is_tag_offload_enabled(ucp_ep_config(ep))) {
224         status = ucp_tag_offload_start_rndv(sreq);
225     } else {
226         ucs_assert(sreq->send.lane == ucp_ep_get_am_lane(ep));
227         sreq->send.uct.func = ucp_proto_progress_rndv_rts;
228         status              = ucp_tag_rndv_reg_send_buffer(sreq);
229     }
230 
231     return status;
232 }
233 
234 static UCS_F_ALWAYS_INLINE size_t
ucp_rndv_adjust_zcopy_length(size_t min_zcopy,size_t max_zcopy,size_t align,size_t send_length,size_t offset,size_t length)235 ucp_rndv_adjust_zcopy_length(size_t min_zcopy, size_t max_zcopy, size_t align,
236                              size_t send_length, size_t offset, size_t length)
237 {
238     size_t result_length, tail;
239 
240     ucs_assert(length > 0);
241 
242     /* ensure that the current length is over min_zcopy */
243     result_length = ucs_max(length, min_zcopy);
244 
245     /* ensure that the current length is less than max_zcopy */
246     result_length = ucs_min(result_length, max_zcopy);
247 
248     /* ensure that tail (rest of message) is over min_zcopy */
249     ucs_assertv(send_length >= (offset + result_length),
250                 "send_length=%zu, offset=%zu, length=%zu",
251                 send_length, offset, result_length);
252     tail = send_length - (offset + result_length);
253     if (ucs_unlikely((tail != 0) && (tail < min_zcopy))) {
254         /* ok, tail is less zcopy minimal & could not be processed as
255          * standalone operation */
256         /* check if we have room to increase current part and not
257          * step over max_zcopy */
258         if (result_length < (max_zcopy - tail)) {
259             /* if we can increase length by min_zcopy - let's do it to
260              * avoid small tail (we have limitation on minimal get zcopy) */
261             result_length += tail;
262         } else {
263             /* reduce current length by align or min_zcopy value
264              * to process it on next round */
265             ucs_assert(result_length > ucs_max(min_zcopy, align));
266             result_length -= ucs_max(min_zcopy, align);
267         }
268     }
269 
270     ucs_assertv(result_length >= min_zcopy, "length=%zu, min_zcopy=%zu",
271                 result_length, min_zcopy);
272     ucs_assertv(((send_length - (offset + result_length)) == 0) ||
273                 ((send_length - (offset + result_length)) >= min_zcopy),
274                 "send_length=%zu, offset=%zu, length=%zu, min_zcopy=%zu",
275                 send_length, offset, result_length, min_zcopy);
276 
277     return result_length;
278 }
279 
ucp_rndv_complete_send(ucp_request_t * sreq,ucs_status_t status)280 static void ucp_rndv_complete_send(ucp_request_t *sreq, ucs_status_t status)
281 {
282     ucp_request_send_generic_dt_finish(sreq);
283     ucp_request_send_buffer_dereg(sreq);
284     ucp_request_complete_send(sreq, status);
285 }
286 
ucp_rndv_req_send_ats(ucp_request_t * rndv_req,ucp_request_t * rreq,uintptr_t remote_request,ucs_status_t status)287 static void ucp_rndv_req_send_ats(ucp_request_t *rndv_req, ucp_request_t *rreq,
288                                   uintptr_t remote_request, ucs_status_t status)
289 {
290     ucp_trace_req(rndv_req, "send ats remote_request 0x%lx", remote_request);
291     UCS_PROFILE_REQUEST_EVENT(rreq, "send_ats", 0);
292 
293     rndv_req->send.lane                 = ucp_ep_get_am_lane(rndv_req->send.ep);
294     rndv_req->send.uct.func             = ucp_proto_progress_am_single;
295     rndv_req->send.proto.am_id          = UCP_AM_ID_RNDV_ATS;
296     rndv_req->send.proto.status         = status;
297     rndv_req->send.proto.remote_request = remote_request;
298     rndv_req->send.proto.comp_cb        = ucp_request_put;
299 
300     ucp_request_send(rndv_req, 0);
301 }
302 
303 UCS_PROFILE_FUNC_VOID(ucp_rndv_complete_rma_put_zcopy, (sreq),
304                       ucp_request_t *sreq)
305 {
306     ucp_trace_req(sreq, "rndv_put completed");
307     UCS_PROFILE_REQUEST_EVENT(sreq, "complete_rndv_put", 0);
308 
309     ucp_request_send_buffer_dereg(sreq);
310     ucp_request_complete_send(sreq, UCS_OK);
311 }
312 
ucp_rndv_send_atp(ucp_request_t * sreq,uintptr_t remote_request)313 static void ucp_rndv_send_atp(ucp_request_t *sreq, uintptr_t remote_request)
314 {
315     ucs_assertv(sreq->send.state.dt.offset == sreq->send.length,
316                 "sreq=%p offset=%zu length=%zu", sreq,
317                 sreq->send.state.dt.offset, sreq->send.length);
318 
319     ucp_trace_req(sreq, "send atp remote_request 0x%lx", remote_request);
320     UCS_PROFILE_REQUEST_EVENT(sreq, "send_atp", 0);
321 
322     /* destroy rkey before it gets overridden by ATP protocol data */
323     ucp_rkey_destroy(sreq->send.rndv_put.rkey);
324 
325     sreq->send.lane                 = ucp_ep_get_am_lane(sreq->send.ep);
326     sreq->send.uct.func             = ucp_proto_progress_am_single;
327     sreq->send.proto.am_id          = UCP_AM_ID_RNDV_ATP;
328     sreq->send.proto.status         = UCS_OK;
329     sreq->send.proto.remote_request = remote_request;
330     sreq->send.proto.comp_cb        = ucp_rndv_complete_rma_put_zcopy;
331 
332     ucp_request_send(sreq, 0);
333 }
334 
335 UCS_PROFILE_FUNC_VOID(ucp_rndv_complete_frag_rma_put_zcopy, (fsreq),
336                       ucp_request_t *fsreq)
337 {
338     ucp_request_t *sreq = fsreq->send.proto.sreq;
339 
340     sreq->send.state.dt.offset += fsreq->send.length;
341 
342     /* delete fragments send request */
343     ucp_request_put(fsreq);
344 
345     /* complete send request after put completions of all fragments */
346     if (sreq->send.state.dt.offset == sreq->send.length) {
347         ucp_rndv_complete_rma_put_zcopy(sreq);
348     }
349 }
350 
ucp_rndv_send_frag_atp(ucp_request_t * fsreq,uintptr_t remote_request)351 static void ucp_rndv_send_frag_atp(ucp_request_t *fsreq, uintptr_t remote_request)
352 {
353     ucp_trace_req(fsreq, "send frag atp remote_request 0x%lx", remote_request);
354     UCS_PROFILE_REQUEST_EVENT(fsreq, "send_frag_atp", 0);
355 
356     /* destroy rkey before it gets overridden by ATP protocol data */
357     ucp_rkey_destroy(fsreq->send.rndv_put.rkey);
358 
359     fsreq->send.lane                 = ucp_ep_get_am_lane(fsreq->send.ep);
360     fsreq->send.uct.func             = ucp_proto_progress_am_single;
361     fsreq->send.proto.sreq           = fsreq->send.rndv_put.sreq;
362     fsreq->send.proto.am_id          = UCP_AM_ID_RNDV_ATP;
363     fsreq->send.proto.status         = UCS_OK;
364     fsreq->send.proto.remote_request = remote_request;
365     fsreq->send.proto.comp_cb        = ucp_rndv_complete_frag_rma_put_zcopy;
366 
367     ucp_request_send(fsreq, 0);
368 }
369 
ucp_rndv_zcopy_recv_req_complete(ucp_request_t * req,ucs_status_t status)370 static void ucp_rndv_zcopy_recv_req_complete(ucp_request_t *req, ucs_status_t status)
371 {
372     ucp_request_recv_buffer_dereg(req);
373     ucp_request_complete_tag_recv(req, status);
374 }
375 
ucp_rndv_complete_rma_get_zcopy(ucp_request_t * rndv_req,ucs_status_t status)376 static void ucp_rndv_complete_rma_get_zcopy(ucp_request_t *rndv_req,
377                                             ucs_status_t status)
378 {
379     ucp_request_t *rreq = rndv_req->send.rndv_get.rreq;
380 
381     ucs_assertv(rndv_req->send.state.dt.offset == rndv_req->send.length,
382                 "rndv_req=%p offset=%zu length=%zu", rndv_req,
383                 rndv_req->send.state.dt.offset, rndv_req->send.length);
384 
385     ucp_trace_req(rndv_req, "rndv_get completed with status %s",
386                   ucs_status_string(status));
387     UCS_PROFILE_REQUEST_EVENT(rreq, "complete_rndv_get", 0);
388 
389     ucp_rkey_destroy(rndv_req->send.rndv_get.rkey);
390     ucp_request_send_buffer_dereg(rndv_req);
391 
392     if (status == UCS_OK) {
393         ucp_rndv_req_send_ats(rndv_req, rreq,
394                               rndv_req->send.rndv_get.remote_request, UCS_OK);
395     } else {
396         /* if completing RNDV with the error, just release RNDV request */
397         ucp_request_put(rndv_req);
398     }
399 
400     ucp_rndv_zcopy_recv_req_complete(rreq, status);
401 }
402 
ucp_rndv_recv_data_init(ucp_request_t * rreq,size_t size)403 static void ucp_rndv_recv_data_init(ucp_request_t *rreq, size_t size)
404 {
405     rreq->status             = UCS_OK;
406     rreq->recv.tag.remaining = size;
407 }
408 
ucp_rndv_req_send_rtr(ucp_request_t * rndv_req,ucp_request_t * rreq,uintptr_t sender_reqptr,size_t recv_length,size_t offset)409 static void ucp_rndv_req_send_rtr(ucp_request_t *rndv_req, ucp_request_t *rreq,
410                                   uintptr_t sender_reqptr, size_t recv_length,
411                                   size_t offset)
412 {
413     ucp_trace_req(rndv_req, "send rtr remote sreq 0x%lx rreq %p", sender_reqptr,
414                   rreq);
415 
416     rndv_req->send.lane                    = ucp_ep_get_am_lane(rndv_req->send.ep);
417     rndv_req->send.uct.func                = ucp_proto_progress_rndv_rtr;
418     rndv_req->send.rndv_rtr.remote_request = sender_reqptr;
419     rndv_req->send.rndv_rtr.rreq           = rreq;
420     rndv_req->send.rndv_rtr.length         = recv_length;
421     rndv_req->send.rndv_rtr.offset         = offset;
422 
423     ucp_request_send(rndv_req, 0);
424 }
425 
426 static ucp_lane_index_t
ucp_rndv_get_zcopy_get_lane(ucp_request_t * rndv_req,uct_rkey_t * uct_rkey)427 ucp_rndv_get_zcopy_get_lane(ucp_request_t *rndv_req, uct_rkey_t *uct_rkey)
428 {
429     ucp_lane_index_t lane_idx;
430     ucp_ep_config_t *ep_config;
431     ucp_rkey_h rkey;
432     uint8_t rkey_index;
433 
434     if (ucs_unlikely(!rndv_req->send.rndv_get.lanes_map_all)) {
435         return UCP_NULL_LANE;
436     }
437 
438     lane_idx   = ucs_ffs64_safe(rndv_req->send.rndv_get.lanes_map_avail);
439     ucs_assert(lane_idx < UCP_MAX_LANES);
440     rkey       = rndv_req->send.rndv_get.rkey;
441     rkey_index = rndv_req->send.rndv_get.rkey_index[lane_idx];
442     *uct_rkey  = (rkey_index != UCP_NULL_RESOURCE) ?
443                  rkey->tl_rkey[rkey_index].rkey.rkey : UCT_INVALID_RKEY;
444     ep_config  = ucp_ep_config(rndv_req->send.ep);
445     return ep_config->tag.rndv.get_zcopy_lanes[lane_idx];
446 }
447 
ucp_rndv_get_zcopy_next_lane(ucp_request_t * rndv_req)448 static void ucp_rndv_get_zcopy_next_lane(ucp_request_t *rndv_req)
449 {
450     rndv_req->send.rndv_get.lanes_map_avail &= rndv_req->send.rndv_get.lanes_map_avail - 1;
451     if (!rndv_req->send.rndv_get.lanes_map_avail) {
452         rndv_req->send.rndv_get.lanes_map_avail = rndv_req->send.rndv_get.lanes_map_all;
453     }
454 }
455 
456 UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_progress_rma_get_zcopy, (self),
457                  uct_pending_req_t *self)
458 {
459     ucp_request_t *rndv_req = ucs_container_of(self, ucp_request_t, send.uct);
460     ucp_ep_h ep             = rndv_req->send.ep;
461     ucp_ep_config_t *config = ucp_ep_config(ep);
462     const size_t max_iovcnt = 1;
463     uct_iface_attr_t* attrs;
464     ucs_status_t status;
465     size_t offset, length, ucp_mtu, remaining, align, chunk;
466     uct_iov_t iov[max_iovcnt];
467     size_t iovcnt;
468     ucp_rsc_index_t rsc_index;
469     ucp_dt_state_t state;
470     uct_rkey_t uct_rkey;
471     size_t min_zcopy;
472     size_t max_zcopy;
473     int pending_add_res;
474     ucp_lane_index_t lane;
475 
476     /* Figure out which lane to use for get operation */
477     rndv_req->send.lane = lane = ucp_rndv_get_zcopy_get_lane(rndv_req, &uct_rkey);
478 
479     if (lane == UCP_NULL_LANE) {
480         /* If can't perform get_zcopy - switch to active-message.
481          * NOTE: we do not register memory and do not send our keys. */
482         ucp_trace_req(rndv_req, "remote memory unreachable, switch to rtr");
483         ucp_rkey_destroy(rndv_req->send.rndv_get.rkey);
484         ucp_rndv_recv_data_init(rndv_req->send.rndv_get.rreq,
485                                 rndv_req->send.length);
486         /* Update statistics counters from get_zcopy to rtr */
487         UCP_WORKER_STAT_RNDV(ep->worker, GET_ZCOPY, -1);
488         UCP_WORKER_STAT_RNDV(ep->worker, SEND_RTR,  +1);
489         ucp_rndv_req_send_rtr(rndv_req, rndv_req->send.rndv_get.rreq,
490                               rndv_req->send.rndv_get.remote_request,
491                               rndv_req->send.length, 0ul);
492         return UCS_OK;
493     }
494 
495     ucs_assert_always(rndv_req->send.rndv_get.lanes_count > 0);
496 
497     if (!rndv_req->send.mdesc) {
498         status = ucp_send_request_add_reg_lane(rndv_req, lane);
499         ucs_assert_always(status == UCS_OK);
500     }
501 
502     rsc_index = ucp_ep_get_rsc_index(ep, lane);
503     attrs     = ucp_worker_iface_get_attr(ep->worker, rsc_index);
504     align     = attrs->cap.get.opt_zcopy_align;
505     ucp_mtu   = attrs->cap.get.align_mtu;
506     min_zcopy = config->tag.rndv.min_get_zcopy;
507     max_zcopy = config->tag.rndv.max_get_zcopy;
508 
509     offset    = rndv_req->send.state.dt.offset;
510     remaining = (uintptr_t)rndv_req->send.buffer % align;
511 
512     if ((offset == 0) && (remaining > 0) && (rndv_req->send.length > ucp_mtu)) {
513         length = ucp_mtu - remaining;
514     } else {
515         chunk = ucs_align_up((size_t)(rndv_req->send.length /
516                                       rndv_req->send.rndv_get.lanes_count
517                                       * config->tag.rndv.scale[lane]),
518                              align);
519         length = ucs_min(chunk, rndv_req->send.length - offset);
520     }
521 
522     length = ucp_rndv_adjust_zcopy_length(min_zcopy, max_zcopy, align,
523                                           rndv_req->send.length, offset,
524                                           length);
525 
526     ucs_trace_data("req %p: offset %zu remainder %zu rma-get to %p len %zu lane %d",
527                    rndv_req, offset, remaining,
528                    UCS_PTR_BYTE_OFFSET(rndv_req->send.buffer, offset),
529                    length, lane);
530 
531     state = rndv_req->send.state.dt;
532     /* TODO: is this correct? memh array may skip MD's where
533      * registration is not supported. for now SHM may avoid registration,
534      * but it will work on single lane */
535     ucp_dt_iov_copy_uct(ep->worker->context, iov, &iovcnt, max_iovcnt, &state,
536                         rndv_req->send.buffer, ucp_dt_make_contig(1), length,
537                         ucp_ep_md_index(ep, lane),
538                         rndv_req->send.mdesc);
539 
540     for (;;) {
541         status = uct_ep_get_zcopy(ep->uct_eps[lane],
542                                   iov, iovcnt,
543                                   rndv_req->send.rndv_get.remote_address + offset,
544                                   uct_rkey,
545                                   &rndv_req->send.state.uct_comp);
546         ucp_request_send_state_advance(rndv_req, &state,
547                                        UCP_REQUEST_SEND_PROTO_RNDV_GET,
548                                        status);
549         if (rndv_req->send.state.dt.offset == rndv_req->send.length) {
550             if (rndv_req->send.state.uct_comp.count == 0) {
551                 rndv_req->send.state.uct_comp.func(&rndv_req->send.state.uct_comp, status);
552             }
553             return UCS_OK;
554         } else if (!UCS_STATUS_IS_ERR(status)) {
555             /* in case if not all chunks are transmitted - return in_progress
556              * status */
557             ucp_rndv_get_zcopy_next_lane(rndv_req);
558             return UCS_INPROGRESS;
559         } else {
560             if (status == UCS_ERR_NO_RESOURCE) {
561                 if (lane != rndv_req->send.pending_lane) {
562                     /* switch to new pending lane */
563                     pending_add_res = ucp_request_pending_add(rndv_req, &status, 0);
564                     if (!pending_add_res) {
565                         /* failed to switch req to pending queue, try again */
566                         continue;
567                     }
568                     ucs_assert(status == UCS_INPROGRESS);
569                     return UCS_OK;
570                 }
571             }
572             return status;
573         }
574     }
575 }
576 
577 UCS_PROFILE_FUNC_VOID(ucp_rndv_get_completion, (self, status),
578                       uct_completion_t *self, ucs_status_t status)
579 {
580     ucp_request_t *rndv_req = ucs_container_of(self, ucp_request_t,
581                                                send.state.uct_comp);
582 
583     if (rndv_req->send.state.dt.offset == rndv_req->send.length) {
584         ucp_rndv_complete_rma_get_zcopy(rndv_req, status);
585     }
586 }
587 
ucp_rndv_put_completion(uct_completion_t * self,ucs_status_t status)588 static void ucp_rndv_put_completion(uct_completion_t *self, ucs_status_t status)
589 {
590     ucp_request_t *sreq = ucs_container_of(self, ucp_request_t,
591                                            send.state.uct_comp);
592 
593     if (sreq->send.state.dt.offset == sreq->send.length) {
594         ucp_rndv_send_atp(sreq, sreq->send.rndv_put.remote_request);
595     }
596 }
597 
ucp_rndv_req_init_get_zcopy_lane_map(ucp_request_t * rndv_req)598 static void ucp_rndv_req_init_get_zcopy_lane_map(ucp_request_t *rndv_req)
599 {
600     ucp_ep_h ep                = rndv_req->send.ep;
601     ucp_ep_config_t *ep_config = ucp_ep_config(ep);
602     ucp_context_h context      = ep->worker->context;
603     ucs_memory_type_t mem_type = rndv_req->send.mem_type;
604     ucp_rkey_h rkey            = rndv_req->send.rndv_get.rkey;
605     ucp_lane_map_t lane_map;
606     ucp_lane_index_t lane, lane_idx;
607     ucp_md_index_t md_index;
608     uct_md_attr_t *md_attr;
609     ucp_md_index_t dst_md_index;
610     ucp_rsc_index_t rsc_index;
611     uct_iface_attr_t *iface_attr;
612     double max_lane_bw, lane_bw;
613     int i;
614 
615     max_lane_bw = 0;
616     lane_map    = 0;
617     for (i = 0; i < UCP_MAX_LANES; i++) {
618         lane = ep_config->tag.rndv.get_zcopy_lanes[i];
619         if (lane == UCP_NULL_LANE) {
620             break; /* no more lanes */
621         }
622 
623         md_index   = ep_config->md_index[lane];
624         md_attr    = &context->tl_mds[md_index].attr;
625         rsc_index  = ep_config->key.lanes[lane].rsc_index;
626         iface_attr = ucp_worker_iface_get_attr(ep->worker, rsc_index);
627         lane_bw    = ucp_tl_iface_bandwidth(context, &iface_attr->bandwidth);
628 
629         if (ucs_unlikely((md_index != UCP_NULL_RESOURCE) &&
630                          !(md_attr->cap.flags & UCT_MD_FLAG_NEED_RKEY))) {
631             /* Lane does not need rkey, can use the lane with invalid rkey  */
632             if (!rkey || ((mem_type == md_attr->cap.access_mem_type) &&
633                           (mem_type == rkey->mem_type))) {
634                 rndv_req->send.rndv_get.rkey_index[i] = UCP_NULL_RESOURCE;
635                 lane_map                             |= UCS_BIT(i);
636                 max_lane_bw                           = ucs_max(max_lane_bw, lane_bw);
637                 continue;
638             }
639         }
640 
641         if (ucs_unlikely((md_index != UCP_NULL_RESOURCE) &&
642                          (!(md_attr->cap.reg_mem_types & UCS_BIT(mem_type))))) {
643             continue;
644         }
645 
646         dst_md_index = ep_config->key.lanes[lane].dst_md_index;
647         if (rkey && ucs_likely(rkey->md_map & UCS_BIT(dst_md_index))) {
648             /* Return first matching lane */
649             rndv_req->send.rndv_get.rkey_index[i] = ucs_bitmap2idx(rkey->md_map,
650                                                                    dst_md_index);
651             lane_map                             |= UCS_BIT(i);
652             max_lane_bw                           = ucs_max(max_lane_bw, lane_bw);
653         }
654     }
655 
656     if (ucs_popcount(lane_map) > 1) {
657         /* remove lanes if bandwidth is too low comparing to the best lane */
658         ucs_for_each_bit(lane_idx, lane_map) {
659             ucs_assert(lane_idx < UCP_MAX_LANES);
660             lane       = ep_config->tag.rndv.get_zcopy_lanes[lane_idx];
661             rsc_index  = ep_config->key.lanes[lane].rsc_index;
662             iface_attr = ucp_worker_iface_get_attr(ep->worker, rsc_index);
663             lane_bw    = ucp_tl_iface_bandwidth(context, &iface_attr->bandwidth);
664 
665             if ((lane_bw/max_lane_bw) <
666                 (1. / context->config.ext.multi_lane_max_ratio)) {
667                 lane_map                                    &= ~UCS_BIT(lane_idx);
668                 rndv_req->send.rndv_get.rkey_index[lane_idx] = UCP_NULL_RESOURCE;
669             }
670         }
671     }
672 
673     rndv_req->send.rndv_get.lanes_map_all   = lane_map;
674     rndv_req->send.rndv_get.lanes_map_avail = lane_map;
675     rndv_req->send.rndv_get.lanes_count     = ucs_popcount(lane_map);
676 }
677 
ucp_rndv_req_send_rma_get(ucp_request_t * rndv_req,ucp_request_t * rreq,const ucp_rndv_rts_hdr_t * rndv_rts_hdr)678 static ucs_status_t ucp_rndv_req_send_rma_get(ucp_request_t *rndv_req,
679                                               ucp_request_t *rreq,
680                                               const ucp_rndv_rts_hdr_t *rndv_rts_hdr)
681 {
682     ucp_ep_h ep = rndv_req->send.ep;
683     ucs_status_t status;
684     uct_rkey_t uct_rkey;
685 
686     ucp_trace_req(rndv_req, "start rma_get rreq %p", rreq);
687 
688     rndv_req->send.uct.func                = ucp_rndv_progress_rma_get_zcopy;
689     rndv_req->send.buffer                  = rreq->recv.buffer;
690     rndv_req->send.mem_type                = rreq->recv.mem_type;
691     rndv_req->send.datatype                = ucp_dt_make_contig(1);
692     rndv_req->send.length                  = rndv_rts_hdr->size;
693     rndv_req->send.rndv_get.remote_request = rndv_rts_hdr->sreq.reqptr;
694     rndv_req->send.rndv_get.remote_address = rndv_rts_hdr->address;
695     rndv_req->send.rndv_get.rreq           = rreq;
696     rndv_req->send.datatype                = rreq->recv.datatype;
697 
698     status = ucp_ep_rkey_unpack(ep, rndv_rts_hdr + 1,
699                                 &rndv_req->send.rndv_get.rkey);
700     if (status != UCS_OK) {
701         ucs_fatal("failed to unpack rendezvous remote key received from %s: %s",
702                   ucp_ep_peer_name(ep), ucs_status_string(status));
703     }
704 
705     ucp_request_send_state_init(rndv_req, ucp_dt_make_contig(1), 0);
706     ucp_request_send_state_reset(rndv_req, ucp_rndv_get_completion,
707                                  UCP_REQUEST_SEND_PROTO_RNDV_GET);
708 
709     ucp_rndv_req_init_get_zcopy_lane_map(rndv_req);
710 
711     rndv_req->send.lane = ucp_rndv_get_zcopy_get_lane(rndv_req, &uct_rkey);
712     if (rndv_req->send.lane == UCP_NULL_LANE) {
713         return UCS_ERR_UNREACHABLE;
714     }
715 
716     UCP_WORKER_STAT_RNDV(ep->worker, GET_ZCOPY, 1);
717     ucp_request_send(rndv_req, 0);
718 
719     return UCS_OK;
720 }
721 
722 UCS_PROFILE_FUNC_VOID(ucp_rndv_recv_frag_put_completion, (self, status),
723                       uct_completion_t *self, ucs_status_t status)
724 {
725     ucp_request_t *freq     = ucs_container_of(self, ucp_request_t,
726                                                send.state.uct_comp);
727     ucp_request_t *req      = freq->send.rndv_put.sreq;
728     ucp_request_t *rndv_req = (ucp_request_t*)freq->send.rndv_put.remote_request;
729 
730     ucs_trace_req("freq:%p: recv_frag_put done. rreq:%p ", freq, req);
731 
732     /* release memory descriptor */
733     ucs_mpool_put_inline((void *)freq->send.mdesc);
734 
735     /* rndv_req is NULL in case of put protocol */
736     if (rndv_req != NULL) {
737         /* pipeline recv get protocol */
738         rndv_req->send.state.dt.offset += freq->send.length;
739 
740         /* send ATS for fragment get rndv completion */
741         if (rndv_req->send.length == rndv_req->send.state.dt.offset) {
742             ucp_rkey_destroy(rndv_req->send.rndv_get.rkey);
743             ucp_rndv_req_send_ats(rndv_req, req,
744                                   rndv_req->send.rndv_get.remote_request, UCS_OK);
745         }
746     }
747 
748     req->recv.tag.remaining -= freq->send.length;
749     if (req->recv.tag.remaining == 0) {
750         ucp_request_complete_tag_recv(req, UCS_OK);
751     }
752 
753     ucp_request_put(freq);
754 }
755 
756 static UCS_F_ALWAYS_INLINE void
ucp_rndv_init_mem_type_frag_req(ucp_worker_h worker,ucp_request_t * freq,int rndv_op,uct_completion_callback_t comp_cb,ucp_mem_desc_t * mdesc,ucs_memory_type_t mem_type,size_t length,uct_pending_callback_t uct_func)757 ucp_rndv_init_mem_type_frag_req(ucp_worker_h worker, ucp_request_t *freq, int rndv_op,
758                                 uct_completion_callback_t comp_cb, ucp_mem_desc_t *mdesc,
759                                 ucs_memory_type_t mem_type, size_t length,
760                                 uct_pending_callback_t uct_func)
761 {
762     ucp_ep_h mem_type_ep;
763     ucp_md_index_t md_index;
764     ucp_lane_index_t mem_type_rma_lane;
765 
766     ucp_request_send_state_init(freq, ucp_dt_make_contig(1), 0);
767     ucp_request_send_state_reset(freq, comp_cb, rndv_op);
768 
769     freq->send.buffer   = mdesc + 1;
770     freq->send.length   = length;
771     freq->send.datatype = ucp_dt_make_contig(1);
772     freq->send.mem_type = mem_type;
773     freq->send.mdesc    = mdesc;
774     freq->send.uct.func = uct_func;
775 
776     if (mem_type != UCS_MEMORY_TYPE_HOST) {
777         mem_type_ep       = worker->mem_type_ep[mem_type];
778         mem_type_rma_lane = ucp_ep_config(mem_type_ep)->key.rma_bw_lanes[0];
779         md_index          = ucp_ep_md_index(mem_type_ep, mem_type_rma_lane);
780         ucs_assert(mem_type_rma_lane != UCP_NULL_LANE);
781 
782         freq->send.lane                       = mem_type_rma_lane;
783         freq->send.ep                         = mem_type_ep;
784         freq->send.state.dt.dt.contig.memh[0] = ucp_memh2uct(mdesc->memh, md_index);
785         freq->send.state.dt.dt.contig.md_map  = UCS_BIT(md_index);
786     }
787 }
788 
789 static void
ucp_rndv_recv_frag_put_mem_type(ucp_request_t * rreq,ucp_request_t * rndv_req,ucp_request_t * freq,ucp_mem_desc_t * mdesc,size_t length,size_t offset)790 ucp_rndv_recv_frag_put_mem_type(ucp_request_t *rreq, ucp_request_t *rndv_req,
791                                 ucp_request_t *freq, ucp_mem_desc_t *mdesc,
792                                 size_t length, size_t offset)
793 {
794 
795     ucs_assert_always(!UCP_MEM_IS_ACCESSIBLE_FROM_CPU(rreq->recv.mem_type));
796 
797     /* PUT on memtype endpoint to stage from
798      * frag recv buffer to memtype recv buffer
799      */
800 
801     ucp_rndv_init_mem_type_frag_req(rreq->recv.worker, freq, UCP_REQUEST_SEND_PROTO_RNDV_PUT,
802                                     ucp_rndv_recv_frag_put_completion, mdesc, rreq->recv.mem_type,
803                                     length, ucp_rndv_progress_rma_put_zcopy);
804 
805     freq->send.rndv_put.sreq           = rreq;
806     freq->send.rndv_put.rkey           = NULL;
807     freq->send.rndv_put.remote_request = (uintptr_t)rndv_req;
808     freq->send.rndv_put.remote_address = (uintptr_t)rreq->recv.buffer + offset;
809 
810     ucp_request_send(freq, 0);
811 }
812 
813 static ucs_status_t
ucp_rndv_send_frag_get_mem_type(ucp_request_t * sreq,uintptr_t rreq_ptr,size_t length,uint64_t remote_address,ucs_memory_type_t remote_mem_type,ucp_rkey_h rkey,uint8_t * rkey_index,ucp_lane_map_t lanes_map,uct_completion_callback_t comp_cb)814 ucp_rndv_send_frag_get_mem_type(ucp_request_t *sreq, uintptr_t rreq_ptr,
815                                 size_t length, uint64_t remote_address,
816                                 ucs_memory_type_t remote_mem_type, ucp_rkey_h rkey,
817                                 uint8_t *rkey_index, ucp_lane_map_t lanes_map,
818                                 uct_completion_callback_t comp_cb)
819 {
820     ucp_worker_h worker = sreq->send.ep->worker;
821     ucp_request_t *freq;
822     ucp_mem_desc_t *mdesc;
823     ucp_lane_index_t i;
824 
825     /* GET fragment to stage buffer */
826 
827     freq = ucp_request_get(worker);
828     if (ucs_unlikely(freq == NULL)) {
829         ucs_error("failed to allocate fragment receive request");
830         return UCS_ERR_NO_MEMORY;
831     }
832 
833     mdesc = ucp_worker_mpool_get(&worker->rndv_frag_mp);
834     if (ucs_unlikely(mdesc == NULL)) {
835         ucs_error("failed to allocate fragment memory desc");
836         return UCS_ERR_NO_MEMORY;
837     }
838 
839     freq->send.ep = sreq->send.ep;
840 
841     ucp_rndv_init_mem_type_frag_req(worker, freq, UCP_REQUEST_SEND_PROTO_RNDV_GET,
842                                     comp_cb, mdesc, remote_mem_type, length,
843                                     ucp_rndv_progress_rma_get_zcopy);
844 
845     freq->send.rndv_get.rkey            = rkey;
846     freq->send.rndv_get.remote_address  = remote_address;
847     freq->send.rndv_get.remote_request  = rreq_ptr;
848     freq->send.rndv_get.rreq            = sreq;
849     freq->send.rndv_get.lanes_map_all   = lanes_map;
850     freq->send.rndv_get.lanes_map_avail = lanes_map;
851     freq->send.rndv_get.lanes_count     = ucs_popcount(lanes_map);
852 
853     for (i = 0; i < UCP_MAX_LANES; i++) {
854         freq->send.rndv_get.rkey_index[i] = rkey_index ? rkey_index[i]
855                                                        : UCP_NULL_RESOURCE;
856     }
857 
858 
859     return ucp_request_send(freq, 0);
860 }
861 
862 UCS_PROFILE_FUNC_VOID(ucp_rndv_recv_frag_get_completion, (self, status),
863                       uct_completion_t *self, ucs_status_t status)
864 {
865     ucp_request_t *freq     = ucs_container_of(self, ucp_request_t,
866                                                send.state.uct_comp);
867     ucp_request_t *rndv_req = freq->send.rndv_get.rreq;
868     ucp_request_t *rreq     = rndv_req->send.rndv_get.rreq;
869 
870     ucs_trace_req("freq:%p: recv_frag_get done. rreq:%p length:%ld offset:%ld",
871                   freq, rndv_req, freq->send.length,
872                   freq->send.rndv_get.remote_address - rndv_req->send.rndv_get.remote_address);
873 
874     /* fragment GET completed from remote to staging buffer, issue PUT from
875      * staging buffer to recv buffer */
876     ucp_rndv_recv_frag_put_mem_type(rreq, rndv_req, freq,
877                                     (ucp_mem_desc_t *)freq->send.buffer -1,
878                                     freq->send.length, (freq->send.rndv_get.remote_address -
879                                     rndv_req->send.rndv_get.remote_address));
880 }
881 
882 static ucs_status_t
ucp_rndv_recv_start_get_pipeline(ucp_worker_h worker,ucp_request_t * rndv_req,ucp_request_t * rreq,uintptr_t remote_request,const void * rkey_buffer,uint64_t remote_address,size_t size,size_t base_offset)883 ucp_rndv_recv_start_get_pipeline(ucp_worker_h worker, ucp_request_t *rndv_req,
884                                  ucp_request_t *rreq, uintptr_t remote_request,
885                                  const void *rkey_buffer, uint64_t remote_address,
886                                  size_t size, size_t base_offset)
887 {
888     ucp_ep_h ep             = rndv_req->send.ep;
889     ucp_ep_config_t *config = ucp_ep_config(ep);
890     ucp_context_h context   = worker->context;
891     ucs_status_t status;
892     size_t max_frag_size, offset, length;
893     size_t min_zcopy, max_zcopy;
894 
895     min_zcopy                              = config->tag.rndv.min_get_zcopy;
896     max_zcopy                              = config->tag.rndv.max_get_zcopy;
897     max_frag_size                          = ucs_min(context->config.ext.rndv_frag_size,
898                                                      max_zcopy);
899     rndv_req->send.rndv_get.remote_request = remote_request;
900     rndv_req->send.rndv_get.remote_address = remote_address - base_offset;
901     rndv_req->send.rndv_get.rreq           = rreq;
902     rndv_req->send.length                  = size;
903     rndv_req->send.state.dt.offset         = 0;
904     rndv_req->send.mem_type                = rreq->recv.mem_type;
905 
906     /* Protocol:
907      * Step 1: GET remote fragment into HOST fragment buffer
908      * Step 2: PUT from fragment buffer to MEM TYPE destination
909      * Step 3: Send ATS for RNDV request
910      */
911 
912     status = ucp_ep_rkey_unpack(rndv_req->send.ep, rkey_buffer,
913                                 &rndv_req->send.rndv_get.rkey);
914     if (ucs_unlikely(status != UCS_OK)) {
915         ucs_fatal("failed to unpack rendezvous remote key received from %s: %s",
916                   ucp_ep_peer_name(rndv_req->send.ep), ucs_status_string(status));
917     }
918 
919     ucp_rndv_req_init_get_zcopy_lane_map(rndv_req);
920 
921     offset = 0;
922     while (offset != size) {
923         length = ucp_rndv_adjust_zcopy_length(min_zcopy, max_frag_size, 0,
924                                               size, offset, size - offset);
925 
926         /* GET remote fragment into HOST fragment buffer */
927         ucp_rndv_send_frag_get_mem_type(rndv_req, remote_request, length,
928                                         remote_address + offset, UCS_MEMORY_TYPE_HOST,
929                                         rndv_req->send.rndv_get.rkey,
930                                         rndv_req->send.rndv_get.rkey_index,
931                                         rndv_req->send.rndv_get.lanes_map_all,
932                                         ucp_rndv_recv_frag_get_completion);
933 
934         offset += length;
935     }
936 
937     return UCS_OK;
938 }
939 
ucp_rndv_send_frag_rtr(ucp_worker_h worker,ucp_request_t * rndv_req,ucp_request_t * rreq,const ucp_rndv_rts_hdr_t * rndv_rts_hdr)940 static void ucp_rndv_send_frag_rtr(ucp_worker_h worker, ucp_request_t *rndv_req,
941                                    ucp_request_t *rreq,
942                                    const ucp_rndv_rts_hdr_t *rndv_rts_hdr)
943 {
944     size_t max_frag_size = worker->context->config.ext.rndv_frag_size;
945     int i, num_frags;
946     size_t frag_size;
947     size_t offset;
948     ucp_mem_desc_t *mdesc;
949     ucp_request_t *freq;
950     ucp_request_t *frndv_req;
951     unsigned md_index;
952     unsigned memh_index;
953 
954     ucp_trace_req(rreq, "using rndv pipeline protocol rndv_req %p", rndv_req);
955 
956     offset    = 0;
957     num_frags = ucs_div_round_up(rndv_rts_hdr->size, max_frag_size);
958 
959     for (i = 0; i < num_frags; i++) {
960         frag_size = ucs_min(max_frag_size, (rndv_rts_hdr->size - offset));
961 
962         /* internal fragment recv request allocated on receiver side to receive
963          *  put fragment from sender and to perform a put to recv buffer */
964         freq = ucp_request_get(worker);
965         if (freq == NULL) {
966             ucs_fatal("failed to allocate fragment receive request");
967         }
968 
969         /* internal rndv request to send RTR */
970         frndv_req = ucp_request_get(worker);
971         if (frndv_req == NULL) {
972             ucs_fatal("failed to allocate fragment rendezvous reply");
973         }
974 
975         /* allocate fragment recv buffer desc*/
976         mdesc = ucp_worker_mpool_get(&worker->rndv_frag_mp);
977         if (mdesc == NULL) {
978             ucs_fatal("failed to allocate fragment memory buffer");
979         }
980 
981         freq->recv.buffer                 = mdesc + 1;
982         freq->recv.datatype               = ucp_dt_make_contig(1);
983         freq->recv.mem_type               = UCS_MEMORY_TYPE_HOST;
984         freq->recv.length                 = frag_size;
985         freq->recv.state.dt.contig.md_map = 0;
986         freq->recv.frag.rreq              = rreq;
987         freq->recv.frag.offset            = offset;
988         freq->flags                      |= UCP_REQUEST_FLAG_RNDV_FRAG;
989 
990         memh_index = 0;
991         ucs_for_each_bit(md_index,
992                          (ucp_ep_config(rndv_req->send.ep)->key.rma_bw_md_map &
993                           mdesc->memh->md_map)) {
994             freq->recv.state.dt.contig.memh[memh_index++] = ucp_memh2uct(mdesc->memh, md_index);
995             freq->recv.state.dt.contig.md_map            |= UCS_BIT(md_index);
996         }
997         ucs_assert(memh_index <= UCP_MAX_OP_MDS);
998 
999         frndv_req->send.ep           = rndv_req->send.ep;
1000         frndv_req->send.pending_lane = UCP_NULL_LANE;
1001 
1002         ucp_rndv_req_send_rtr(frndv_req, freq, rndv_rts_hdr->sreq.reqptr,
1003                               freq->recv.length, offset);
1004         offset += frag_size;
1005     }
1006 
1007     /* release original rndv reply request */
1008     ucp_request_put(rndv_req);
1009 }
1010 
1011 static UCS_F_ALWAYS_INLINE int
ucp_rndv_is_rkey_ptr(const ucp_rndv_rts_hdr_t * rndv_rts_hdr,ucp_ep_h ep,ucs_memory_type_t recv_mem_type,ucp_rndv_mode_t rndv_mode)1012 ucp_rndv_is_rkey_ptr(const ucp_rndv_rts_hdr_t *rndv_rts_hdr, ucp_ep_h ep,
1013                      ucs_memory_type_t recv_mem_type, ucp_rndv_mode_t rndv_mode)
1014 {
1015     const ucp_ep_config_t *ep_config = ucp_ep_config(ep);
1016 
1017     return /* must have remote address */
1018            (rndv_rts_hdr->address != 0) &&
1019            /* remote key must be on a memory domain for which we support rkey_ptr */
1020            (ucp_rkey_packed_md_map(rndv_rts_hdr + 1) &
1021             ep_config->tag.rndv.rkey_ptr_dst_mds) &&
1022            /* rendezvous mode must not be forced to put/get */
1023            (rndv_mode == UCP_RNDV_MODE_AUTO) &&
1024            /* need local memory access for data unpack */
1025            UCP_MEM_IS_ACCESSIBLE_FROM_CPU(recv_mem_type);
1026 }
1027 
ucp_rndv_progress_rkey_ptr(void * arg)1028 static unsigned ucp_rndv_progress_rkey_ptr(void *arg)
1029 {
1030     ucp_worker_h worker     = (ucp_worker_h)arg;
1031     ucp_request_t *rndv_req = ucs_queue_head_elem_non_empty(&worker->rkey_ptr_reqs,
1032                                                             ucp_request_t,
1033                                                             send.rkey_ptr.queue_elem);
1034     ucp_request_t *rreq     = rndv_req->send.rkey_ptr.rreq;
1035     size_t seg_size         = ucs_min(worker->context->config.ext.rkey_ptr_seg_size,
1036                                       rndv_req->send.length - rreq->recv.state.offset);
1037     ucs_status_t status;
1038     size_t offset, new_offset;
1039     int last;
1040 
1041     offset     = rreq->recv.state.offset;
1042     new_offset = offset + seg_size;
1043     last       = new_offset == rndv_req->send.length;
1044     status     = ucp_request_recv_data_unpack(rreq,
1045                                               rndv_req->send.buffer + offset,
1046                                               seg_size, offset, last);
1047     if (ucs_unlikely(status != UCS_OK) || last) {
1048         ucs_queue_pull_non_empty(&worker->rkey_ptr_reqs);
1049         ucp_request_complete_tag_recv(rreq, status);
1050         ucp_rkey_destroy(rndv_req->send.rkey_ptr.rkey);
1051         ucp_rndv_req_send_ats(rndv_req, rreq,
1052                               rndv_req->send.rkey_ptr.remote_request, status);
1053         if (ucs_queue_is_empty(&worker->rkey_ptr_reqs)) {
1054             uct_worker_progress_unregister_safe(worker->uct,
1055                                                 &worker->rkey_ptr_cb_id);
1056         }
1057     } else {
1058         rreq->recv.state.offset = new_offset;
1059     }
1060 
1061     return 1;
1062 }
1063 
ucp_rndv_do_rkey_ptr(ucp_request_t * rndv_req,ucp_request_t * rreq,const ucp_rndv_rts_hdr_t * rndv_rts_hdr)1064 static void ucp_rndv_do_rkey_ptr(ucp_request_t *rndv_req, ucp_request_t *rreq,
1065                                  const ucp_rndv_rts_hdr_t *rndv_rts_hdr)
1066 {
1067     ucp_ep_h ep                      = rndv_req->send.ep;
1068     const ucp_ep_config_t *ep_config = ucp_ep_config(ep);
1069     ucp_worker_h worker              = rreq->recv.worker;
1070     ucp_md_index_t dst_md_index      = 0;
1071     ucp_lane_index_t i, lane;
1072     ucs_status_t status;
1073     unsigned rkey_index;
1074     void *local_ptr;
1075     ucp_rkey_h rkey;
1076 
1077     ucp_trace_req(rndv_req, "start rkey_ptr rndv rreq %p", rreq);
1078 
1079     status = ucp_ep_rkey_unpack(ep, rndv_rts_hdr + 1, &rkey);
1080     if (status != UCS_OK) {
1081         ucs_fatal("failed to unpack rendezvous remote key received from %s: %s",
1082                   ucp_ep_peer_name(ep), ucs_status_string(status));
1083     }
1084 
1085     /* Find a lane which is capable of accessing the destination memory */
1086     lane = UCP_NULL_LANE;
1087     for (i = 0; i < ep_config->key.num_lanes; ++i) {
1088         dst_md_index = ep_config->key.lanes[i].dst_md_index;
1089         if (UCS_BIT(dst_md_index) & rkey->md_map) {
1090             lane = i;
1091             break;
1092         }
1093     }
1094 
1095     if (ucs_unlikely(lane == UCP_NULL_LANE)) {
1096         /* We should be able to find a lane, because ucp_rndv_is_rkey_ptr()
1097          * already checked that (rkey->md_map & ep_config->rkey_ptr_dst_mds) != 0
1098          */
1099         ucs_fatal("failed to find a lane to access remote memory domains 0x%lx",
1100                   rkey->md_map);
1101     }
1102 
1103     rkey_index = ucs_bitmap2idx(rkey->md_map, dst_md_index);
1104     status     = uct_rkey_ptr(rkey->tl_rkey[rkey_index].cmpt,
1105                               &rkey->tl_rkey[rkey_index].rkey,
1106                               rndv_rts_hdr->address, &local_ptr);
1107     if (status != UCS_OK) {
1108         ucp_request_complete_tag_recv(rreq, status);
1109         ucp_rkey_destroy(rkey);
1110         ucp_rndv_req_send_ats(rndv_req, rreq, rndv_rts_hdr->sreq.reqptr, status);
1111         return;
1112     }
1113 
1114     rreq->recv.state.offset = 0;
1115 
1116     ucp_trace_req(rndv_req, "obtained a local pointer to remote buffer: %p",
1117                   local_ptr);
1118     rndv_req->send.buffer                  = local_ptr;
1119     rndv_req->send.length                  = rndv_rts_hdr->size;
1120     rndv_req->send.rkey_ptr.rkey           = rkey;
1121     rndv_req->send.rkey_ptr.remote_request = rndv_rts_hdr->sreq.reqptr;
1122     rndv_req->send.rkey_ptr.rreq           = rreq;
1123 
1124     UCP_WORKER_STAT_RNDV(ep->worker, RKEY_PTR, 1);
1125 
1126     ucs_queue_push(&worker->rkey_ptr_reqs, &rndv_req->send.rkey_ptr.queue_elem);
1127     uct_worker_progress_register_safe(worker->uct,
1128                                       ucp_rndv_progress_rkey_ptr,
1129                                       rreq->recv.worker,
1130                                       UCS_CALLBACKQ_FLAG_FAST,
1131                                       &worker->rkey_ptr_cb_id);
1132 }
1133 
1134 static UCS_F_ALWAYS_INLINE int
ucp_rndv_test_zcopy_scheme_support(size_t length,size_t min_zcopy,size_t max_zcopy,int split)1135 ucp_rndv_test_zcopy_scheme_support(size_t length, size_t min_zcopy,
1136                                    size_t max_zcopy, int split)
1137 {
1138     return /* is the current message greater than the minimal GET/PUT Zcopy? */
1139            (length >= min_zcopy) &&
1140            /* is the current message less than the maximal GET/PUT Zcopy? */
1141            ((length <= max_zcopy) ||
1142             /* or can the message be split? */ split);
1143 }
1144 
1145 UCS_PROFILE_FUNC_VOID(ucp_rndv_matched, (worker, rreq, rndv_rts_hdr),
1146                       ucp_worker_h worker, ucp_request_t *rreq,
1147                       const ucp_rndv_rts_hdr_t *rndv_rts_hdr)
1148 {
1149     ucp_rndv_mode_t rndv_mode;
1150     ucp_request_t *rndv_req;
1151     ucp_ep_h ep;
1152     ucp_ep_config_t *ep_config;
1153     ucs_status_t status;
1154     int is_get_zcopy_failed;
1155 
1156     UCS_ASYNC_BLOCK(&worker->async);
1157 
1158     UCS_PROFILE_REQUEST_EVENT(rreq, "rndv_match", 0);
1159 
1160     /* rreq is the receive request on the receiver's side */
1161     rreq->recv.tag.info.sender_tag = rndv_rts_hdr->super.tag;
1162     rreq->recv.tag.info.length     = rndv_rts_hdr->size;
1163 
1164     /* the internal send request allocated on receiver side (to perform a "get"
1165      * operation, send "ATS" and "RTR") */
1166     rndv_req = ucp_request_get(worker);
1167     if (rndv_req == NULL) {
1168         ucs_error("failed to allocate rendezvous reply");
1169         goto out;
1170     }
1171 
1172     rndv_req->send.ep           = ucp_worker_get_ep_by_ptr(worker,
1173                                                            rndv_rts_hdr->sreq.ep_ptr);
1174     rndv_req->flags             = 0;
1175     rndv_req->send.mdesc        = NULL;
1176     rndv_req->send.pending_lane = UCP_NULL_LANE;
1177     is_get_zcopy_failed         = 0;
1178 
1179     ucp_trace_req(rreq,
1180                   "rndv matched remote {address 0x%"PRIx64" size %zu sreq 0x%lx}"
1181                   " rndv_sreq %p", rndv_rts_hdr->address, rndv_rts_hdr->size,
1182                   rndv_rts_hdr->sreq.reqptr, rndv_req);
1183 
1184     if (ucs_unlikely(rreq->recv.length < rndv_rts_hdr->size)) {
1185         ucp_trace_req(rndv_req,
1186                       "rndv truncated remote size %zu local size %zu rreq %p",
1187                       rndv_rts_hdr->size, rreq->recv.length, rreq);
1188         ucp_rndv_req_send_ats(rndv_req, rreq, rndv_rts_hdr->sreq.reqptr, UCS_OK);
1189         ucp_request_recv_generic_dt_finish(rreq);
1190         ucp_rndv_zcopy_recv_req_complete(rreq, UCS_ERR_MESSAGE_TRUNCATED);
1191         goto out;
1192     }
1193 
1194     /* if the receive side is not connected yet then the RTS was received on a stub ep */
1195     ep        = rndv_req->send.ep;
1196     ep_config = ucp_ep_config(ep);
1197     rndv_mode = worker->context->config.ext.rndv_mode;
1198 
1199     if (ucp_rndv_is_rkey_ptr(rndv_rts_hdr, ep, rreq->recv.mem_type, rndv_mode)) {
1200         ucp_rndv_do_rkey_ptr(rndv_req, rreq, rndv_rts_hdr);
1201         goto out;
1202     }
1203 
1204     if (UCP_DT_IS_CONTIG(rreq->recv.datatype)) {
1205         if ((rndv_rts_hdr->address != 0) &&
1206             ucp_rndv_test_zcopy_scheme_support(rndv_rts_hdr->size,
1207                                                ep_config->tag.rndv.min_get_zcopy,
1208                                                ep_config->tag.rndv.max_get_zcopy,
1209                                                ep_config->tag.rndv.get_zcopy_split)) {
1210             /* try to fetch the data with a get_zcopy operation */
1211             status = ucp_rndv_req_send_rma_get(rndv_req, rreq, rndv_rts_hdr);
1212             if (status == UCS_OK) {
1213                 goto out;
1214             }
1215 
1216             /* fallback to non get zcopy protocol */
1217             ucp_rkey_destroy(rndv_req->send.rndv_get.rkey);
1218             is_get_zcopy_failed = 1;
1219         }
1220 
1221         if (rndv_mode == UCP_RNDV_MODE_AUTO) {
1222             /* check if we need pipelined memtype staging */
1223             if (UCP_MEM_IS_CUDA(rreq->recv.mem_type) &&
1224                 ucp_rndv_is_recv_pipeline_needed(rndv_req,
1225                                                  rndv_rts_hdr,
1226                                                  rreq->recv.mem_type,
1227                                                  is_get_zcopy_failed)) {
1228                 ucp_rndv_recv_data_init(rreq, rndv_rts_hdr->size);
1229                 if (ucp_rndv_is_put_pipeline_needed(rndv_rts_hdr->address,
1230                                                     rndv_rts_hdr->size,
1231                                                     ep_config->tag.rndv.min_get_zcopy,
1232                                                     ep_config->tag.rndv.max_get_zcopy,
1233                                                     is_get_zcopy_failed)) {
1234                     /* send FRAG RTR for sender to PUT the fragment. */
1235                     ucp_rndv_send_frag_rtr(worker, rndv_req, rreq, rndv_rts_hdr);
1236                 } else {
1237                     /* sender address is present. do GET pipeline */
1238                     ucp_rndv_recv_start_get_pipeline(worker, rndv_req, rreq,
1239                                                      rndv_rts_hdr->sreq.reqptr,
1240                                                      rndv_rts_hdr + 1,
1241                                                      rndv_rts_hdr->address,
1242                                                      rndv_rts_hdr->size, 0);
1243                 }
1244                 goto out;
1245             }
1246         }
1247 
1248         if ((rndv_mode == UCP_RNDV_MODE_PUT_ZCOPY) ||
1249             UCP_MEM_IS_CUDA(rreq->recv.mem_type)) {
1250             /* put protocol is allowed - register receive buffer memory for rma */
1251             ucs_assert(rndv_rts_hdr->size <= rreq->recv.length);
1252             ucp_request_recv_buffer_reg(rreq, ep_config->key.rma_bw_md_map,
1253                                         rndv_rts_hdr->size);
1254         }
1255     }
1256 
1257     /* The sender didn't specify its address in the RTS, or the rndv mode was
1258      * configured to PUT, or GET rndv mode is unsupported - send an RTR and
1259      * the sender will send the data with active message or put_zcopy. */
1260     ucp_rndv_recv_data_init(rreq, rndv_rts_hdr->size);
1261     UCP_WORKER_STAT_RNDV(ep->worker, SEND_RTR, 1);
1262     ucp_rndv_req_send_rtr(rndv_req, rreq, rndv_rts_hdr->sreq.reqptr,
1263                           rndv_rts_hdr->size, 0ul);
1264 
1265 out:
1266     UCS_ASYNC_UNBLOCK(&worker->async);
1267 }
1268 
ucp_rndv_process_rts(void * arg,void * data,size_t length,unsigned tl_flags)1269 ucs_status_t ucp_rndv_process_rts(void *arg, void *data, size_t length,
1270                                   unsigned tl_flags)
1271 {
1272     ucp_worker_h worker                = arg;
1273     ucp_rndv_rts_hdr_t *rndv_rts_hdr   = data;
1274     ucp_recv_desc_t *rdesc;
1275     ucp_request_t *rreq;
1276     ucs_status_t status;
1277 
1278     rreq = ucp_tag_exp_search(&worker->tm, rndv_rts_hdr->super.tag);
1279     if (rreq != NULL) {
1280         ucp_rndv_matched(worker, rreq, rndv_rts_hdr);
1281 
1282         /* Cancel req in transport if it was offloaded, because it arrived
1283            as unexpected */
1284         ucp_tag_offload_try_cancel(worker, rreq, UCP_TAG_OFFLOAD_CANCEL_FORCE);
1285 
1286         UCP_WORKER_STAT_RNDV(worker, EXP, 1);
1287         status = UCS_OK;
1288     } else {
1289         status = ucp_recv_desc_init(worker, data, length, 0, tl_flags,
1290                                     sizeof(*rndv_rts_hdr),
1291                                     UCP_RECV_DESC_FLAG_RNDV, 0, &rdesc);
1292         if (!UCS_STATUS_IS_ERR(status)) {
1293             ucp_tag_unexp_recv(&worker->tm, rdesc, rndv_rts_hdr->super.tag);
1294         }
1295     }
1296 
1297     return status;
1298 }
1299 
1300 UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_rts_handler,
1301                  (arg, data, length, tl_flags),
1302                  void *arg, void *data, size_t length, unsigned tl_flags)
1303 {
1304     return ucp_rndv_process_rts(arg, data, length, tl_flags);
1305 }
1306 
1307 UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_ats_handler,
1308                  (arg, data, length, flags),
1309                  void *arg, void *data, size_t length, unsigned flags)
1310 {
1311     ucp_reply_hdr_t *rep_hdr = data;
1312     ucp_request_t *sreq = (ucp_request_t*) rep_hdr->reqptr;
1313 
1314     /* dereg the original send request and set it to complete */
1315     UCS_PROFILE_REQUEST_EVENT(sreq, "rndv_ats_recv", 0);
1316     if (sreq->flags & UCP_REQUEST_FLAG_OFFLOADED) {
1317         ucp_tag_offload_cancel_rndv(sreq);
1318     }
1319     ucp_rndv_complete_send(sreq, rep_hdr->status);
1320     return UCS_OK;
1321 }
1322 
ucp_rndv_pack_data(void * dest,void * arg)1323 static size_t ucp_rndv_pack_data(void *dest, void *arg)
1324 {
1325     ucp_rndv_data_hdr_t *hdr = dest;
1326     ucp_request_t *sreq = arg;
1327     size_t length, offset;
1328 
1329     offset        = sreq->send.state.dt.offset;
1330     hdr->rreq_ptr = sreq->send.msg_proto.tag.rreq_ptr;
1331     hdr->offset   = offset;
1332     length        = ucs_min(sreq->send.length - offset,
1333                             ucp_ep_get_max_bcopy(sreq->send.ep, sreq->send.lane) - sizeof(*hdr));
1334 
1335     return sizeof(*hdr) + ucp_dt_pack(sreq->send.ep->worker, sreq->send.datatype,
1336                                       sreq->send.mem_type, hdr + 1, sreq->send.buffer,
1337                                       &sreq->send.state.dt, length);
1338 }
1339 
1340 UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_progress_am_bcopy, (self),
1341                  uct_pending_req_t *self)
1342 {
1343     ucp_request_t *sreq = ucs_container_of(self, ucp_request_t, send.uct);
1344     ucp_ep_t *ep        = sreq->send.ep;
1345     ucs_status_t status;
1346 
1347     if (sreq->send.length <= ucp_ep_config(ep)->am.max_bcopy - sizeof(ucp_rndv_data_hdr_t)) {
1348         /* send a single bcopy message */
1349         status = ucp_do_am_bcopy_single(self, UCP_AM_ID_RNDV_DATA,
1350                                         ucp_rndv_pack_data);
1351     } else {
1352         status = ucp_do_am_bcopy_multi(self, UCP_AM_ID_RNDV_DATA,
1353                                        UCP_AM_ID_RNDV_DATA,
1354                                        ucp_rndv_pack_data,
1355                                        ucp_rndv_pack_data, 1);
1356     }
1357     if (status == UCS_OK) {
1358         ucp_rndv_complete_send(sreq, UCS_OK);
1359     } else if (status == UCP_STATUS_PENDING_SWITCH) {
1360         status = UCS_OK;
1361     }
1362 
1363     return status;
1364 }
1365 
1366 UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_progress_rma_put_zcopy, (self),
1367                  uct_pending_req_t *self)
1368 {
1369     ucp_request_t *sreq     = ucs_container_of(self, ucp_request_t, send.uct);
1370     const size_t max_iovcnt = 1;
1371     ucp_ep_h ep             = sreq->send.ep;
1372     ucs_status_t status;
1373     size_t offset, ucp_mtu, align, remaining, length;
1374     uct_iface_attr_t *attrs;
1375     uct_iov_t iov[max_iovcnt];
1376     size_t iovcnt;
1377     ucp_dt_state_t state;
1378 
1379     if (!sreq->send.mdesc) {
1380         status = ucp_request_send_buffer_reg_lane(sreq, sreq->send.lane, 0);
1381         ucs_assert_always(status == UCS_OK);
1382     }
1383 
1384     attrs     = ucp_worker_iface_get_attr(ep->worker,
1385                                           ucp_ep_get_rsc_index(ep, sreq->send.lane));
1386     align     = attrs->cap.put.opt_zcopy_align;
1387     ucp_mtu   = attrs->cap.put.align_mtu;
1388 
1389     offset    = sreq->send.state.dt.offset;
1390     remaining = (uintptr_t)sreq->send.buffer % align;
1391 
1392     if ((offset == 0) && (remaining > 0) && (sreq->send.length > ucp_mtu)) {
1393         length = ucp_mtu - remaining;
1394     } else {
1395         length = ucs_min(sreq->send.length - offset,
1396                          ucp_ep_config(ep)->tag.rndv.max_put_zcopy);
1397     }
1398 
1399     ucs_trace_data("req %p: offset %zu remainder %zu. read to %p len %zu",
1400                    sreq, offset, (uintptr_t)sreq->send.buffer % align,
1401                    UCS_PTR_BYTE_OFFSET(sreq->send.buffer, offset), length);
1402 
1403     state = sreq->send.state.dt;
1404     ucp_dt_iov_copy_uct(ep->worker->context, iov, &iovcnt, max_iovcnt, &state,
1405                         sreq->send.buffer, ucp_dt_make_contig(1), length,
1406                         ucp_ep_md_index(ep, sreq->send.lane), sreq->send.mdesc);
1407     status = uct_ep_put_zcopy(ep->uct_eps[sreq->send.lane],
1408                               iov, iovcnt,
1409                               sreq->send.rndv_put.remote_address + offset,
1410                               sreq->send.rndv_put.uct_rkey,
1411                               &sreq->send.state.uct_comp);
1412     ucp_request_send_state_advance(sreq, &state,
1413                                    UCP_REQUEST_SEND_PROTO_RNDV_PUT,
1414                                    status);
1415     if (sreq->send.state.dt.offset == sreq->send.length) {
1416         if (sreq->send.state.uct_comp.count == 0) {
1417             sreq->send.state.uct_comp.func(&sreq->send.state.uct_comp, status);
1418         }
1419         return UCS_OK;
1420     } else if (!UCS_STATUS_IS_ERR(status)) {
1421         return UCS_INPROGRESS;
1422     } else {
1423         return status;
1424     }
1425 }
1426 
ucp_rndv_am_zcopy_send_req_complete(ucp_request_t * req,ucs_status_t status)1427 static void ucp_rndv_am_zcopy_send_req_complete(ucp_request_t *req,
1428                                                 ucs_status_t status)
1429 {
1430     ucs_assert(req->send.state.uct_comp.count == 0);
1431     ucp_request_send_buffer_dereg(req);
1432     ucp_request_complete_send(req, status);
1433 }
1434 
ucp_rndv_am_zcopy_completion(uct_completion_t * self,ucs_status_t status)1435 static void ucp_rndv_am_zcopy_completion(uct_completion_t *self,
1436                                          ucs_status_t status)
1437 {
1438     ucp_request_t *sreq = ucs_container_of(self, ucp_request_t,
1439                                            send.state.uct_comp);
1440     if (sreq->send.state.dt.offset == sreq->send.length) {
1441         ucp_rndv_am_zcopy_send_req_complete(sreq, status);
1442     } else if (status != UCS_OK) {
1443         ucs_fatal("error handling is unsupported with rendezvous protocol");
1444     }
1445 }
1446 
ucp_rndv_progress_am_zcopy_single(uct_pending_req_t * self)1447 static ucs_status_t ucp_rndv_progress_am_zcopy_single(uct_pending_req_t *self)
1448 {
1449     ucp_request_t *sreq = ucs_container_of(self, ucp_request_t, send.uct);
1450     ucp_rndv_data_hdr_t hdr;
1451 
1452     hdr.rreq_ptr = sreq->send.msg_proto.tag.rreq_ptr;
1453     hdr.offset   = 0;
1454     return ucp_do_am_zcopy_single(self, UCP_AM_ID_RNDV_DATA, &hdr, sizeof(hdr),
1455                                   ucp_rndv_am_zcopy_send_req_complete);
1456 }
1457 
ucp_rndv_progress_am_zcopy_multi(uct_pending_req_t * self)1458 static ucs_status_t ucp_rndv_progress_am_zcopy_multi(uct_pending_req_t *self)
1459 {
1460     ucp_request_t *sreq = ucs_container_of(self, ucp_request_t, send.uct);
1461     ucp_rndv_data_hdr_t hdr;
1462 
1463     hdr.rreq_ptr = sreq->send.msg_proto.tag.rreq_ptr;
1464     hdr.offset   = sreq->send.state.dt.offset;
1465     return ucp_do_am_zcopy_multi(self,
1466                                  UCP_AM_ID_RNDV_DATA,
1467                                  UCP_AM_ID_RNDV_DATA,
1468                                  &hdr, sizeof(hdr),
1469                                  &hdr, sizeof(hdr),
1470                                  ucp_rndv_am_zcopy_send_req_complete, 1);
1471 }
1472 
1473 UCS_PROFILE_FUNC_VOID(ucp_rndv_send_frag_put_completion, (self, status),
1474                       uct_completion_t *self, ucs_status_t status)
1475 {
1476     ucp_request_t *freq = ucs_container_of(self, ucp_request_t, send.state.uct_comp);
1477     ucp_request_t *req  = freq->send.rndv_put.sreq;
1478 
1479     /* release memory descriptor */
1480     if (freq->send.mdesc) {
1481         ucs_mpool_put_inline((void *)freq->send.mdesc);
1482     }
1483 
1484     req->send.state.dt.offset += freq->send.length;
1485     ucs_assert(req->send.state.dt.offset <= req->send.length);
1486 
1487     /* send ATP for last fragment of the rndv request */
1488     if (req->send.length == req->send.state.dt.offset) {
1489         ucp_rndv_send_frag_atp(req, req->send.rndv_put.remote_request);
1490     }
1491 
1492     ucp_request_put(freq);
1493 }
1494 
1495 UCS_PROFILE_FUNC_VOID(ucp_rndv_put_pipeline_frag_get_completion, (self, status),
1496                       uct_completion_t *self, ucs_status_t status)
1497 {
1498     ucp_request_t *freq  = ucs_container_of(self, ucp_request_t, send.state.uct_comp);
1499     ucp_request_t *fsreq = freq->send.rndv_get.rreq;
1500 
1501     /* get completed on memtype endpoint to stage on host. send put request to receiver*/
1502     ucp_request_send_state_reset(freq, ucp_rndv_send_frag_put_completion,
1503                                  UCP_REQUEST_SEND_PROTO_RNDV_PUT);
1504     freq->send.rndv_put.remote_address   = fsreq->send.rndv_put.remote_address +
1505         (freq->send.rndv_get.remote_address - (uint64_t)fsreq->send.buffer);
1506     freq->send.ep                        = fsreq->send.ep;
1507     freq->send.uct.func                  = ucp_rndv_progress_rma_put_zcopy;
1508     freq->send.rndv_put.sreq             = fsreq;
1509     freq->send.rndv_put.rkey             = fsreq->send.rndv_put.rkey;
1510     freq->send.rndv_put.uct_rkey         = fsreq->send.rndv_put.uct_rkey;
1511     freq->send.lane                      = fsreq->send.lane;
1512     freq->send.state.dt.dt.contig.md_map = 0;
1513 
1514     ucp_request_send(freq, 0);
1515 }
1516 
ucp_rndv_send_start_put_pipeline(ucp_request_t * sreq,ucp_rndv_rtr_hdr_t * rndv_rtr_hdr)1517 static ucs_status_t ucp_rndv_send_start_put_pipeline(ucp_request_t *sreq,
1518                                                      ucp_rndv_rtr_hdr_t *rndv_rtr_hdr)
1519 {
1520     ucp_ep_h ep             = sreq->send.ep;
1521     ucp_ep_config_t *config = ucp_ep_config(ep);
1522     ucp_worker_h worker     = sreq->send.ep->worker;
1523     ucp_context_h context   = worker->context;
1524     const uct_md_attr_t *md_attr;
1525     ucp_request_t *freq;
1526     ucp_request_t *fsreq;
1527     ucp_md_index_t md_index;
1528     size_t max_frag_size, rndv_size, length;
1529     size_t offset, rndv_base_offset;
1530     size_t min_zcopy, max_zcopy;
1531 
1532     ucp_trace_req(sreq, "using put rndv pipeline protocol");
1533 
1534     /* Protocol:
1535      * Step 1: GET fragment from send buffer to HOST fragment buffer
1536      * Step 2: PUT from fragment HOST buffer to remote HOST fragment buffer
1537      * Step 3: send ATP for each fragment request
1538      */
1539 
1540     /* check if lane supports host memory, to stage sends through host memory */
1541     md_attr = ucp_ep_md_attr(sreq->send.ep, sreq->send.lane);
1542     if (!(md_attr->cap.reg_mem_types & UCS_BIT(UCS_MEMORY_TYPE_HOST))) {
1543         return UCS_ERR_UNSUPPORTED;
1544     }
1545 
1546     min_zcopy        = config->tag.rndv.min_put_zcopy;
1547     max_zcopy        = config->tag.rndv.max_put_zcopy;
1548     rndv_size        = ucs_min(rndv_rtr_hdr->size, sreq->send.length);
1549     max_frag_size    = ucs_min(context->config.ext.rndv_frag_size, max_zcopy);
1550     rndv_base_offset = rndv_rtr_hdr->offset;
1551 
1552     /* initialize send req state on first fragment rndv request */
1553     if (rndv_base_offset == 0) {
1554          ucp_request_send_state_reset(sreq, NULL, UCP_REQUEST_SEND_PROTO_RNDV_PUT);
1555     }
1556 
1557     /* internal send request allocated on sender side to handle send fragments for RTR */
1558     fsreq = ucp_request_get(worker);
1559     if (fsreq == NULL) {
1560         ucs_fatal("failed to allocate fragment receive request");
1561     }
1562 
1563     ucp_request_send_state_init(fsreq, ucp_dt_make_contig(1), 0);
1564     fsreq->send.buffer                  = UCS_PTR_BYTE_OFFSET(sreq->send.buffer,
1565                                                               rndv_base_offset);
1566     fsreq->send.length                  = rndv_size;
1567     fsreq->send.mem_type                = sreq->send.mem_type;
1568     fsreq->send.ep                      = sreq->send.ep;
1569     fsreq->send.lane                    = sreq->send.lane;
1570     fsreq->send.rndv_put.rkey           = sreq->send.rndv_put.rkey;
1571     fsreq->send.rndv_put.uct_rkey       = sreq->send.rndv_put.uct_rkey;
1572     fsreq->send.rndv_put.remote_request = rndv_rtr_hdr->rreq_ptr;
1573     fsreq->send.rndv_put.remote_address = rndv_rtr_hdr->address;
1574     fsreq->send.rndv_put.sreq           = sreq;
1575     fsreq->send.state.dt.offset         = 0;
1576 
1577     offset = 0;
1578     while (offset != rndv_size) {
1579         length = ucp_rndv_adjust_zcopy_length(min_zcopy, max_frag_size, 0,
1580                                               rndv_size, offset, rndv_size - offset);
1581 
1582         if (UCP_MEM_IS_ACCESSIBLE_FROM_CPU(sreq->send.mem_type)) {
1583             /* sbuf is in host, directly do put */
1584             freq = ucp_request_get(worker);
1585             if (ucs_unlikely(freq == NULL)) {
1586                 ucs_error("failed to allocate fragment receive request");
1587                 return UCS_ERR_NO_MEMORY;
1588             }
1589 
1590             ucp_request_send_state_reset(freq, ucp_rndv_send_frag_put_completion,
1591                                          UCP_REQUEST_SEND_PROTO_RNDV_PUT);
1592             md_index                              = ucp_ep_md_index(sreq->send.ep,
1593                                                                     sreq->send.lane);
1594             freq->send.ep                         = fsreq->send.ep;
1595             freq->send.buffer                     = UCS_PTR_BYTE_OFFSET(fsreq->send.buffer,
1596                                                                         offset);
1597             freq->send.datatype                   = ucp_dt_make_contig(1);
1598             freq->send.mem_type                   = UCS_MEMORY_TYPE_HOST;
1599             freq->send.state.dt.dt.contig.memh[0] =
1600                         ucp_memh_map2uct(sreq->send.state.dt.dt.contig.memh,
1601                                          sreq->send.state.dt.dt.contig.md_map, md_index);
1602             freq->send.state.dt.dt.contig.md_map  = UCS_BIT(md_index);
1603             freq->send.length                     = length;
1604             freq->send.uct.func                   = ucp_rndv_progress_rma_put_zcopy;
1605             freq->send.rndv_put.sreq              = fsreq;
1606             freq->send.rndv_put.rkey              = fsreq->send.rndv_put.rkey;
1607             freq->send.rndv_put.uct_rkey          = fsreq->send.rndv_put.uct_rkey;
1608             freq->send.rndv_put.remote_address    = rndv_rtr_hdr->address + offset;
1609             freq->send.rndv_put.remote_request    = rndv_rtr_hdr->rreq_ptr;
1610             freq->send.lane                       = fsreq->send.lane;
1611             freq->send.mdesc                      = NULL;
1612 
1613             ucp_request_send(freq, 0);
1614         } else {
1615             ucp_rndv_send_frag_get_mem_type(fsreq, 0, length,
1616                                             (uint64_t)UCS_PTR_BYTE_OFFSET(fsreq->send.buffer, offset),
1617                                             fsreq->send.mem_type, NULL, NULL, UCS_BIT(0),
1618                                             ucp_rndv_put_pipeline_frag_get_completion);
1619         }
1620 
1621         offset += length;
1622     }
1623 
1624     return UCS_OK;
1625 }
1626 
1627 UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_atp_handler,
1628                  (arg, data, length, flags),
1629                  void *arg, void *data, size_t length, unsigned flags)
1630 {
1631     ucp_reply_hdr_t *rep_hdr = data;
1632     ucp_request_t *req       = (ucp_request_t*) rep_hdr->reqptr;
1633 
1634     if (req->flags & UCP_REQUEST_FLAG_RNDV_FRAG) {
1635         /* received ATP for frag RTR request */
1636         ucs_assert(req->recv.frag.rreq != NULL);
1637         UCS_PROFILE_REQUEST_EVENT(req, "rndv_frag_atp_recv", 0);
1638         ucp_rndv_recv_frag_put_mem_type(req->recv.frag.rreq, NULL, req,
1639                                         ((ucp_mem_desc_t*) req->recv.buffer - 1),
1640                                         req->recv.length, req->recv.frag.offset);
1641     } else {
1642         UCS_PROFILE_REQUEST_EVENT(req, "rndv_atp_recv", 0);
1643         ucp_rndv_zcopy_recv_req_complete(req, UCS_OK);
1644     }
1645 
1646     return UCS_OK;
1647 }
1648 
1649 UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_rtr_handler,
1650                  (arg, data, length, flags),
1651                  void *arg, void *data, size_t length, unsigned flags)
1652 {
1653     ucp_rndv_rtr_hdr_t *rndv_rtr_hdr = data;
1654     ucp_request_t *sreq              = (ucp_request_t*)rndv_rtr_hdr->sreq_ptr;
1655     ucp_ep_h ep                      = sreq->send.ep;
1656     ucp_ep_config_t *ep_config       = ucp_ep_config(ep);
1657     ucp_context_h context            = ep->worker->context;
1658     ucs_status_t status;
1659     int is_pipeline_rndv;
1660 
1661     ucp_trace_req(sreq, "received rtr address 0x%lx remote rreq 0x%lx",
1662                   rndv_rtr_hdr->address, rndv_rtr_hdr->rreq_ptr);
1663     UCS_PROFILE_REQUEST_EVENT(sreq, "rndv_rtr_recv", 0);
1664 
1665     if (sreq->flags & UCP_REQUEST_FLAG_OFFLOADED) {
1666         /* Do not deregister memory here, because am zcopy rndv may
1667          * need it registered (if am and tag is the same lane). */
1668         ucp_tag_offload_cancel_rndv(sreq);
1669     }
1670 
1671     if (UCP_DT_IS_CONTIG(sreq->send.datatype) && rndv_rtr_hdr->address) {
1672         status = ucp_ep_rkey_unpack(ep, rndv_rtr_hdr + 1,
1673                                     &sreq->send.rndv_put.rkey);
1674         if (status != UCS_OK) {
1675             ucs_fatal("failed to unpack rendezvous remote key received from %s: %s",
1676                       ucp_ep_peer_name(ep), ucs_status_string(status));
1677         }
1678 
1679         is_pipeline_rndv = ((!UCP_MEM_IS_ACCESSIBLE_FROM_CPU(sreq->send.mem_type) ||
1680                              (sreq->send.length != rndv_rtr_hdr->size)) &&
1681                             (context->config.ext.rndv_mode != UCP_RNDV_MODE_PUT_ZCOPY));
1682 
1683         sreq->send.lane = ucp_rkey_find_rma_lane(ep->worker->context, ep_config,
1684                                                  (is_pipeline_rndv ?
1685                                                   sreq->send.rndv_put.rkey->mem_type :
1686                                                   sreq->send.mem_type),
1687                                                  ep_config->tag.rndv.put_zcopy_lanes,
1688                                                  sreq->send.rndv_put.rkey, 0,
1689                                                  &sreq->send.rndv_put.uct_rkey);
1690         if (sreq->send.lane != UCP_NULL_LANE) {
1691             /*
1692              * Try pipeline protocol for non-host memory, if PUT_ZCOPY protocol is
1693              * not explicitly required. If pipeline is UNSUPPORTED, fallback to
1694              * PUT_ZCOPY anyway.
1695              */
1696             if (is_pipeline_rndv) {
1697                 status = ucp_rndv_send_start_put_pipeline(sreq, rndv_rtr_hdr);
1698                 if (status != UCS_ERR_UNSUPPORTED) {
1699                     return status;
1700                 }
1701                 /* If we get here, it means that RNDV pipeline protocol is
1702                  * unsupported and we have to use PUT_ZCOPY RNDV scheme instead */
1703             }
1704 
1705             if ((context->config.ext.rndv_mode != UCP_RNDV_MODE_GET_ZCOPY) &&
1706                 ucp_rndv_test_zcopy_scheme_support(sreq->send.length,
1707                                                    ep_config->tag.rndv.min_put_zcopy,
1708                                                    ep_config->tag.rndv.max_put_zcopy,
1709                                                    ep_config->tag.rndv.put_zcopy_split)) {
1710                 ucp_request_send_state_reset(sreq, ucp_rndv_put_completion,
1711                                              UCP_REQUEST_SEND_PROTO_RNDV_PUT);
1712                 sreq->send.uct.func                = ucp_rndv_progress_rma_put_zcopy;
1713                 sreq->send.rndv_put.remote_request = rndv_rtr_hdr->rreq_ptr;
1714                 sreq->send.rndv_put.remote_address = rndv_rtr_hdr->address;
1715                 sreq->send.mdesc                   = NULL;
1716                 goto out_send;
1717             } else {
1718                 ucp_rkey_destroy(sreq->send.rndv_put.rkey);
1719             }
1720         } else {
1721             ucp_rkey_destroy(sreq->send.rndv_put.rkey);
1722         }
1723     }
1724 
1725     ucp_trace_req(sreq, "using rdnv_data protocol");
1726 
1727     /* switch to AM */
1728     sreq->send.msg_proto.tag.rreq_ptr = rndv_rtr_hdr->rreq_ptr;
1729 
1730     if (UCP_DT_IS_CONTIG(sreq->send.datatype) &&
1731         (sreq->send.length >=
1732          ep_config->am.mem_type_zcopy_thresh[sreq->send.mem_type]))
1733     {
1734         status = ucp_request_send_buffer_reg_lane(sreq, ucp_ep_get_am_lane(ep), 0);
1735         ucs_assert_always(status == UCS_OK);
1736 
1737         ucp_request_send_state_reset(sreq, ucp_rndv_am_zcopy_completion,
1738                                      UCP_REQUEST_SEND_PROTO_ZCOPY_AM);
1739 
1740         if ((sreq->send.length + sizeof(ucp_rndv_data_hdr_t)) <=
1741             ep_config->am.max_zcopy) {
1742             sreq->send.uct.func = ucp_rndv_progress_am_zcopy_single;
1743         } else {
1744             sreq->send.uct.func              = ucp_rndv_progress_am_zcopy_multi;
1745             sreq->send.msg_proto.am_bw_index = 1;
1746         }
1747     } else {
1748         ucp_request_send_state_reset(sreq, NULL, UCP_REQUEST_SEND_PROTO_BCOPY_AM);
1749         sreq->send.uct.func              = ucp_rndv_progress_am_bcopy;
1750         sreq->send.msg_proto.am_bw_index = 1;
1751     }
1752 
1753 out_send:
1754     ucp_request_send(sreq, 0);
1755     return UCS_OK;
1756 }
1757 
1758 UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_data_handler,
1759                  (arg, data, length, flags),
1760                  void *arg, void *data, size_t length, unsigned flags)
1761 {
1762     ucp_rndv_data_hdr_t *rndv_data_hdr = data;
1763     ucp_request_t *rreq = (ucp_request_t*) rndv_data_hdr->rreq_ptr;
1764     size_t recv_len;
1765 
1766     ucs_assert(!(rreq->flags & UCP_REQUEST_FLAG_RNDV_FRAG));
1767 
1768     recv_len = length - sizeof(*rndv_data_hdr);
1769     UCS_PROFILE_REQUEST_EVENT(rreq, "rndv_data_recv", recv_len);
1770 
1771     (void)ucp_tag_request_process_recv_data(rreq, rndv_data_hdr + 1, recv_len,
1772                                             rndv_data_hdr->offset, 1, 0);
1773     return UCS_OK;
1774 }
1775 
ucp_rndv_dump_rkey(const void * packed_rkey,char * buffer,size_t max)1776 static void ucp_rndv_dump_rkey(const void *packed_rkey, char *buffer, size_t max)
1777 {
1778     char *p    = buffer;
1779     char *endp = buffer + max;
1780 
1781     snprintf(p, endp - p, " rkey ");
1782     p += strlen(p);
1783 
1784     ucp_rkey_dump_packed(packed_rkey, p, endp - p);
1785 }
1786 
ucp_rndv_dump(ucp_worker_h worker,uct_am_trace_type_t type,uint8_t id,const void * data,size_t length,char * buffer,size_t max)1787 static void ucp_rndv_dump(ucp_worker_h worker, uct_am_trace_type_t type,
1788                           uint8_t id, const void *data, size_t length,
1789                           char *buffer, size_t max)
1790 {
1791 
1792     const ucp_rndv_rts_hdr_t *rndv_rts_hdr = data;
1793     const ucp_rndv_rtr_hdr_t *rndv_rtr_hdr = data;
1794     const ucp_rndv_data_hdr_t *rndv_data = data;
1795     const ucp_reply_hdr_t *rep_hdr = data;
1796 
1797     switch (id) {
1798     case UCP_AM_ID_RNDV_RTS:
1799         ucs_assert(rndv_rts_hdr->sreq.ep_ptr != 0);
1800         snprintf(buffer, max, "RNDV_RTS tag %"PRIx64" ep_ptr %lx sreq 0x%lx "
1801                  "address 0x%"PRIx64" size %zu", rndv_rts_hdr->super.tag,
1802                  rndv_rts_hdr->sreq.ep_ptr, rndv_rts_hdr->sreq.reqptr,
1803                  rndv_rts_hdr->address, rndv_rts_hdr->size);
1804         if (rndv_rts_hdr->address) {
1805             ucp_rndv_dump_rkey(rndv_rts_hdr + 1, buffer + strlen(buffer),
1806                                max - strlen(buffer));
1807         }
1808         break;
1809     case UCP_AM_ID_RNDV_ATS:
1810         snprintf(buffer, max, "RNDV_ATS sreq 0x%lx status '%s'",
1811                  rep_hdr->reqptr, ucs_status_string(rep_hdr->status));
1812         break;
1813     case UCP_AM_ID_RNDV_RTR:
1814         snprintf(buffer, max, "RNDV_RTR sreq 0x%lx rreq 0x%lx address 0x%lx",
1815                  rndv_rtr_hdr->sreq_ptr, rndv_rtr_hdr->rreq_ptr,
1816                  rndv_rtr_hdr->address);
1817         if (rndv_rtr_hdr->address) {
1818             ucp_rndv_dump_rkey(rndv_rtr_hdr + 1, buffer + strlen(buffer),
1819                                max - strlen(buffer));
1820         }
1821         break;
1822     case UCP_AM_ID_RNDV_DATA:
1823         snprintf(buffer, max, "RNDV_DATA rreq 0x%"PRIx64" offset %zu",
1824                  rndv_data->rreq_ptr, rndv_data->offset);
1825         break;
1826     case UCP_AM_ID_RNDV_ATP:
1827         snprintf(buffer, max, "RNDV_ATP sreq 0x%lx status '%s'",
1828                  rep_hdr->reqptr, ucs_status_string(rep_hdr->status));
1829         break;
1830     default:
1831         return;
1832     }
1833 }
1834 
1835 UCP_DEFINE_AM(UCP_FEATURE_TAG, UCP_AM_ID_RNDV_RTS, ucp_rndv_rts_handler,
1836               ucp_rndv_dump, 0);
1837 UCP_DEFINE_AM(UCP_FEATURE_TAG, UCP_AM_ID_RNDV_ATS, ucp_rndv_ats_handler,
1838               ucp_rndv_dump, 0);
1839 UCP_DEFINE_AM(UCP_FEATURE_TAG, UCP_AM_ID_RNDV_ATP, ucp_rndv_atp_handler,
1840               ucp_rndv_dump, 0);
1841 UCP_DEFINE_AM(UCP_FEATURE_TAG, UCP_AM_ID_RNDV_RTR, ucp_rndv_rtr_handler,
1842               ucp_rndv_dump, 0);
1843 UCP_DEFINE_AM(UCP_FEATURE_TAG, UCP_AM_ID_RNDV_DATA, ucp_rndv_data_handler,
1844               ucp_rndv_dump, 0);
1845 
1846 UCP_DEFINE_AM_PROXY(UCP_AM_ID_RNDV_RTS);
1847 UCP_DEFINE_AM_PROXY(UCP_AM_ID_RNDV_ATS);
1848 UCP_DEFINE_AM_PROXY(UCP_AM_ID_RNDV_ATP);
1849 UCP_DEFINE_AM_PROXY(UCP_AM_ID_RNDV_RTR);
1850 UCP_DEFINE_AM_PROXY(UCP_AM_ID_RNDV_DATA);
1851