1 /**
2 * Copyright (C) Mellanox Technologies Ltd. 2001-2019. ALL RIGHTS RESERVED.
3 *
4 * See file LICENSE for terms.
5 */
6
7 #ifdef HAVE_CONFIG_H
8 # include "config.h"
9 #endif
10
11 #include "rndv.h"
12 #include "tag_match.inl"
13 #include "offload.h"
14
15 #include <ucp/proto/proto_am.inl>
16 #include <ucs/datastruct/queue.h>
17
ucp_rndv_is_recv_pipeline_needed(ucp_request_t * rndv_req,const ucp_rndv_rts_hdr_t * rndv_rts_hdr,ucs_memory_type_t mem_type,int is_get_zcopy_failed)18 static int ucp_rndv_is_recv_pipeline_needed(ucp_request_t *rndv_req,
19 const ucp_rndv_rts_hdr_t *rndv_rts_hdr,
20 ucs_memory_type_t mem_type,
21 int is_get_zcopy_failed)
22 {
23 const ucp_ep_config_t *ep_config = ucp_ep_config(rndv_req->send.ep);
24 ucp_context_h context = rndv_req->send.ep->worker->context;
25 int found = 0;
26 ucp_md_index_t md_index;
27 uct_md_attr_t *md_attr;
28 uint64_t mem_types;
29 int i;
30
31 for (i = 0;
32 (i < UCP_MAX_LANES) &&
33 (ep_config->key.rma_bw_lanes[i] != UCP_NULL_LANE); i++) {
34 md_index = ep_config->md_index[ep_config->key.rma_bw_lanes[i]];
35 if (context->tl_mds[md_index].attr.cap.access_mem_type == UCS_MEMORY_TYPE_HOST) {
36 found = 1;
37 break;
38 }
39 }
40
41 /* no host bw lanes for pipeline staging */
42 if (!found) {
43 return 0;
44 }
45
46 if (is_get_zcopy_failed) {
47 return 1;
48 }
49
50 /* disqualify recv side pipeline if
51 * a mem_type bw lane exist AND
52 * lane can do RMA on remote mem_type
53 */
54 mem_types = UCS_BIT(mem_type);
55 if (rndv_rts_hdr->address) {
56 mem_types |= UCS_BIT(ucp_rkey_packed_mem_type(rndv_rts_hdr + 1));
57 }
58
59 ucs_for_each_bit(md_index, ep_config->key.rma_bw_md_map) {
60 md_attr = &context->tl_mds[md_index].attr;
61 if (ucs_test_all_flags(md_attr->cap.reg_mem_types, mem_types)) {
62 return 0;
63 }
64 }
65
66 return 1;
67 }
68
ucp_rndv_is_put_pipeline_needed(uintptr_t remote_address,size_t length,size_t min_get_zcopy,size_t max_get_zcopy,int is_get_zcopy_failed)69 static int ucp_rndv_is_put_pipeline_needed(uintptr_t remote_address,
70 size_t length, size_t min_get_zcopy,
71 size_t max_get_zcopy,
72 int is_get_zcopy_failed)
73 {
74 /* fallback to PUT pipeline if remote mem type is non-HOST memory OR
75 * can't do GET ZCOPY */
76 return ((remote_address == 0) || (max_get_zcopy == 0) ||
77 (length < min_get_zcopy) || is_get_zcopy_failed);
78 }
79
ucp_tag_rndv_rts_pack(void * dest,void * arg)80 size_t ucp_tag_rndv_rts_pack(void *dest, void *arg)
81 {
82 ucp_request_t *sreq = arg; /* send request */
83 ucp_rndv_rts_hdr_t *rndv_rts_hdr = dest;
84 ucp_worker_h worker = sreq->send.ep->worker;
85 ssize_t packed_rkey_size;
86
87 rndv_rts_hdr->super.tag = sreq->send.msg_proto.tag.tag;
88 rndv_rts_hdr->sreq.reqptr = (uintptr_t)sreq;
89 rndv_rts_hdr->sreq.ep_ptr = ucp_request_get_dest_ep_ptr(sreq);
90 rndv_rts_hdr->size = sreq->send.length;
91
92 /* Pack remote keys (which can be empty list) */
93 if (UCP_DT_IS_CONTIG(sreq->send.datatype) &&
94 ucp_rndv_is_get_zcopy(sreq, worker->context)) {
95 /* pack rkey, ask target to do get_zcopy */
96 rndv_rts_hdr->address = (uintptr_t)sreq->send.buffer;
97 packed_rkey_size = ucp_rkey_pack_uct(worker->context,
98 sreq->send.state.dt.dt.contig.md_map,
99 sreq->send.state.dt.dt.contig.memh,
100 sreq->send.mem_type,
101 rndv_rts_hdr + 1);
102 if (packed_rkey_size < 0) {
103 ucs_fatal("failed to pack rendezvous remote key: %s",
104 ucs_status_string((ucs_status_t)packed_rkey_size));
105 }
106
107 ucs_assert(packed_rkey_size <=
108 ucp_ep_config(sreq->send.ep)->tag.rndv.rkey_size);
109 } else {
110 rndv_rts_hdr->address = 0;
111 packed_rkey_size = 0;
112 }
113
114 return sizeof(*rndv_rts_hdr) + packed_rkey_size;
115 }
116
117 UCS_PROFILE_FUNC(ucs_status_t, ucp_proto_progress_rndv_rts, (self),
118 uct_pending_req_t *self)
119 {
120 ucp_request_t *sreq = ucs_container_of(self, ucp_request_t, send.uct);
121 size_t packed_rkey_size;
122
123 /* send the RTS. the pack_cb will pack all the necessary fields in the RTS */
124 packed_rkey_size = ucp_ep_config(sreq->send.ep)->tag.rndv.rkey_size;
125 return ucp_do_am_single(self, UCP_AM_ID_RNDV_RTS, ucp_tag_rndv_rts_pack,
126 sizeof(ucp_rndv_rts_hdr_t) + packed_rkey_size);
127 }
128
ucp_tag_rndv_rtr_pack(void * dest,void * arg)129 static size_t ucp_tag_rndv_rtr_pack(void *dest, void *arg)
130 {
131 ucp_request_t *rndv_req = arg;
132 ucp_rndv_rtr_hdr_t *rndv_rtr_hdr = dest;
133 ucp_request_t *rreq = rndv_req->send.rndv_rtr.rreq;
134 ssize_t packed_rkey_size;
135
136 rndv_rtr_hdr->sreq_ptr = rndv_req->send.rndv_rtr.remote_request;
137 rndv_rtr_hdr->rreq_ptr = (uintptr_t)rreq; /* request of receiver side */
138
139 /* Pack remote keys (which can be empty list) */
140 if (UCP_DT_IS_CONTIG(rreq->recv.datatype)) {
141 rndv_rtr_hdr->address = (uintptr_t)rreq->recv.buffer;
142 rndv_rtr_hdr->size = rndv_req->send.rndv_rtr.length;
143 rndv_rtr_hdr->offset = rndv_req->send.rndv_rtr.offset;
144
145 packed_rkey_size = ucp_rkey_pack_uct(rndv_req->send.ep->worker->context,
146 rreq->recv.state.dt.contig.md_map,
147 rreq->recv.state.dt.contig.memh,
148 rreq->recv.mem_type,
149 rndv_rtr_hdr + 1);
150 if (packed_rkey_size < 0) {
151 return packed_rkey_size;
152 }
153 } else {
154 rndv_rtr_hdr->address = 0;
155 rndv_rtr_hdr->size = 0;
156 rndv_rtr_hdr->offset = 0;
157 packed_rkey_size = 0;
158 }
159
160 return sizeof(*rndv_rtr_hdr) + packed_rkey_size;
161 }
162
163 UCS_PROFILE_FUNC(ucs_status_t, ucp_proto_progress_rndv_rtr, (self),
164 uct_pending_req_t *self)
165 {
166 ucp_request_t *rndv_req = ucs_container_of(self, ucp_request_t, send.uct);
167 size_t packed_rkey_size;
168 ucs_status_t status;
169
170 /* send the RTR. the pack_cb will pack all the necessary fields in the RTR */
171 packed_rkey_size = ucp_ep_config(rndv_req->send.ep)->tag.rndv.rkey_size;
172 status = ucp_do_am_single(self, UCP_AM_ID_RNDV_RTR, ucp_tag_rndv_rtr_pack,
173 sizeof(ucp_rndv_rtr_hdr_t) + packed_rkey_size);
174 if (status == UCS_OK) {
175 /* release rndv request */
176 ucp_request_put(rndv_req);
177 }
178
179 return status;
180 }
181
ucp_tag_rndv_reg_send_buffer(ucp_request_t * sreq)182 ucs_status_t ucp_tag_rndv_reg_send_buffer(ucp_request_t *sreq)
183 {
184 ucp_ep_h ep = sreq->send.ep;
185 ucp_md_map_t md_map;
186 ucs_status_t status;
187
188 if (UCP_DT_IS_CONTIG(sreq->send.datatype) &&
189 ucp_rndv_is_get_zcopy(sreq, ep->worker->context)) {
190
191 /* register a contiguous buffer for rma_get */
192 md_map = ucp_ep_config(ep)->key.rma_bw_md_map;
193
194 /* Pass UCT_MD_MEM_FLAG_HIDE_ERRORS flag, because registration may fail
195 * if md does not support send memory type (e.g. CUDA memory). In this
196 * case RTS will be sent with empty key, and sender will fallback to
197 * PUT or pipeline protocols. */
198 status = ucp_request_send_buffer_reg(sreq, md_map,
199 UCT_MD_MEM_FLAG_HIDE_ERRORS);
200 if (status != UCS_OK) {
201 return status;
202 }
203 }
204
205 return UCS_OK;
206 }
207
ucp_tag_send_start_rndv(ucp_request_t * sreq)208 ucs_status_t ucp_tag_send_start_rndv(ucp_request_t *sreq)
209 {
210 ucp_ep_h ep = sreq->send.ep;
211 ucs_status_t status;
212
213 ucp_trace_req(sreq, "start_rndv to %s buffer %p length %zu",
214 ucp_ep_peer_name(ep), sreq->send.buffer,
215 sreq->send.length);
216 UCS_PROFILE_REQUEST_EVENT(sreq, "start_rndv", sreq->send.length);
217
218 status = ucp_ep_resolve_dest_ep_ptr(ep, sreq->send.lane);
219 if (status != UCS_OK) {
220 return status;
221 }
222
223 if (ucp_ep_is_tag_offload_enabled(ucp_ep_config(ep))) {
224 status = ucp_tag_offload_start_rndv(sreq);
225 } else {
226 ucs_assert(sreq->send.lane == ucp_ep_get_am_lane(ep));
227 sreq->send.uct.func = ucp_proto_progress_rndv_rts;
228 status = ucp_tag_rndv_reg_send_buffer(sreq);
229 }
230
231 return status;
232 }
233
234 static UCS_F_ALWAYS_INLINE size_t
ucp_rndv_adjust_zcopy_length(size_t min_zcopy,size_t max_zcopy,size_t align,size_t send_length,size_t offset,size_t length)235 ucp_rndv_adjust_zcopy_length(size_t min_zcopy, size_t max_zcopy, size_t align,
236 size_t send_length, size_t offset, size_t length)
237 {
238 size_t result_length, tail;
239
240 ucs_assert(length > 0);
241
242 /* ensure that the current length is over min_zcopy */
243 result_length = ucs_max(length, min_zcopy);
244
245 /* ensure that the current length is less than max_zcopy */
246 result_length = ucs_min(result_length, max_zcopy);
247
248 /* ensure that tail (rest of message) is over min_zcopy */
249 ucs_assertv(send_length >= (offset + result_length),
250 "send_length=%zu, offset=%zu, length=%zu",
251 send_length, offset, result_length);
252 tail = send_length - (offset + result_length);
253 if (ucs_unlikely((tail != 0) && (tail < min_zcopy))) {
254 /* ok, tail is less zcopy minimal & could not be processed as
255 * standalone operation */
256 /* check if we have room to increase current part and not
257 * step over max_zcopy */
258 if (result_length < (max_zcopy - tail)) {
259 /* if we can increase length by min_zcopy - let's do it to
260 * avoid small tail (we have limitation on minimal get zcopy) */
261 result_length += tail;
262 } else {
263 /* reduce current length by align or min_zcopy value
264 * to process it on next round */
265 ucs_assert(result_length > ucs_max(min_zcopy, align));
266 result_length -= ucs_max(min_zcopy, align);
267 }
268 }
269
270 ucs_assertv(result_length >= min_zcopy, "length=%zu, min_zcopy=%zu",
271 result_length, min_zcopy);
272 ucs_assertv(((send_length - (offset + result_length)) == 0) ||
273 ((send_length - (offset + result_length)) >= min_zcopy),
274 "send_length=%zu, offset=%zu, length=%zu, min_zcopy=%zu",
275 send_length, offset, result_length, min_zcopy);
276
277 return result_length;
278 }
279
ucp_rndv_complete_send(ucp_request_t * sreq,ucs_status_t status)280 static void ucp_rndv_complete_send(ucp_request_t *sreq, ucs_status_t status)
281 {
282 ucp_request_send_generic_dt_finish(sreq);
283 ucp_request_send_buffer_dereg(sreq);
284 ucp_request_complete_send(sreq, status);
285 }
286
ucp_rndv_req_send_ats(ucp_request_t * rndv_req,ucp_request_t * rreq,uintptr_t remote_request,ucs_status_t status)287 static void ucp_rndv_req_send_ats(ucp_request_t *rndv_req, ucp_request_t *rreq,
288 uintptr_t remote_request, ucs_status_t status)
289 {
290 ucp_trace_req(rndv_req, "send ats remote_request 0x%lx", remote_request);
291 UCS_PROFILE_REQUEST_EVENT(rreq, "send_ats", 0);
292
293 rndv_req->send.lane = ucp_ep_get_am_lane(rndv_req->send.ep);
294 rndv_req->send.uct.func = ucp_proto_progress_am_single;
295 rndv_req->send.proto.am_id = UCP_AM_ID_RNDV_ATS;
296 rndv_req->send.proto.status = status;
297 rndv_req->send.proto.remote_request = remote_request;
298 rndv_req->send.proto.comp_cb = ucp_request_put;
299
300 ucp_request_send(rndv_req, 0);
301 }
302
303 UCS_PROFILE_FUNC_VOID(ucp_rndv_complete_rma_put_zcopy, (sreq),
304 ucp_request_t *sreq)
305 {
306 ucp_trace_req(sreq, "rndv_put completed");
307 UCS_PROFILE_REQUEST_EVENT(sreq, "complete_rndv_put", 0);
308
309 ucp_request_send_buffer_dereg(sreq);
310 ucp_request_complete_send(sreq, UCS_OK);
311 }
312
ucp_rndv_send_atp(ucp_request_t * sreq,uintptr_t remote_request)313 static void ucp_rndv_send_atp(ucp_request_t *sreq, uintptr_t remote_request)
314 {
315 ucs_assertv(sreq->send.state.dt.offset == sreq->send.length,
316 "sreq=%p offset=%zu length=%zu", sreq,
317 sreq->send.state.dt.offset, sreq->send.length);
318
319 ucp_trace_req(sreq, "send atp remote_request 0x%lx", remote_request);
320 UCS_PROFILE_REQUEST_EVENT(sreq, "send_atp", 0);
321
322 /* destroy rkey before it gets overridden by ATP protocol data */
323 ucp_rkey_destroy(sreq->send.rndv_put.rkey);
324
325 sreq->send.lane = ucp_ep_get_am_lane(sreq->send.ep);
326 sreq->send.uct.func = ucp_proto_progress_am_single;
327 sreq->send.proto.am_id = UCP_AM_ID_RNDV_ATP;
328 sreq->send.proto.status = UCS_OK;
329 sreq->send.proto.remote_request = remote_request;
330 sreq->send.proto.comp_cb = ucp_rndv_complete_rma_put_zcopy;
331
332 ucp_request_send(sreq, 0);
333 }
334
335 UCS_PROFILE_FUNC_VOID(ucp_rndv_complete_frag_rma_put_zcopy, (fsreq),
336 ucp_request_t *fsreq)
337 {
338 ucp_request_t *sreq = fsreq->send.proto.sreq;
339
340 sreq->send.state.dt.offset += fsreq->send.length;
341
342 /* delete fragments send request */
343 ucp_request_put(fsreq);
344
345 /* complete send request after put completions of all fragments */
346 if (sreq->send.state.dt.offset == sreq->send.length) {
347 ucp_rndv_complete_rma_put_zcopy(sreq);
348 }
349 }
350
ucp_rndv_send_frag_atp(ucp_request_t * fsreq,uintptr_t remote_request)351 static void ucp_rndv_send_frag_atp(ucp_request_t *fsreq, uintptr_t remote_request)
352 {
353 ucp_trace_req(fsreq, "send frag atp remote_request 0x%lx", remote_request);
354 UCS_PROFILE_REQUEST_EVENT(fsreq, "send_frag_atp", 0);
355
356 /* destroy rkey before it gets overridden by ATP protocol data */
357 ucp_rkey_destroy(fsreq->send.rndv_put.rkey);
358
359 fsreq->send.lane = ucp_ep_get_am_lane(fsreq->send.ep);
360 fsreq->send.uct.func = ucp_proto_progress_am_single;
361 fsreq->send.proto.sreq = fsreq->send.rndv_put.sreq;
362 fsreq->send.proto.am_id = UCP_AM_ID_RNDV_ATP;
363 fsreq->send.proto.status = UCS_OK;
364 fsreq->send.proto.remote_request = remote_request;
365 fsreq->send.proto.comp_cb = ucp_rndv_complete_frag_rma_put_zcopy;
366
367 ucp_request_send(fsreq, 0);
368 }
369
ucp_rndv_zcopy_recv_req_complete(ucp_request_t * req,ucs_status_t status)370 static void ucp_rndv_zcopy_recv_req_complete(ucp_request_t *req, ucs_status_t status)
371 {
372 ucp_request_recv_buffer_dereg(req);
373 ucp_request_complete_tag_recv(req, status);
374 }
375
ucp_rndv_complete_rma_get_zcopy(ucp_request_t * rndv_req,ucs_status_t status)376 static void ucp_rndv_complete_rma_get_zcopy(ucp_request_t *rndv_req,
377 ucs_status_t status)
378 {
379 ucp_request_t *rreq = rndv_req->send.rndv_get.rreq;
380
381 ucs_assertv(rndv_req->send.state.dt.offset == rndv_req->send.length,
382 "rndv_req=%p offset=%zu length=%zu", rndv_req,
383 rndv_req->send.state.dt.offset, rndv_req->send.length);
384
385 ucp_trace_req(rndv_req, "rndv_get completed with status %s",
386 ucs_status_string(status));
387 UCS_PROFILE_REQUEST_EVENT(rreq, "complete_rndv_get", 0);
388
389 ucp_rkey_destroy(rndv_req->send.rndv_get.rkey);
390 ucp_request_send_buffer_dereg(rndv_req);
391
392 if (status == UCS_OK) {
393 ucp_rndv_req_send_ats(rndv_req, rreq,
394 rndv_req->send.rndv_get.remote_request, UCS_OK);
395 } else {
396 /* if completing RNDV with the error, just release RNDV request */
397 ucp_request_put(rndv_req);
398 }
399
400 ucp_rndv_zcopy_recv_req_complete(rreq, status);
401 }
402
ucp_rndv_recv_data_init(ucp_request_t * rreq,size_t size)403 static void ucp_rndv_recv_data_init(ucp_request_t *rreq, size_t size)
404 {
405 rreq->status = UCS_OK;
406 rreq->recv.tag.remaining = size;
407 }
408
ucp_rndv_req_send_rtr(ucp_request_t * rndv_req,ucp_request_t * rreq,uintptr_t sender_reqptr,size_t recv_length,size_t offset)409 static void ucp_rndv_req_send_rtr(ucp_request_t *rndv_req, ucp_request_t *rreq,
410 uintptr_t sender_reqptr, size_t recv_length,
411 size_t offset)
412 {
413 ucp_trace_req(rndv_req, "send rtr remote sreq 0x%lx rreq %p", sender_reqptr,
414 rreq);
415
416 rndv_req->send.lane = ucp_ep_get_am_lane(rndv_req->send.ep);
417 rndv_req->send.uct.func = ucp_proto_progress_rndv_rtr;
418 rndv_req->send.rndv_rtr.remote_request = sender_reqptr;
419 rndv_req->send.rndv_rtr.rreq = rreq;
420 rndv_req->send.rndv_rtr.length = recv_length;
421 rndv_req->send.rndv_rtr.offset = offset;
422
423 ucp_request_send(rndv_req, 0);
424 }
425
426 static ucp_lane_index_t
ucp_rndv_get_zcopy_get_lane(ucp_request_t * rndv_req,uct_rkey_t * uct_rkey)427 ucp_rndv_get_zcopy_get_lane(ucp_request_t *rndv_req, uct_rkey_t *uct_rkey)
428 {
429 ucp_lane_index_t lane_idx;
430 ucp_ep_config_t *ep_config;
431 ucp_rkey_h rkey;
432 uint8_t rkey_index;
433
434 if (ucs_unlikely(!rndv_req->send.rndv_get.lanes_map_all)) {
435 return UCP_NULL_LANE;
436 }
437
438 lane_idx = ucs_ffs64_safe(rndv_req->send.rndv_get.lanes_map_avail);
439 ucs_assert(lane_idx < UCP_MAX_LANES);
440 rkey = rndv_req->send.rndv_get.rkey;
441 rkey_index = rndv_req->send.rndv_get.rkey_index[lane_idx];
442 *uct_rkey = (rkey_index != UCP_NULL_RESOURCE) ?
443 rkey->tl_rkey[rkey_index].rkey.rkey : UCT_INVALID_RKEY;
444 ep_config = ucp_ep_config(rndv_req->send.ep);
445 return ep_config->tag.rndv.get_zcopy_lanes[lane_idx];
446 }
447
ucp_rndv_get_zcopy_next_lane(ucp_request_t * rndv_req)448 static void ucp_rndv_get_zcopy_next_lane(ucp_request_t *rndv_req)
449 {
450 rndv_req->send.rndv_get.lanes_map_avail &= rndv_req->send.rndv_get.lanes_map_avail - 1;
451 if (!rndv_req->send.rndv_get.lanes_map_avail) {
452 rndv_req->send.rndv_get.lanes_map_avail = rndv_req->send.rndv_get.lanes_map_all;
453 }
454 }
455
456 UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_progress_rma_get_zcopy, (self),
457 uct_pending_req_t *self)
458 {
459 ucp_request_t *rndv_req = ucs_container_of(self, ucp_request_t, send.uct);
460 ucp_ep_h ep = rndv_req->send.ep;
461 ucp_ep_config_t *config = ucp_ep_config(ep);
462 const size_t max_iovcnt = 1;
463 uct_iface_attr_t* attrs;
464 ucs_status_t status;
465 size_t offset, length, ucp_mtu, remaining, align, chunk;
466 uct_iov_t iov[max_iovcnt];
467 size_t iovcnt;
468 ucp_rsc_index_t rsc_index;
469 ucp_dt_state_t state;
470 uct_rkey_t uct_rkey;
471 size_t min_zcopy;
472 size_t max_zcopy;
473 int pending_add_res;
474 ucp_lane_index_t lane;
475
476 /* Figure out which lane to use for get operation */
477 rndv_req->send.lane = lane = ucp_rndv_get_zcopy_get_lane(rndv_req, &uct_rkey);
478
479 if (lane == UCP_NULL_LANE) {
480 /* If can't perform get_zcopy - switch to active-message.
481 * NOTE: we do not register memory and do not send our keys. */
482 ucp_trace_req(rndv_req, "remote memory unreachable, switch to rtr");
483 ucp_rkey_destroy(rndv_req->send.rndv_get.rkey);
484 ucp_rndv_recv_data_init(rndv_req->send.rndv_get.rreq,
485 rndv_req->send.length);
486 /* Update statistics counters from get_zcopy to rtr */
487 UCP_WORKER_STAT_RNDV(ep->worker, GET_ZCOPY, -1);
488 UCP_WORKER_STAT_RNDV(ep->worker, SEND_RTR, +1);
489 ucp_rndv_req_send_rtr(rndv_req, rndv_req->send.rndv_get.rreq,
490 rndv_req->send.rndv_get.remote_request,
491 rndv_req->send.length, 0ul);
492 return UCS_OK;
493 }
494
495 ucs_assert_always(rndv_req->send.rndv_get.lanes_count > 0);
496
497 if (!rndv_req->send.mdesc) {
498 status = ucp_send_request_add_reg_lane(rndv_req, lane);
499 ucs_assert_always(status == UCS_OK);
500 }
501
502 rsc_index = ucp_ep_get_rsc_index(ep, lane);
503 attrs = ucp_worker_iface_get_attr(ep->worker, rsc_index);
504 align = attrs->cap.get.opt_zcopy_align;
505 ucp_mtu = attrs->cap.get.align_mtu;
506 min_zcopy = config->tag.rndv.min_get_zcopy;
507 max_zcopy = config->tag.rndv.max_get_zcopy;
508
509 offset = rndv_req->send.state.dt.offset;
510 remaining = (uintptr_t)rndv_req->send.buffer % align;
511
512 if ((offset == 0) && (remaining > 0) && (rndv_req->send.length > ucp_mtu)) {
513 length = ucp_mtu - remaining;
514 } else {
515 chunk = ucs_align_up((size_t)(rndv_req->send.length /
516 rndv_req->send.rndv_get.lanes_count
517 * config->tag.rndv.scale[lane]),
518 align);
519 length = ucs_min(chunk, rndv_req->send.length - offset);
520 }
521
522 length = ucp_rndv_adjust_zcopy_length(min_zcopy, max_zcopy, align,
523 rndv_req->send.length, offset,
524 length);
525
526 ucs_trace_data("req %p: offset %zu remainder %zu rma-get to %p len %zu lane %d",
527 rndv_req, offset, remaining,
528 UCS_PTR_BYTE_OFFSET(rndv_req->send.buffer, offset),
529 length, lane);
530
531 state = rndv_req->send.state.dt;
532 /* TODO: is this correct? memh array may skip MD's where
533 * registration is not supported. for now SHM may avoid registration,
534 * but it will work on single lane */
535 ucp_dt_iov_copy_uct(ep->worker->context, iov, &iovcnt, max_iovcnt, &state,
536 rndv_req->send.buffer, ucp_dt_make_contig(1), length,
537 ucp_ep_md_index(ep, lane),
538 rndv_req->send.mdesc);
539
540 for (;;) {
541 status = uct_ep_get_zcopy(ep->uct_eps[lane],
542 iov, iovcnt,
543 rndv_req->send.rndv_get.remote_address + offset,
544 uct_rkey,
545 &rndv_req->send.state.uct_comp);
546 ucp_request_send_state_advance(rndv_req, &state,
547 UCP_REQUEST_SEND_PROTO_RNDV_GET,
548 status);
549 if (rndv_req->send.state.dt.offset == rndv_req->send.length) {
550 if (rndv_req->send.state.uct_comp.count == 0) {
551 rndv_req->send.state.uct_comp.func(&rndv_req->send.state.uct_comp, status);
552 }
553 return UCS_OK;
554 } else if (!UCS_STATUS_IS_ERR(status)) {
555 /* in case if not all chunks are transmitted - return in_progress
556 * status */
557 ucp_rndv_get_zcopy_next_lane(rndv_req);
558 return UCS_INPROGRESS;
559 } else {
560 if (status == UCS_ERR_NO_RESOURCE) {
561 if (lane != rndv_req->send.pending_lane) {
562 /* switch to new pending lane */
563 pending_add_res = ucp_request_pending_add(rndv_req, &status, 0);
564 if (!pending_add_res) {
565 /* failed to switch req to pending queue, try again */
566 continue;
567 }
568 ucs_assert(status == UCS_INPROGRESS);
569 return UCS_OK;
570 }
571 }
572 return status;
573 }
574 }
575 }
576
577 UCS_PROFILE_FUNC_VOID(ucp_rndv_get_completion, (self, status),
578 uct_completion_t *self, ucs_status_t status)
579 {
580 ucp_request_t *rndv_req = ucs_container_of(self, ucp_request_t,
581 send.state.uct_comp);
582
583 if (rndv_req->send.state.dt.offset == rndv_req->send.length) {
584 ucp_rndv_complete_rma_get_zcopy(rndv_req, status);
585 }
586 }
587
ucp_rndv_put_completion(uct_completion_t * self,ucs_status_t status)588 static void ucp_rndv_put_completion(uct_completion_t *self, ucs_status_t status)
589 {
590 ucp_request_t *sreq = ucs_container_of(self, ucp_request_t,
591 send.state.uct_comp);
592
593 if (sreq->send.state.dt.offset == sreq->send.length) {
594 ucp_rndv_send_atp(sreq, sreq->send.rndv_put.remote_request);
595 }
596 }
597
ucp_rndv_req_init_get_zcopy_lane_map(ucp_request_t * rndv_req)598 static void ucp_rndv_req_init_get_zcopy_lane_map(ucp_request_t *rndv_req)
599 {
600 ucp_ep_h ep = rndv_req->send.ep;
601 ucp_ep_config_t *ep_config = ucp_ep_config(ep);
602 ucp_context_h context = ep->worker->context;
603 ucs_memory_type_t mem_type = rndv_req->send.mem_type;
604 ucp_rkey_h rkey = rndv_req->send.rndv_get.rkey;
605 ucp_lane_map_t lane_map;
606 ucp_lane_index_t lane, lane_idx;
607 ucp_md_index_t md_index;
608 uct_md_attr_t *md_attr;
609 ucp_md_index_t dst_md_index;
610 ucp_rsc_index_t rsc_index;
611 uct_iface_attr_t *iface_attr;
612 double max_lane_bw, lane_bw;
613 int i;
614
615 max_lane_bw = 0;
616 lane_map = 0;
617 for (i = 0; i < UCP_MAX_LANES; i++) {
618 lane = ep_config->tag.rndv.get_zcopy_lanes[i];
619 if (lane == UCP_NULL_LANE) {
620 break; /* no more lanes */
621 }
622
623 md_index = ep_config->md_index[lane];
624 md_attr = &context->tl_mds[md_index].attr;
625 rsc_index = ep_config->key.lanes[lane].rsc_index;
626 iface_attr = ucp_worker_iface_get_attr(ep->worker, rsc_index);
627 lane_bw = ucp_tl_iface_bandwidth(context, &iface_attr->bandwidth);
628
629 if (ucs_unlikely((md_index != UCP_NULL_RESOURCE) &&
630 !(md_attr->cap.flags & UCT_MD_FLAG_NEED_RKEY))) {
631 /* Lane does not need rkey, can use the lane with invalid rkey */
632 if (!rkey || ((mem_type == md_attr->cap.access_mem_type) &&
633 (mem_type == rkey->mem_type))) {
634 rndv_req->send.rndv_get.rkey_index[i] = UCP_NULL_RESOURCE;
635 lane_map |= UCS_BIT(i);
636 max_lane_bw = ucs_max(max_lane_bw, lane_bw);
637 continue;
638 }
639 }
640
641 if (ucs_unlikely((md_index != UCP_NULL_RESOURCE) &&
642 (!(md_attr->cap.reg_mem_types & UCS_BIT(mem_type))))) {
643 continue;
644 }
645
646 dst_md_index = ep_config->key.lanes[lane].dst_md_index;
647 if (rkey && ucs_likely(rkey->md_map & UCS_BIT(dst_md_index))) {
648 /* Return first matching lane */
649 rndv_req->send.rndv_get.rkey_index[i] = ucs_bitmap2idx(rkey->md_map,
650 dst_md_index);
651 lane_map |= UCS_BIT(i);
652 max_lane_bw = ucs_max(max_lane_bw, lane_bw);
653 }
654 }
655
656 if (ucs_popcount(lane_map) > 1) {
657 /* remove lanes if bandwidth is too low comparing to the best lane */
658 ucs_for_each_bit(lane_idx, lane_map) {
659 ucs_assert(lane_idx < UCP_MAX_LANES);
660 lane = ep_config->tag.rndv.get_zcopy_lanes[lane_idx];
661 rsc_index = ep_config->key.lanes[lane].rsc_index;
662 iface_attr = ucp_worker_iface_get_attr(ep->worker, rsc_index);
663 lane_bw = ucp_tl_iface_bandwidth(context, &iface_attr->bandwidth);
664
665 if ((lane_bw/max_lane_bw) <
666 (1. / context->config.ext.multi_lane_max_ratio)) {
667 lane_map &= ~UCS_BIT(lane_idx);
668 rndv_req->send.rndv_get.rkey_index[lane_idx] = UCP_NULL_RESOURCE;
669 }
670 }
671 }
672
673 rndv_req->send.rndv_get.lanes_map_all = lane_map;
674 rndv_req->send.rndv_get.lanes_map_avail = lane_map;
675 rndv_req->send.rndv_get.lanes_count = ucs_popcount(lane_map);
676 }
677
ucp_rndv_req_send_rma_get(ucp_request_t * rndv_req,ucp_request_t * rreq,const ucp_rndv_rts_hdr_t * rndv_rts_hdr)678 static ucs_status_t ucp_rndv_req_send_rma_get(ucp_request_t *rndv_req,
679 ucp_request_t *rreq,
680 const ucp_rndv_rts_hdr_t *rndv_rts_hdr)
681 {
682 ucp_ep_h ep = rndv_req->send.ep;
683 ucs_status_t status;
684 uct_rkey_t uct_rkey;
685
686 ucp_trace_req(rndv_req, "start rma_get rreq %p", rreq);
687
688 rndv_req->send.uct.func = ucp_rndv_progress_rma_get_zcopy;
689 rndv_req->send.buffer = rreq->recv.buffer;
690 rndv_req->send.mem_type = rreq->recv.mem_type;
691 rndv_req->send.datatype = ucp_dt_make_contig(1);
692 rndv_req->send.length = rndv_rts_hdr->size;
693 rndv_req->send.rndv_get.remote_request = rndv_rts_hdr->sreq.reqptr;
694 rndv_req->send.rndv_get.remote_address = rndv_rts_hdr->address;
695 rndv_req->send.rndv_get.rreq = rreq;
696 rndv_req->send.datatype = rreq->recv.datatype;
697
698 status = ucp_ep_rkey_unpack(ep, rndv_rts_hdr + 1,
699 &rndv_req->send.rndv_get.rkey);
700 if (status != UCS_OK) {
701 ucs_fatal("failed to unpack rendezvous remote key received from %s: %s",
702 ucp_ep_peer_name(ep), ucs_status_string(status));
703 }
704
705 ucp_request_send_state_init(rndv_req, ucp_dt_make_contig(1), 0);
706 ucp_request_send_state_reset(rndv_req, ucp_rndv_get_completion,
707 UCP_REQUEST_SEND_PROTO_RNDV_GET);
708
709 ucp_rndv_req_init_get_zcopy_lane_map(rndv_req);
710
711 rndv_req->send.lane = ucp_rndv_get_zcopy_get_lane(rndv_req, &uct_rkey);
712 if (rndv_req->send.lane == UCP_NULL_LANE) {
713 return UCS_ERR_UNREACHABLE;
714 }
715
716 UCP_WORKER_STAT_RNDV(ep->worker, GET_ZCOPY, 1);
717 ucp_request_send(rndv_req, 0);
718
719 return UCS_OK;
720 }
721
722 UCS_PROFILE_FUNC_VOID(ucp_rndv_recv_frag_put_completion, (self, status),
723 uct_completion_t *self, ucs_status_t status)
724 {
725 ucp_request_t *freq = ucs_container_of(self, ucp_request_t,
726 send.state.uct_comp);
727 ucp_request_t *req = freq->send.rndv_put.sreq;
728 ucp_request_t *rndv_req = (ucp_request_t*)freq->send.rndv_put.remote_request;
729
730 ucs_trace_req("freq:%p: recv_frag_put done. rreq:%p ", freq, req);
731
732 /* release memory descriptor */
733 ucs_mpool_put_inline((void *)freq->send.mdesc);
734
735 /* rndv_req is NULL in case of put protocol */
736 if (rndv_req != NULL) {
737 /* pipeline recv get protocol */
738 rndv_req->send.state.dt.offset += freq->send.length;
739
740 /* send ATS for fragment get rndv completion */
741 if (rndv_req->send.length == rndv_req->send.state.dt.offset) {
742 ucp_rkey_destroy(rndv_req->send.rndv_get.rkey);
743 ucp_rndv_req_send_ats(rndv_req, req,
744 rndv_req->send.rndv_get.remote_request, UCS_OK);
745 }
746 }
747
748 req->recv.tag.remaining -= freq->send.length;
749 if (req->recv.tag.remaining == 0) {
750 ucp_request_complete_tag_recv(req, UCS_OK);
751 }
752
753 ucp_request_put(freq);
754 }
755
756 static UCS_F_ALWAYS_INLINE void
ucp_rndv_init_mem_type_frag_req(ucp_worker_h worker,ucp_request_t * freq,int rndv_op,uct_completion_callback_t comp_cb,ucp_mem_desc_t * mdesc,ucs_memory_type_t mem_type,size_t length,uct_pending_callback_t uct_func)757 ucp_rndv_init_mem_type_frag_req(ucp_worker_h worker, ucp_request_t *freq, int rndv_op,
758 uct_completion_callback_t comp_cb, ucp_mem_desc_t *mdesc,
759 ucs_memory_type_t mem_type, size_t length,
760 uct_pending_callback_t uct_func)
761 {
762 ucp_ep_h mem_type_ep;
763 ucp_md_index_t md_index;
764 ucp_lane_index_t mem_type_rma_lane;
765
766 ucp_request_send_state_init(freq, ucp_dt_make_contig(1), 0);
767 ucp_request_send_state_reset(freq, comp_cb, rndv_op);
768
769 freq->send.buffer = mdesc + 1;
770 freq->send.length = length;
771 freq->send.datatype = ucp_dt_make_contig(1);
772 freq->send.mem_type = mem_type;
773 freq->send.mdesc = mdesc;
774 freq->send.uct.func = uct_func;
775
776 if (mem_type != UCS_MEMORY_TYPE_HOST) {
777 mem_type_ep = worker->mem_type_ep[mem_type];
778 mem_type_rma_lane = ucp_ep_config(mem_type_ep)->key.rma_bw_lanes[0];
779 md_index = ucp_ep_md_index(mem_type_ep, mem_type_rma_lane);
780 ucs_assert(mem_type_rma_lane != UCP_NULL_LANE);
781
782 freq->send.lane = mem_type_rma_lane;
783 freq->send.ep = mem_type_ep;
784 freq->send.state.dt.dt.contig.memh[0] = ucp_memh2uct(mdesc->memh, md_index);
785 freq->send.state.dt.dt.contig.md_map = UCS_BIT(md_index);
786 }
787 }
788
789 static void
ucp_rndv_recv_frag_put_mem_type(ucp_request_t * rreq,ucp_request_t * rndv_req,ucp_request_t * freq,ucp_mem_desc_t * mdesc,size_t length,size_t offset)790 ucp_rndv_recv_frag_put_mem_type(ucp_request_t *rreq, ucp_request_t *rndv_req,
791 ucp_request_t *freq, ucp_mem_desc_t *mdesc,
792 size_t length, size_t offset)
793 {
794
795 ucs_assert_always(!UCP_MEM_IS_ACCESSIBLE_FROM_CPU(rreq->recv.mem_type));
796
797 /* PUT on memtype endpoint to stage from
798 * frag recv buffer to memtype recv buffer
799 */
800
801 ucp_rndv_init_mem_type_frag_req(rreq->recv.worker, freq, UCP_REQUEST_SEND_PROTO_RNDV_PUT,
802 ucp_rndv_recv_frag_put_completion, mdesc, rreq->recv.mem_type,
803 length, ucp_rndv_progress_rma_put_zcopy);
804
805 freq->send.rndv_put.sreq = rreq;
806 freq->send.rndv_put.rkey = NULL;
807 freq->send.rndv_put.remote_request = (uintptr_t)rndv_req;
808 freq->send.rndv_put.remote_address = (uintptr_t)rreq->recv.buffer + offset;
809
810 ucp_request_send(freq, 0);
811 }
812
813 static ucs_status_t
ucp_rndv_send_frag_get_mem_type(ucp_request_t * sreq,uintptr_t rreq_ptr,size_t length,uint64_t remote_address,ucs_memory_type_t remote_mem_type,ucp_rkey_h rkey,uint8_t * rkey_index,ucp_lane_map_t lanes_map,uct_completion_callback_t comp_cb)814 ucp_rndv_send_frag_get_mem_type(ucp_request_t *sreq, uintptr_t rreq_ptr,
815 size_t length, uint64_t remote_address,
816 ucs_memory_type_t remote_mem_type, ucp_rkey_h rkey,
817 uint8_t *rkey_index, ucp_lane_map_t lanes_map,
818 uct_completion_callback_t comp_cb)
819 {
820 ucp_worker_h worker = sreq->send.ep->worker;
821 ucp_request_t *freq;
822 ucp_mem_desc_t *mdesc;
823 ucp_lane_index_t i;
824
825 /* GET fragment to stage buffer */
826
827 freq = ucp_request_get(worker);
828 if (ucs_unlikely(freq == NULL)) {
829 ucs_error("failed to allocate fragment receive request");
830 return UCS_ERR_NO_MEMORY;
831 }
832
833 mdesc = ucp_worker_mpool_get(&worker->rndv_frag_mp);
834 if (ucs_unlikely(mdesc == NULL)) {
835 ucs_error("failed to allocate fragment memory desc");
836 return UCS_ERR_NO_MEMORY;
837 }
838
839 freq->send.ep = sreq->send.ep;
840
841 ucp_rndv_init_mem_type_frag_req(worker, freq, UCP_REQUEST_SEND_PROTO_RNDV_GET,
842 comp_cb, mdesc, remote_mem_type, length,
843 ucp_rndv_progress_rma_get_zcopy);
844
845 freq->send.rndv_get.rkey = rkey;
846 freq->send.rndv_get.remote_address = remote_address;
847 freq->send.rndv_get.remote_request = rreq_ptr;
848 freq->send.rndv_get.rreq = sreq;
849 freq->send.rndv_get.lanes_map_all = lanes_map;
850 freq->send.rndv_get.lanes_map_avail = lanes_map;
851 freq->send.rndv_get.lanes_count = ucs_popcount(lanes_map);
852
853 for (i = 0; i < UCP_MAX_LANES; i++) {
854 freq->send.rndv_get.rkey_index[i] = rkey_index ? rkey_index[i]
855 : UCP_NULL_RESOURCE;
856 }
857
858
859 return ucp_request_send(freq, 0);
860 }
861
862 UCS_PROFILE_FUNC_VOID(ucp_rndv_recv_frag_get_completion, (self, status),
863 uct_completion_t *self, ucs_status_t status)
864 {
865 ucp_request_t *freq = ucs_container_of(self, ucp_request_t,
866 send.state.uct_comp);
867 ucp_request_t *rndv_req = freq->send.rndv_get.rreq;
868 ucp_request_t *rreq = rndv_req->send.rndv_get.rreq;
869
870 ucs_trace_req("freq:%p: recv_frag_get done. rreq:%p length:%ld offset:%ld",
871 freq, rndv_req, freq->send.length,
872 freq->send.rndv_get.remote_address - rndv_req->send.rndv_get.remote_address);
873
874 /* fragment GET completed from remote to staging buffer, issue PUT from
875 * staging buffer to recv buffer */
876 ucp_rndv_recv_frag_put_mem_type(rreq, rndv_req, freq,
877 (ucp_mem_desc_t *)freq->send.buffer -1,
878 freq->send.length, (freq->send.rndv_get.remote_address -
879 rndv_req->send.rndv_get.remote_address));
880 }
881
882 static ucs_status_t
ucp_rndv_recv_start_get_pipeline(ucp_worker_h worker,ucp_request_t * rndv_req,ucp_request_t * rreq,uintptr_t remote_request,const void * rkey_buffer,uint64_t remote_address,size_t size,size_t base_offset)883 ucp_rndv_recv_start_get_pipeline(ucp_worker_h worker, ucp_request_t *rndv_req,
884 ucp_request_t *rreq, uintptr_t remote_request,
885 const void *rkey_buffer, uint64_t remote_address,
886 size_t size, size_t base_offset)
887 {
888 ucp_ep_h ep = rndv_req->send.ep;
889 ucp_ep_config_t *config = ucp_ep_config(ep);
890 ucp_context_h context = worker->context;
891 ucs_status_t status;
892 size_t max_frag_size, offset, length;
893 size_t min_zcopy, max_zcopy;
894
895 min_zcopy = config->tag.rndv.min_get_zcopy;
896 max_zcopy = config->tag.rndv.max_get_zcopy;
897 max_frag_size = ucs_min(context->config.ext.rndv_frag_size,
898 max_zcopy);
899 rndv_req->send.rndv_get.remote_request = remote_request;
900 rndv_req->send.rndv_get.remote_address = remote_address - base_offset;
901 rndv_req->send.rndv_get.rreq = rreq;
902 rndv_req->send.length = size;
903 rndv_req->send.state.dt.offset = 0;
904 rndv_req->send.mem_type = rreq->recv.mem_type;
905
906 /* Protocol:
907 * Step 1: GET remote fragment into HOST fragment buffer
908 * Step 2: PUT from fragment buffer to MEM TYPE destination
909 * Step 3: Send ATS for RNDV request
910 */
911
912 status = ucp_ep_rkey_unpack(rndv_req->send.ep, rkey_buffer,
913 &rndv_req->send.rndv_get.rkey);
914 if (ucs_unlikely(status != UCS_OK)) {
915 ucs_fatal("failed to unpack rendezvous remote key received from %s: %s",
916 ucp_ep_peer_name(rndv_req->send.ep), ucs_status_string(status));
917 }
918
919 ucp_rndv_req_init_get_zcopy_lane_map(rndv_req);
920
921 offset = 0;
922 while (offset != size) {
923 length = ucp_rndv_adjust_zcopy_length(min_zcopy, max_frag_size, 0,
924 size, offset, size - offset);
925
926 /* GET remote fragment into HOST fragment buffer */
927 ucp_rndv_send_frag_get_mem_type(rndv_req, remote_request, length,
928 remote_address + offset, UCS_MEMORY_TYPE_HOST,
929 rndv_req->send.rndv_get.rkey,
930 rndv_req->send.rndv_get.rkey_index,
931 rndv_req->send.rndv_get.lanes_map_all,
932 ucp_rndv_recv_frag_get_completion);
933
934 offset += length;
935 }
936
937 return UCS_OK;
938 }
939
ucp_rndv_send_frag_rtr(ucp_worker_h worker,ucp_request_t * rndv_req,ucp_request_t * rreq,const ucp_rndv_rts_hdr_t * rndv_rts_hdr)940 static void ucp_rndv_send_frag_rtr(ucp_worker_h worker, ucp_request_t *rndv_req,
941 ucp_request_t *rreq,
942 const ucp_rndv_rts_hdr_t *rndv_rts_hdr)
943 {
944 size_t max_frag_size = worker->context->config.ext.rndv_frag_size;
945 int i, num_frags;
946 size_t frag_size;
947 size_t offset;
948 ucp_mem_desc_t *mdesc;
949 ucp_request_t *freq;
950 ucp_request_t *frndv_req;
951 unsigned md_index;
952 unsigned memh_index;
953
954 ucp_trace_req(rreq, "using rndv pipeline protocol rndv_req %p", rndv_req);
955
956 offset = 0;
957 num_frags = ucs_div_round_up(rndv_rts_hdr->size, max_frag_size);
958
959 for (i = 0; i < num_frags; i++) {
960 frag_size = ucs_min(max_frag_size, (rndv_rts_hdr->size - offset));
961
962 /* internal fragment recv request allocated on receiver side to receive
963 * put fragment from sender and to perform a put to recv buffer */
964 freq = ucp_request_get(worker);
965 if (freq == NULL) {
966 ucs_fatal("failed to allocate fragment receive request");
967 }
968
969 /* internal rndv request to send RTR */
970 frndv_req = ucp_request_get(worker);
971 if (frndv_req == NULL) {
972 ucs_fatal("failed to allocate fragment rendezvous reply");
973 }
974
975 /* allocate fragment recv buffer desc*/
976 mdesc = ucp_worker_mpool_get(&worker->rndv_frag_mp);
977 if (mdesc == NULL) {
978 ucs_fatal("failed to allocate fragment memory buffer");
979 }
980
981 freq->recv.buffer = mdesc + 1;
982 freq->recv.datatype = ucp_dt_make_contig(1);
983 freq->recv.mem_type = UCS_MEMORY_TYPE_HOST;
984 freq->recv.length = frag_size;
985 freq->recv.state.dt.contig.md_map = 0;
986 freq->recv.frag.rreq = rreq;
987 freq->recv.frag.offset = offset;
988 freq->flags |= UCP_REQUEST_FLAG_RNDV_FRAG;
989
990 memh_index = 0;
991 ucs_for_each_bit(md_index,
992 (ucp_ep_config(rndv_req->send.ep)->key.rma_bw_md_map &
993 mdesc->memh->md_map)) {
994 freq->recv.state.dt.contig.memh[memh_index++] = ucp_memh2uct(mdesc->memh, md_index);
995 freq->recv.state.dt.contig.md_map |= UCS_BIT(md_index);
996 }
997 ucs_assert(memh_index <= UCP_MAX_OP_MDS);
998
999 frndv_req->send.ep = rndv_req->send.ep;
1000 frndv_req->send.pending_lane = UCP_NULL_LANE;
1001
1002 ucp_rndv_req_send_rtr(frndv_req, freq, rndv_rts_hdr->sreq.reqptr,
1003 freq->recv.length, offset);
1004 offset += frag_size;
1005 }
1006
1007 /* release original rndv reply request */
1008 ucp_request_put(rndv_req);
1009 }
1010
1011 static UCS_F_ALWAYS_INLINE int
ucp_rndv_is_rkey_ptr(const ucp_rndv_rts_hdr_t * rndv_rts_hdr,ucp_ep_h ep,ucs_memory_type_t recv_mem_type,ucp_rndv_mode_t rndv_mode)1012 ucp_rndv_is_rkey_ptr(const ucp_rndv_rts_hdr_t *rndv_rts_hdr, ucp_ep_h ep,
1013 ucs_memory_type_t recv_mem_type, ucp_rndv_mode_t rndv_mode)
1014 {
1015 const ucp_ep_config_t *ep_config = ucp_ep_config(ep);
1016
1017 return /* must have remote address */
1018 (rndv_rts_hdr->address != 0) &&
1019 /* remote key must be on a memory domain for which we support rkey_ptr */
1020 (ucp_rkey_packed_md_map(rndv_rts_hdr + 1) &
1021 ep_config->tag.rndv.rkey_ptr_dst_mds) &&
1022 /* rendezvous mode must not be forced to put/get */
1023 (rndv_mode == UCP_RNDV_MODE_AUTO) &&
1024 /* need local memory access for data unpack */
1025 UCP_MEM_IS_ACCESSIBLE_FROM_CPU(recv_mem_type);
1026 }
1027
ucp_rndv_progress_rkey_ptr(void * arg)1028 static unsigned ucp_rndv_progress_rkey_ptr(void *arg)
1029 {
1030 ucp_worker_h worker = (ucp_worker_h)arg;
1031 ucp_request_t *rndv_req = ucs_queue_head_elem_non_empty(&worker->rkey_ptr_reqs,
1032 ucp_request_t,
1033 send.rkey_ptr.queue_elem);
1034 ucp_request_t *rreq = rndv_req->send.rkey_ptr.rreq;
1035 size_t seg_size = ucs_min(worker->context->config.ext.rkey_ptr_seg_size,
1036 rndv_req->send.length - rreq->recv.state.offset);
1037 ucs_status_t status;
1038 size_t offset, new_offset;
1039 int last;
1040
1041 offset = rreq->recv.state.offset;
1042 new_offset = offset + seg_size;
1043 last = new_offset == rndv_req->send.length;
1044 status = ucp_request_recv_data_unpack(rreq,
1045 rndv_req->send.buffer + offset,
1046 seg_size, offset, last);
1047 if (ucs_unlikely(status != UCS_OK) || last) {
1048 ucs_queue_pull_non_empty(&worker->rkey_ptr_reqs);
1049 ucp_request_complete_tag_recv(rreq, status);
1050 ucp_rkey_destroy(rndv_req->send.rkey_ptr.rkey);
1051 ucp_rndv_req_send_ats(rndv_req, rreq,
1052 rndv_req->send.rkey_ptr.remote_request, status);
1053 if (ucs_queue_is_empty(&worker->rkey_ptr_reqs)) {
1054 uct_worker_progress_unregister_safe(worker->uct,
1055 &worker->rkey_ptr_cb_id);
1056 }
1057 } else {
1058 rreq->recv.state.offset = new_offset;
1059 }
1060
1061 return 1;
1062 }
1063
ucp_rndv_do_rkey_ptr(ucp_request_t * rndv_req,ucp_request_t * rreq,const ucp_rndv_rts_hdr_t * rndv_rts_hdr)1064 static void ucp_rndv_do_rkey_ptr(ucp_request_t *rndv_req, ucp_request_t *rreq,
1065 const ucp_rndv_rts_hdr_t *rndv_rts_hdr)
1066 {
1067 ucp_ep_h ep = rndv_req->send.ep;
1068 const ucp_ep_config_t *ep_config = ucp_ep_config(ep);
1069 ucp_worker_h worker = rreq->recv.worker;
1070 ucp_md_index_t dst_md_index = 0;
1071 ucp_lane_index_t i, lane;
1072 ucs_status_t status;
1073 unsigned rkey_index;
1074 void *local_ptr;
1075 ucp_rkey_h rkey;
1076
1077 ucp_trace_req(rndv_req, "start rkey_ptr rndv rreq %p", rreq);
1078
1079 status = ucp_ep_rkey_unpack(ep, rndv_rts_hdr + 1, &rkey);
1080 if (status != UCS_OK) {
1081 ucs_fatal("failed to unpack rendezvous remote key received from %s: %s",
1082 ucp_ep_peer_name(ep), ucs_status_string(status));
1083 }
1084
1085 /* Find a lane which is capable of accessing the destination memory */
1086 lane = UCP_NULL_LANE;
1087 for (i = 0; i < ep_config->key.num_lanes; ++i) {
1088 dst_md_index = ep_config->key.lanes[i].dst_md_index;
1089 if (UCS_BIT(dst_md_index) & rkey->md_map) {
1090 lane = i;
1091 break;
1092 }
1093 }
1094
1095 if (ucs_unlikely(lane == UCP_NULL_LANE)) {
1096 /* We should be able to find a lane, because ucp_rndv_is_rkey_ptr()
1097 * already checked that (rkey->md_map & ep_config->rkey_ptr_dst_mds) != 0
1098 */
1099 ucs_fatal("failed to find a lane to access remote memory domains 0x%lx",
1100 rkey->md_map);
1101 }
1102
1103 rkey_index = ucs_bitmap2idx(rkey->md_map, dst_md_index);
1104 status = uct_rkey_ptr(rkey->tl_rkey[rkey_index].cmpt,
1105 &rkey->tl_rkey[rkey_index].rkey,
1106 rndv_rts_hdr->address, &local_ptr);
1107 if (status != UCS_OK) {
1108 ucp_request_complete_tag_recv(rreq, status);
1109 ucp_rkey_destroy(rkey);
1110 ucp_rndv_req_send_ats(rndv_req, rreq, rndv_rts_hdr->sreq.reqptr, status);
1111 return;
1112 }
1113
1114 rreq->recv.state.offset = 0;
1115
1116 ucp_trace_req(rndv_req, "obtained a local pointer to remote buffer: %p",
1117 local_ptr);
1118 rndv_req->send.buffer = local_ptr;
1119 rndv_req->send.length = rndv_rts_hdr->size;
1120 rndv_req->send.rkey_ptr.rkey = rkey;
1121 rndv_req->send.rkey_ptr.remote_request = rndv_rts_hdr->sreq.reqptr;
1122 rndv_req->send.rkey_ptr.rreq = rreq;
1123
1124 UCP_WORKER_STAT_RNDV(ep->worker, RKEY_PTR, 1);
1125
1126 ucs_queue_push(&worker->rkey_ptr_reqs, &rndv_req->send.rkey_ptr.queue_elem);
1127 uct_worker_progress_register_safe(worker->uct,
1128 ucp_rndv_progress_rkey_ptr,
1129 rreq->recv.worker,
1130 UCS_CALLBACKQ_FLAG_FAST,
1131 &worker->rkey_ptr_cb_id);
1132 }
1133
1134 static UCS_F_ALWAYS_INLINE int
ucp_rndv_test_zcopy_scheme_support(size_t length,size_t min_zcopy,size_t max_zcopy,int split)1135 ucp_rndv_test_zcopy_scheme_support(size_t length, size_t min_zcopy,
1136 size_t max_zcopy, int split)
1137 {
1138 return /* is the current message greater than the minimal GET/PUT Zcopy? */
1139 (length >= min_zcopy) &&
1140 /* is the current message less than the maximal GET/PUT Zcopy? */
1141 ((length <= max_zcopy) ||
1142 /* or can the message be split? */ split);
1143 }
1144
1145 UCS_PROFILE_FUNC_VOID(ucp_rndv_matched, (worker, rreq, rndv_rts_hdr),
1146 ucp_worker_h worker, ucp_request_t *rreq,
1147 const ucp_rndv_rts_hdr_t *rndv_rts_hdr)
1148 {
1149 ucp_rndv_mode_t rndv_mode;
1150 ucp_request_t *rndv_req;
1151 ucp_ep_h ep;
1152 ucp_ep_config_t *ep_config;
1153 ucs_status_t status;
1154 int is_get_zcopy_failed;
1155
1156 UCS_ASYNC_BLOCK(&worker->async);
1157
1158 UCS_PROFILE_REQUEST_EVENT(rreq, "rndv_match", 0);
1159
1160 /* rreq is the receive request on the receiver's side */
1161 rreq->recv.tag.info.sender_tag = rndv_rts_hdr->super.tag;
1162 rreq->recv.tag.info.length = rndv_rts_hdr->size;
1163
1164 /* the internal send request allocated on receiver side (to perform a "get"
1165 * operation, send "ATS" and "RTR") */
1166 rndv_req = ucp_request_get(worker);
1167 if (rndv_req == NULL) {
1168 ucs_error("failed to allocate rendezvous reply");
1169 goto out;
1170 }
1171
1172 rndv_req->send.ep = ucp_worker_get_ep_by_ptr(worker,
1173 rndv_rts_hdr->sreq.ep_ptr);
1174 rndv_req->flags = 0;
1175 rndv_req->send.mdesc = NULL;
1176 rndv_req->send.pending_lane = UCP_NULL_LANE;
1177 is_get_zcopy_failed = 0;
1178
1179 ucp_trace_req(rreq,
1180 "rndv matched remote {address 0x%"PRIx64" size %zu sreq 0x%lx}"
1181 " rndv_sreq %p", rndv_rts_hdr->address, rndv_rts_hdr->size,
1182 rndv_rts_hdr->sreq.reqptr, rndv_req);
1183
1184 if (ucs_unlikely(rreq->recv.length < rndv_rts_hdr->size)) {
1185 ucp_trace_req(rndv_req,
1186 "rndv truncated remote size %zu local size %zu rreq %p",
1187 rndv_rts_hdr->size, rreq->recv.length, rreq);
1188 ucp_rndv_req_send_ats(rndv_req, rreq, rndv_rts_hdr->sreq.reqptr, UCS_OK);
1189 ucp_request_recv_generic_dt_finish(rreq);
1190 ucp_rndv_zcopy_recv_req_complete(rreq, UCS_ERR_MESSAGE_TRUNCATED);
1191 goto out;
1192 }
1193
1194 /* if the receive side is not connected yet then the RTS was received on a stub ep */
1195 ep = rndv_req->send.ep;
1196 ep_config = ucp_ep_config(ep);
1197 rndv_mode = worker->context->config.ext.rndv_mode;
1198
1199 if (ucp_rndv_is_rkey_ptr(rndv_rts_hdr, ep, rreq->recv.mem_type, rndv_mode)) {
1200 ucp_rndv_do_rkey_ptr(rndv_req, rreq, rndv_rts_hdr);
1201 goto out;
1202 }
1203
1204 if (UCP_DT_IS_CONTIG(rreq->recv.datatype)) {
1205 if ((rndv_rts_hdr->address != 0) &&
1206 ucp_rndv_test_zcopy_scheme_support(rndv_rts_hdr->size,
1207 ep_config->tag.rndv.min_get_zcopy,
1208 ep_config->tag.rndv.max_get_zcopy,
1209 ep_config->tag.rndv.get_zcopy_split)) {
1210 /* try to fetch the data with a get_zcopy operation */
1211 status = ucp_rndv_req_send_rma_get(rndv_req, rreq, rndv_rts_hdr);
1212 if (status == UCS_OK) {
1213 goto out;
1214 }
1215
1216 /* fallback to non get zcopy protocol */
1217 ucp_rkey_destroy(rndv_req->send.rndv_get.rkey);
1218 is_get_zcopy_failed = 1;
1219 }
1220
1221 if (rndv_mode == UCP_RNDV_MODE_AUTO) {
1222 /* check if we need pipelined memtype staging */
1223 if (UCP_MEM_IS_CUDA(rreq->recv.mem_type) &&
1224 ucp_rndv_is_recv_pipeline_needed(rndv_req,
1225 rndv_rts_hdr,
1226 rreq->recv.mem_type,
1227 is_get_zcopy_failed)) {
1228 ucp_rndv_recv_data_init(rreq, rndv_rts_hdr->size);
1229 if (ucp_rndv_is_put_pipeline_needed(rndv_rts_hdr->address,
1230 rndv_rts_hdr->size,
1231 ep_config->tag.rndv.min_get_zcopy,
1232 ep_config->tag.rndv.max_get_zcopy,
1233 is_get_zcopy_failed)) {
1234 /* send FRAG RTR for sender to PUT the fragment. */
1235 ucp_rndv_send_frag_rtr(worker, rndv_req, rreq, rndv_rts_hdr);
1236 } else {
1237 /* sender address is present. do GET pipeline */
1238 ucp_rndv_recv_start_get_pipeline(worker, rndv_req, rreq,
1239 rndv_rts_hdr->sreq.reqptr,
1240 rndv_rts_hdr + 1,
1241 rndv_rts_hdr->address,
1242 rndv_rts_hdr->size, 0);
1243 }
1244 goto out;
1245 }
1246 }
1247
1248 if ((rndv_mode == UCP_RNDV_MODE_PUT_ZCOPY) ||
1249 UCP_MEM_IS_CUDA(rreq->recv.mem_type)) {
1250 /* put protocol is allowed - register receive buffer memory for rma */
1251 ucs_assert(rndv_rts_hdr->size <= rreq->recv.length);
1252 ucp_request_recv_buffer_reg(rreq, ep_config->key.rma_bw_md_map,
1253 rndv_rts_hdr->size);
1254 }
1255 }
1256
1257 /* The sender didn't specify its address in the RTS, or the rndv mode was
1258 * configured to PUT, or GET rndv mode is unsupported - send an RTR and
1259 * the sender will send the data with active message or put_zcopy. */
1260 ucp_rndv_recv_data_init(rreq, rndv_rts_hdr->size);
1261 UCP_WORKER_STAT_RNDV(ep->worker, SEND_RTR, 1);
1262 ucp_rndv_req_send_rtr(rndv_req, rreq, rndv_rts_hdr->sreq.reqptr,
1263 rndv_rts_hdr->size, 0ul);
1264
1265 out:
1266 UCS_ASYNC_UNBLOCK(&worker->async);
1267 }
1268
ucp_rndv_process_rts(void * arg,void * data,size_t length,unsigned tl_flags)1269 ucs_status_t ucp_rndv_process_rts(void *arg, void *data, size_t length,
1270 unsigned tl_flags)
1271 {
1272 ucp_worker_h worker = arg;
1273 ucp_rndv_rts_hdr_t *rndv_rts_hdr = data;
1274 ucp_recv_desc_t *rdesc;
1275 ucp_request_t *rreq;
1276 ucs_status_t status;
1277
1278 rreq = ucp_tag_exp_search(&worker->tm, rndv_rts_hdr->super.tag);
1279 if (rreq != NULL) {
1280 ucp_rndv_matched(worker, rreq, rndv_rts_hdr);
1281
1282 /* Cancel req in transport if it was offloaded, because it arrived
1283 as unexpected */
1284 ucp_tag_offload_try_cancel(worker, rreq, UCP_TAG_OFFLOAD_CANCEL_FORCE);
1285
1286 UCP_WORKER_STAT_RNDV(worker, EXP, 1);
1287 status = UCS_OK;
1288 } else {
1289 status = ucp_recv_desc_init(worker, data, length, 0, tl_flags,
1290 sizeof(*rndv_rts_hdr),
1291 UCP_RECV_DESC_FLAG_RNDV, 0, &rdesc);
1292 if (!UCS_STATUS_IS_ERR(status)) {
1293 ucp_tag_unexp_recv(&worker->tm, rdesc, rndv_rts_hdr->super.tag);
1294 }
1295 }
1296
1297 return status;
1298 }
1299
1300 UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_rts_handler,
1301 (arg, data, length, tl_flags),
1302 void *arg, void *data, size_t length, unsigned tl_flags)
1303 {
1304 return ucp_rndv_process_rts(arg, data, length, tl_flags);
1305 }
1306
1307 UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_ats_handler,
1308 (arg, data, length, flags),
1309 void *arg, void *data, size_t length, unsigned flags)
1310 {
1311 ucp_reply_hdr_t *rep_hdr = data;
1312 ucp_request_t *sreq = (ucp_request_t*) rep_hdr->reqptr;
1313
1314 /* dereg the original send request and set it to complete */
1315 UCS_PROFILE_REQUEST_EVENT(sreq, "rndv_ats_recv", 0);
1316 if (sreq->flags & UCP_REQUEST_FLAG_OFFLOADED) {
1317 ucp_tag_offload_cancel_rndv(sreq);
1318 }
1319 ucp_rndv_complete_send(sreq, rep_hdr->status);
1320 return UCS_OK;
1321 }
1322
ucp_rndv_pack_data(void * dest,void * arg)1323 static size_t ucp_rndv_pack_data(void *dest, void *arg)
1324 {
1325 ucp_rndv_data_hdr_t *hdr = dest;
1326 ucp_request_t *sreq = arg;
1327 size_t length, offset;
1328
1329 offset = sreq->send.state.dt.offset;
1330 hdr->rreq_ptr = sreq->send.msg_proto.tag.rreq_ptr;
1331 hdr->offset = offset;
1332 length = ucs_min(sreq->send.length - offset,
1333 ucp_ep_get_max_bcopy(sreq->send.ep, sreq->send.lane) - sizeof(*hdr));
1334
1335 return sizeof(*hdr) + ucp_dt_pack(sreq->send.ep->worker, sreq->send.datatype,
1336 sreq->send.mem_type, hdr + 1, sreq->send.buffer,
1337 &sreq->send.state.dt, length);
1338 }
1339
1340 UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_progress_am_bcopy, (self),
1341 uct_pending_req_t *self)
1342 {
1343 ucp_request_t *sreq = ucs_container_of(self, ucp_request_t, send.uct);
1344 ucp_ep_t *ep = sreq->send.ep;
1345 ucs_status_t status;
1346
1347 if (sreq->send.length <= ucp_ep_config(ep)->am.max_bcopy - sizeof(ucp_rndv_data_hdr_t)) {
1348 /* send a single bcopy message */
1349 status = ucp_do_am_bcopy_single(self, UCP_AM_ID_RNDV_DATA,
1350 ucp_rndv_pack_data);
1351 } else {
1352 status = ucp_do_am_bcopy_multi(self, UCP_AM_ID_RNDV_DATA,
1353 UCP_AM_ID_RNDV_DATA,
1354 ucp_rndv_pack_data,
1355 ucp_rndv_pack_data, 1);
1356 }
1357 if (status == UCS_OK) {
1358 ucp_rndv_complete_send(sreq, UCS_OK);
1359 } else if (status == UCP_STATUS_PENDING_SWITCH) {
1360 status = UCS_OK;
1361 }
1362
1363 return status;
1364 }
1365
1366 UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_progress_rma_put_zcopy, (self),
1367 uct_pending_req_t *self)
1368 {
1369 ucp_request_t *sreq = ucs_container_of(self, ucp_request_t, send.uct);
1370 const size_t max_iovcnt = 1;
1371 ucp_ep_h ep = sreq->send.ep;
1372 ucs_status_t status;
1373 size_t offset, ucp_mtu, align, remaining, length;
1374 uct_iface_attr_t *attrs;
1375 uct_iov_t iov[max_iovcnt];
1376 size_t iovcnt;
1377 ucp_dt_state_t state;
1378
1379 if (!sreq->send.mdesc) {
1380 status = ucp_request_send_buffer_reg_lane(sreq, sreq->send.lane, 0);
1381 ucs_assert_always(status == UCS_OK);
1382 }
1383
1384 attrs = ucp_worker_iface_get_attr(ep->worker,
1385 ucp_ep_get_rsc_index(ep, sreq->send.lane));
1386 align = attrs->cap.put.opt_zcopy_align;
1387 ucp_mtu = attrs->cap.put.align_mtu;
1388
1389 offset = sreq->send.state.dt.offset;
1390 remaining = (uintptr_t)sreq->send.buffer % align;
1391
1392 if ((offset == 0) && (remaining > 0) && (sreq->send.length > ucp_mtu)) {
1393 length = ucp_mtu - remaining;
1394 } else {
1395 length = ucs_min(sreq->send.length - offset,
1396 ucp_ep_config(ep)->tag.rndv.max_put_zcopy);
1397 }
1398
1399 ucs_trace_data("req %p: offset %zu remainder %zu. read to %p len %zu",
1400 sreq, offset, (uintptr_t)sreq->send.buffer % align,
1401 UCS_PTR_BYTE_OFFSET(sreq->send.buffer, offset), length);
1402
1403 state = sreq->send.state.dt;
1404 ucp_dt_iov_copy_uct(ep->worker->context, iov, &iovcnt, max_iovcnt, &state,
1405 sreq->send.buffer, ucp_dt_make_contig(1), length,
1406 ucp_ep_md_index(ep, sreq->send.lane), sreq->send.mdesc);
1407 status = uct_ep_put_zcopy(ep->uct_eps[sreq->send.lane],
1408 iov, iovcnt,
1409 sreq->send.rndv_put.remote_address + offset,
1410 sreq->send.rndv_put.uct_rkey,
1411 &sreq->send.state.uct_comp);
1412 ucp_request_send_state_advance(sreq, &state,
1413 UCP_REQUEST_SEND_PROTO_RNDV_PUT,
1414 status);
1415 if (sreq->send.state.dt.offset == sreq->send.length) {
1416 if (sreq->send.state.uct_comp.count == 0) {
1417 sreq->send.state.uct_comp.func(&sreq->send.state.uct_comp, status);
1418 }
1419 return UCS_OK;
1420 } else if (!UCS_STATUS_IS_ERR(status)) {
1421 return UCS_INPROGRESS;
1422 } else {
1423 return status;
1424 }
1425 }
1426
ucp_rndv_am_zcopy_send_req_complete(ucp_request_t * req,ucs_status_t status)1427 static void ucp_rndv_am_zcopy_send_req_complete(ucp_request_t *req,
1428 ucs_status_t status)
1429 {
1430 ucs_assert(req->send.state.uct_comp.count == 0);
1431 ucp_request_send_buffer_dereg(req);
1432 ucp_request_complete_send(req, status);
1433 }
1434
ucp_rndv_am_zcopy_completion(uct_completion_t * self,ucs_status_t status)1435 static void ucp_rndv_am_zcopy_completion(uct_completion_t *self,
1436 ucs_status_t status)
1437 {
1438 ucp_request_t *sreq = ucs_container_of(self, ucp_request_t,
1439 send.state.uct_comp);
1440 if (sreq->send.state.dt.offset == sreq->send.length) {
1441 ucp_rndv_am_zcopy_send_req_complete(sreq, status);
1442 } else if (status != UCS_OK) {
1443 ucs_fatal("error handling is unsupported with rendezvous protocol");
1444 }
1445 }
1446
ucp_rndv_progress_am_zcopy_single(uct_pending_req_t * self)1447 static ucs_status_t ucp_rndv_progress_am_zcopy_single(uct_pending_req_t *self)
1448 {
1449 ucp_request_t *sreq = ucs_container_of(self, ucp_request_t, send.uct);
1450 ucp_rndv_data_hdr_t hdr;
1451
1452 hdr.rreq_ptr = sreq->send.msg_proto.tag.rreq_ptr;
1453 hdr.offset = 0;
1454 return ucp_do_am_zcopy_single(self, UCP_AM_ID_RNDV_DATA, &hdr, sizeof(hdr),
1455 ucp_rndv_am_zcopy_send_req_complete);
1456 }
1457
ucp_rndv_progress_am_zcopy_multi(uct_pending_req_t * self)1458 static ucs_status_t ucp_rndv_progress_am_zcopy_multi(uct_pending_req_t *self)
1459 {
1460 ucp_request_t *sreq = ucs_container_of(self, ucp_request_t, send.uct);
1461 ucp_rndv_data_hdr_t hdr;
1462
1463 hdr.rreq_ptr = sreq->send.msg_proto.tag.rreq_ptr;
1464 hdr.offset = sreq->send.state.dt.offset;
1465 return ucp_do_am_zcopy_multi(self,
1466 UCP_AM_ID_RNDV_DATA,
1467 UCP_AM_ID_RNDV_DATA,
1468 &hdr, sizeof(hdr),
1469 &hdr, sizeof(hdr),
1470 ucp_rndv_am_zcopy_send_req_complete, 1);
1471 }
1472
1473 UCS_PROFILE_FUNC_VOID(ucp_rndv_send_frag_put_completion, (self, status),
1474 uct_completion_t *self, ucs_status_t status)
1475 {
1476 ucp_request_t *freq = ucs_container_of(self, ucp_request_t, send.state.uct_comp);
1477 ucp_request_t *req = freq->send.rndv_put.sreq;
1478
1479 /* release memory descriptor */
1480 if (freq->send.mdesc) {
1481 ucs_mpool_put_inline((void *)freq->send.mdesc);
1482 }
1483
1484 req->send.state.dt.offset += freq->send.length;
1485 ucs_assert(req->send.state.dt.offset <= req->send.length);
1486
1487 /* send ATP for last fragment of the rndv request */
1488 if (req->send.length == req->send.state.dt.offset) {
1489 ucp_rndv_send_frag_atp(req, req->send.rndv_put.remote_request);
1490 }
1491
1492 ucp_request_put(freq);
1493 }
1494
1495 UCS_PROFILE_FUNC_VOID(ucp_rndv_put_pipeline_frag_get_completion, (self, status),
1496 uct_completion_t *self, ucs_status_t status)
1497 {
1498 ucp_request_t *freq = ucs_container_of(self, ucp_request_t, send.state.uct_comp);
1499 ucp_request_t *fsreq = freq->send.rndv_get.rreq;
1500
1501 /* get completed on memtype endpoint to stage on host. send put request to receiver*/
1502 ucp_request_send_state_reset(freq, ucp_rndv_send_frag_put_completion,
1503 UCP_REQUEST_SEND_PROTO_RNDV_PUT);
1504 freq->send.rndv_put.remote_address = fsreq->send.rndv_put.remote_address +
1505 (freq->send.rndv_get.remote_address - (uint64_t)fsreq->send.buffer);
1506 freq->send.ep = fsreq->send.ep;
1507 freq->send.uct.func = ucp_rndv_progress_rma_put_zcopy;
1508 freq->send.rndv_put.sreq = fsreq;
1509 freq->send.rndv_put.rkey = fsreq->send.rndv_put.rkey;
1510 freq->send.rndv_put.uct_rkey = fsreq->send.rndv_put.uct_rkey;
1511 freq->send.lane = fsreq->send.lane;
1512 freq->send.state.dt.dt.contig.md_map = 0;
1513
1514 ucp_request_send(freq, 0);
1515 }
1516
ucp_rndv_send_start_put_pipeline(ucp_request_t * sreq,ucp_rndv_rtr_hdr_t * rndv_rtr_hdr)1517 static ucs_status_t ucp_rndv_send_start_put_pipeline(ucp_request_t *sreq,
1518 ucp_rndv_rtr_hdr_t *rndv_rtr_hdr)
1519 {
1520 ucp_ep_h ep = sreq->send.ep;
1521 ucp_ep_config_t *config = ucp_ep_config(ep);
1522 ucp_worker_h worker = sreq->send.ep->worker;
1523 ucp_context_h context = worker->context;
1524 const uct_md_attr_t *md_attr;
1525 ucp_request_t *freq;
1526 ucp_request_t *fsreq;
1527 ucp_md_index_t md_index;
1528 size_t max_frag_size, rndv_size, length;
1529 size_t offset, rndv_base_offset;
1530 size_t min_zcopy, max_zcopy;
1531
1532 ucp_trace_req(sreq, "using put rndv pipeline protocol");
1533
1534 /* Protocol:
1535 * Step 1: GET fragment from send buffer to HOST fragment buffer
1536 * Step 2: PUT from fragment HOST buffer to remote HOST fragment buffer
1537 * Step 3: send ATP for each fragment request
1538 */
1539
1540 /* check if lane supports host memory, to stage sends through host memory */
1541 md_attr = ucp_ep_md_attr(sreq->send.ep, sreq->send.lane);
1542 if (!(md_attr->cap.reg_mem_types & UCS_BIT(UCS_MEMORY_TYPE_HOST))) {
1543 return UCS_ERR_UNSUPPORTED;
1544 }
1545
1546 min_zcopy = config->tag.rndv.min_put_zcopy;
1547 max_zcopy = config->tag.rndv.max_put_zcopy;
1548 rndv_size = ucs_min(rndv_rtr_hdr->size, sreq->send.length);
1549 max_frag_size = ucs_min(context->config.ext.rndv_frag_size, max_zcopy);
1550 rndv_base_offset = rndv_rtr_hdr->offset;
1551
1552 /* initialize send req state on first fragment rndv request */
1553 if (rndv_base_offset == 0) {
1554 ucp_request_send_state_reset(sreq, NULL, UCP_REQUEST_SEND_PROTO_RNDV_PUT);
1555 }
1556
1557 /* internal send request allocated on sender side to handle send fragments for RTR */
1558 fsreq = ucp_request_get(worker);
1559 if (fsreq == NULL) {
1560 ucs_fatal("failed to allocate fragment receive request");
1561 }
1562
1563 ucp_request_send_state_init(fsreq, ucp_dt_make_contig(1), 0);
1564 fsreq->send.buffer = UCS_PTR_BYTE_OFFSET(sreq->send.buffer,
1565 rndv_base_offset);
1566 fsreq->send.length = rndv_size;
1567 fsreq->send.mem_type = sreq->send.mem_type;
1568 fsreq->send.ep = sreq->send.ep;
1569 fsreq->send.lane = sreq->send.lane;
1570 fsreq->send.rndv_put.rkey = sreq->send.rndv_put.rkey;
1571 fsreq->send.rndv_put.uct_rkey = sreq->send.rndv_put.uct_rkey;
1572 fsreq->send.rndv_put.remote_request = rndv_rtr_hdr->rreq_ptr;
1573 fsreq->send.rndv_put.remote_address = rndv_rtr_hdr->address;
1574 fsreq->send.rndv_put.sreq = sreq;
1575 fsreq->send.state.dt.offset = 0;
1576
1577 offset = 0;
1578 while (offset != rndv_size) {
1579 length = ucp_rndv_adjust_zcopy_length(min_zcopy, max_frag_size, 0,
1580 rndv_size, offset, rndv_size - offset);
1581
1582 if (UCP_MEM_IS_ACCESSIBLE_FROM_CPU(sreq->send.mem_type)) {
1583 /* sbuf is in host, directly do put */
1584 freq = ucp_request_get(worker);
1585 if (ucs_unlikely(freq == NULL)) {
1586 ucs_error("failed to allocate fragment receive request");
1587 return UCS_ERR_NO_MEMORY;
1588 }
1589
1590 ucp_request_send_state_reset(freq, ucp_rndv_send_frag_put_completion,
1591 UCP_REQUEST_SEND_PROTO_RNDV_PUT);
1592 md_index = ucp_ep_md_index(sreq->send.ep,
1593 sreq->send.lane);
1594 freq->send.ep = fsreq->send.ep;
1595 freq->send.buffer = UCS_PTR_BYTE_OFFSET(fsreq->send.buffer,
1596 offset);
1597 freq->send.datatype = ucp_dt_make_contig(1);
1598 freq->send.mem_type = UCS_MEMORY_TYPE_HOST;
1599 freq->send.state.dt.dt.contig.memh[0] =
1600 ucp_memh_map2uct(sreq->send.state.dt.dt.contig.memh,
1601 sreq->send.state.dt.dt.contig.md_map, md_index);
1602 freq->send.state.dt.dt.contig.md_map = UCS_BIT(md_index);
1603 freq->send.length = length;
1604 freq->send.uct.func = ucp_rndv_progress_rma_put_zcopy;
1605 freq->send.rndv_put.sreq = fsreq;
1606 freq->send.rndv_put.rkey = fsreq->send.rndv_put.rkey;
1607 freq->send.rndv_put.uct_rkey = fsreq->send.rndv_put.uct_rkey;
1608 freq->send.rndv_put.remote_address = rndv_rtr_hdr->address + offset;
1609 freq->send.rndv_put.remote_request = rndv_rtr_hdr->rreq_ptr;
1610 freq->send.lane = fsreq->send.lane;
1611 freq->send.mdesc = NULL;
1612
1613 ucp_request_send(freq, 0);
1614 } else {
1615 ucp_rndv_send_frag_get_mem_type(fsreq, 0, length,
1616 (uint64_t)UCS_PTR_BYTE_OFFSET(fsreq->send.buffer, offset),
1617 fsreq->send.mem_type, NULL, NULL, UCS_BIT(0),
1618 ucp_rndv_put_pipeline_frag_get_completion);
1619 }
1620
1621 offset += length;
1622 }
1623
1624 return UCS_OK;
1625 }
1626
1627 UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_atp_handler,
1628 (arg, data, length, flags),
1629 void *arg, void *data, size_t length, unsigned flags)
1630 {
1631 ucp_reply_hdr_t *rep_hdr = data;
1632 ucp_request_t *req = (ucp_request_t*) rep_hdr->reqptr;
1633
1634 if (req->flags & UCP_REQUEST_FLAG_RNDV_FRAG) {
1635 /* received ATP for frag RTR request */
1636 ucs_assert(req->recv.frag.rreq != NULL);
1637 UCS_PROFILE_REQUEST_EVENT(req, "rndv_frag_atp_recv", 0);
1638 ucp_rndv_recv_frag_put_mem_type(req->recv.frag.rreq, NULL, req,
1639 ((ucp_mem_desc_t*) req->recv.buffer - 1),
1640 req->recv.length, req->recv.frag.offset);
1641 } else {
1642 UCS_PROFILE_REQUEST_EVENT(req, "rndv_atp_recv", 0);
1643 ucp_rndv_zcopy_recv_req_complete(req, UCS_OK);
1644 }
1645
1646 return UCS_OK;
1647 }
1648
1649 UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_rtr_handler,
1650 (arg, data, length, flags),
1651 void *arg, void *data, size_t length, unsigned flags)
1652 {
1653 ucp_rndv_rtr_hdr_t *rndv_rtr_hdr = data;
1654 ucp_request_t *sreq = (ucp_request_t*)rndv_rtr_hdr->sreq_ptr;
1655 ucp_ep_h ep = sreq->send.ep;
1656 ucp_ep_config_t *ep_config = ucp_ep_config(ep);
1657 ucp_context_h context = ep->worker->context;
1658 ucs_status_t status;
1659 int is_pipeline_rndv;
1660
1661 ucp_trace_req(sreq, "received rtr address 0x%lx remote rreq 0x%lx",
1662 rndv_rtr_hdr->address, rndv_rtr_hdr->rreq_ptr);
1663 UCS_PROFILE_REQUEST_EVENT(sreq, "rndv_rtr_recv", 0);
1664
1665 if (sreq->flags & UCP_REQUEST_FLAG_OFFLOADED) {
1666 /* Do not deregister memory here, because am zcopy rndv may
1667 * need it registered (if am and tag is the same lane). */
1668 ucp_tag_offload_cancel_rndv(sreq);
1669 }
1670
1671 if (UCP_DT_IS_CONTIG(sreq->send.datatype) && rndv_rtr_hdr->address) {
1672 status = ucp_ep_rkey_unpack(ep, rndv_rtr_hdr + 1,
1673 &sreq->send.rndv_put.rkey);
1674 if (status != UCS_OK) {
1675 ucs_fatal("failed to unpack rendezvous remote key received from %s: %s",
1676 ucp_ep_peer_name(ep), ucs_status_string(status));
1677 }
1678
1679 is_pipeline_rndv = ((!UCP_MEM_IS_ACCESSIBLE_FROM_CPU(sreq->send.mem_type) ||
1680 (sreq->send.length != rndv_rtr_hdr->size)) &&
1681 (context->config.ext.rndv_mode != UCP_RNDV_MODE_PUT_ZCOPY));
1682
1683 sreq->send.lane = ucp_rkey_find_rma_lane(ep->worker->context, ep_config,
1684 (is_pipeline_rndv ?
1685 sreq->send.rndv_put.rkey->mem_type :
1686 sreq->send.mem_type),
1687 ep_config->tag.rndv.put_zcopy_lanes,
1688 sreq->send.rndv_put.rkey, 0,
1689 &sreq->send.rndv_put.uct_rkey);
1690 if (sreq->send.lane != UCP_NULL_LANE) {
1691 /*
1692 * Try pipeline protocol for non-host memory, if PUT_ZCOPY protocol is
1693 * not explicitly required. If pipeline is UNSUPPORTED, fallback to
1694 * PUT_ZCOPY anyway.
1695 */
1696 if (is_pipeline_rndv) {
1697 status = ucp_rndv_send_start_put_pipeline(sreq, rndv_rtr_hdr);
1698 if (status != UCS_ERR_UNSUPPORTED) {
1699 return status;
1700 }
1701 /* If we get here, it means that RNDV pipeline protocol is
1702 * unsupported and we have to use PUT_ZCOPY RNDV scheme instead */
1703 }
1704
1705 if ((context->config.ext.rndv_mode != UCP_RNDV_MODE_GET_ZCOPY) &&
1706 ucp_rndv_test_zcopy_scheme_support(sreq->send.length,
1707 ep_config->tag.rndv.min_put_zcopy,
1708 ep_config->tag.rndv.max_put_zcopy,
1709 ep_config->tag.rndv.put_zcopy_split)) {
1710 ucp_request_send_state_reset(sreq, ucp_rndv_put_completion,
1711 UCP_REQUEST_SEND_PROTO_RNDV_PUT);
1712 sreq->send.uct.func = ucp_rndv_progress_rma_put_zcopy;
1713 sreq->send.rndv_put.remote_request = rndv_rtr_hdr->rreq_ptr;
1714 sreq->send.rndv_put.remote_address = rndv_rtr_hdr->address;
1715 sreq->send.mdesc = NULL;
1716 goto out_send;
1717 } else {
1718 ucp_rkey_destroy(sreq->send.rndv_put.rkey);
1719 }
1720 } else {
1721 ucp_rkey_destroy(sreq->send.rndv_put.rkey);
1722 }
1723 }
1724
1725 ucp_trace_req(sreq, "using rdnv_data protocol");
1726
1727 /* switch to AM */
1728 sreq->send.msg_proto.tag.rreq_ptr = rndv_rtr_hdr->rreq_ptr;
1729
1730 if (UCP_DT_IS_CONTIG(sreq->send.datatype) &&
1731 (sreq->send.length >=
1732 ep_config->am.mem_type_zcopy_thresh[sreq->send.mem_type]))
1733 {
1734 status = ucp_request_send_buffer_reg_lane(sreq, ucp_ep_get_am_lane(ep), 0);
1735 ucs_assert_always(status == UCS_OK);
1736
1737 ucp_request_send_state_reset(sreq, ucp_rndv_am_zcopy_completion,
1738 UCP_REQUEST_SEND_PROTO_ZCOPY_AM);
1739
1740 if ((sreq->send.length + sizeof(ucp_rndv_data_hdr_t)) <=
1741 ep_config->am.max_zcopy) {
1742 sreq->send.uct.func = ucp_rndv_progress_am_zcopy_single;
1743 } else {
1744 sreq->send.uct.func = ucp_rndv_progress_am_zcopy_multi;
1745 sreq->send.msg_proto.am_bw_index = 1;
1746 }
1747 } else {
1748 ucp_request_send_state_reset(sreq, NULL, UCP_REQUEST_SEND_PROTO_BCOPY_AM);
1749 sreq->send.uct.func = ucp_rndv_progress_am_bcopy;
1750 sreq->send.msg_proto.am_bw_index = 1;
1751 }
1752
1753 out_send:
1754 ucp_request_send(sreq, 0);
1755 return UCS_OK;
1756 }
1757
1758 UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_data_handler,
1759 (arg, data, length, flags),
1760 void *arg, void *data, size_t length, unsigned flags)
1761 {
1762 ucp_rndv_data_hdr_t *rndv_data_hdr = data;
1763 ucp_request_t *rreq = (ucp_request_t*) rndv_data_hdr->rreq_ptr;
1764 size_t recv_len;
1765
1766 ucs_assert(!(rreq->flags & UCP_REQUEST_FLAG_RNDV_FRAG));
1767
1768 recv_len = length - sizeof(*rndv_data_hdr);
1769 UCS_PROFILE_REQUEST_EVENT(rreq, "rndv_data_recv", recv_len);
1770
1771 (void)ucp_tag_request_process_recv_data(rreq, rndv_data_hdr + 1, recv_len,
1772 rndv_data_hdr->offset, 1, 0);
1773 return UCS_OK;
1774 }
1775
ucp_rndv_dump_rkey(const void * packed_rkey,char * buffer,size_t max)1776 static void ucp_rndv_dump_rkey(const void *packed_rkey, char *buffer, size_t max)
1777 {
1778 char *p = buffer;
1779 char *endp = buffer + max;
1780
1781 snprintf(p, endp - p, " rkey ");
1782 p += strlen(p);
1783
1784 ucp_rkey_dump_packed(packed_rkey, p, endp - p);
1785 }
1786
ucp_rndv_dump(ucp_worker_h worker,uct_am_trace_type_t type,uint8_t id,const void * data,size_t length,char * buffer,size_t max)1787 static void ucp_rndv_dump(ucp_worker_h worker, uct_am_trace_type_t type,
1788 uint8_t id, const void *data, size_t length,
1789 char *buffer, size_t max)
1790 {
1791
1792 const ucp_rndv_rts_hdr_t *rndv_rts_hdr = data;
1793 const ucp_rndv_rtr_hdr_t *rndv_rtr_hdr = data;
1794 const ucp_rndv_data_hdr_t *rndv_data = data;
1795 const ucp_reply_hdr_t *rep_hdr = data;
1796
1797 switch (id) {
1798 case UCP_AM_ID_RNDV_RTS:
1799 ucs_assert(rndv_rts_hdr->sreq.ep_ptr != 0);
1800 snprintf(buffer, max, "RNDV_RTS tag %"PRIx64" ep_ptr %lx sreq 0x%lx "
1801 "address 0x%"PRIx64" size %zu", rndv_rts_hdr->super.tag,
1802 rndv_rts_hdr->sreq.ep_ptr, rndv_rts_hdr->sreq.reqptr,
1803 rndv_rts_hdr->address, rndv_rts_hdr->size);
1804 if (rndv_rts_hdr->address) {
1805 ucp_rndv_dump_rkey(rndv_rts_hdr + 1, buffer + strlen(buffer),
1806 max - strlen(buffer));
1807 }
1808 break;
1809 case UCP_AM_ID_RNDV_ATS:
1810 snprintf(buffer, max, "RNDV_ATS sreq 0x%lx status '%s'",
1811 rep_hdr->reqptr, ucs_status_string(rep_hdr->status));
1812 break;
1813 case UCP_AM_ID_RNDV_RTR:
1814 snprintf(buffer, max, "RNDV_RTR sreq 0x%lx rreq 0x%lx address 0x%lx",
1815 rndv_rtr_hdr->sreq_ptr, rndv_rtr_hdr->rreq_ptr,
1816 rndv_rtr_hdr->address);
1817 if (rndv_rtr_hdr->address) {
1818 ucp_rndv_dump_rkey(rndv_rtr_hdr + 1, buffer + strlen(buffer),
1819 max - strlen(buffer));
1820 }
1821 break;
1822 case UCP_AM_ID_RNDV_DATA:
1823 snprintf(buffer, max, "RNDV_DATA rreq 0x%"PRIx64" offset %zu",
1824 rndv_data->rreq_ptr, rndv_data->offset);
1825 break;
1826 case UCP_AM_ID_RNDV_ATP:
1827 snprintf(buffer, max, "RNDV_ATP sreq 0x%lx status '%s'",
1828 rep_hdr->reqptr, ucs_status_string(rep_hdr->status));
1829 break;
1830 default:
1831 return;
1832 }
1833 }
1834
1835 UCP_DEFINE_AM(UCP_FEATURE_TAG, UCP_AM_ID_RNDV_RTS, ucp_rndv_rts_handler,
1836 ucp_rndv_dump, 0);
1837 UCP_DEFINE_AM(UCP_FEATURE_TAG, UCP_AM_ID_RNDV_ATS, ucp_rndv_ats_handler,
1838 ucp_rndv_dump, 0);
1839 UCP_DEFINE_AM(UCP_FEATURE_TAG, UCP_AM_ID_RNDV_ATP, ucp_rndv_atp_handler,
1840 ucp_rndv_dump, 0);
1841 UCP_DEFINE_AM(UCP_FEATURE_TAG, UCP_AM_ID_RNDV_RTR, ucp_rndv_rtr_handler,
1842 ucp_rndv_dump, 0);
1843 UCP_DEFINE_AM(UCP_FEATURE_TAG, UCP_AM_ID_RNDV_DATA, ucp_rndv_data_handler,
1844 ucp_rndv_dump, 0);
1845
1846 UCP_DEFINE_AM_PROXY(UCP_AM_ID_RNDV_RTS);
1847 UCP_DEFINE_AM_PROXY(UCP_AM_ID_RNDV_ATS);
1848 UCP_DEFINE_AM_PROXY(UCP_AM_ID_RNDV_ATP);
1849 UCP_DEFINE_AM_PROXY(UCP_AM_ID_RNDV_RTR);
1850 UCP_DEFINE_AM_PROXY(UCP_AM_ID_RNDV_DATA);
1851