1 /**
2 * Copyright (C) Mellanox Technologies Ltd. 2017. ALL RIGHTS RESERVED.
3 *
4 * See file LICENSE for terms.
5 */
6
7 #include <list>
8 #include <numeric>
9 #include <set>
10 #include <vector>
11
12 #include "ucp_datatype.h"
13 #include "ucp_test.h"
14
15
16 class test_ucp_stream_base : public ucp_test {
17 public:
get_ctx_params()18 static ucp_params_t get_ctx_params() {
19 ucp_params_t params = ucp_test::get_ctx_params();
20 params.field_mask |= UCP_PARAM_FIELD_FEATURES;
21 params.features = UCP_FEATURE_STREAM;
22 return params;
23 }
24
ucp_send_cb(void * request,ucs_status_t status)25 static void ucp_send_cb(void *request, ucs_status_t status) {}
ucp_recv_cb(void * request,ucs_status_t status,size_t length)26 static void ucp_recv_cb(void *request, ucs_status_t status, size_t length) {}
27
28 size_t wait_stream_recv(void *request);
29
30 protected:
31 ucs_status_ptr_t stream_send_nb(const ucp::data_type_desc_t& dt_desc);
32 };
33
wait_stream_recv(void * request)34 size_t test_ucp_stream_base::wait_stream_recv(void *request)
35 {
36 ucs_time_t deadline = ucs::get_deadline();
37 ucs_status_t status;
38 size_t length;
39 do {
40 progress();
41 status = ucp_stream_recv_request_test(request, &length);
42 } while ((status == UCS_INPROGRESS) && (ucs_get_time() < deadline));
43 ASSERT_UCS_OK(status);
44 ucp_request_free(request);
45
46 return length;
47 }
48
49 ucs_status_ptr_t
stream_send_nb(const ucp::data_type_desc_t & dt_desc)50 test_ucp_stream_base::stream_send_nb(const ucp::data_type_desc_t& dt_desc)
51 {
52 return ucp_stream_send_nb(sender().ep(), dt_desc.buf(), dt_desc.count(),
53 dt_desc.dt(), ucp_send_cb, 0);
54 }
55
56 class test_ucp_stream_onesided : public test_ucp_stream_base {
57 public:
get_ep_params()58 ucp_ep_params_t get_ep_params() {
59 ucp_ep_params_t params = test_ucp_stream_base::get_ep_params();
60 params.field_mask |= UCP_EP_PARAM_FIELD_FLAGS;
61 params.flags |= UCP_EP_PARAMS_FLAGS_NO_LOOPBACK;
62 return params;
63 }
64 };
65
UCS_TEST_P(test_ucp_stream_onesided,recv_not_connected_ep_cleanup)66 UCS_TEST_P(test_ucp_stream_onesided, recv_not_connected_ep_cleanup) {
67 receiver().connect(&sender(), get_ep_params());
68
69 uint64_t recv_data = 0;
70 size_t length;
71 void *rreq = ucp_stream_recv_nb(receiver().ep(), &recv_data, 1,
72 ucp_dt_make_contig(sizeof(uint64_t)),
73 ucp_recv_cb, &length,
74 UCP_STREAM_RECV_FLAG_WAITALL);
75 EXPECT_TRUE(UCS_PTR_IS_PTR(rreq));
76 EXPECT_EQ(UCS_INPROGRESS, ucp_request_check_status(rreq));
77 disconnect(receiver());
78 EXPECT_EQ(UCS_ERR_CANCELED, ucp_request_check_status(rreq));
79 ucp_request_free(rreq);
80 }
81
UCS_TEST_P(test_ucp_stream_onesided,recv_connected_ep_cleanup)82 UCS_TEST_P(test_ucp_stream_onesided, recv_connected_ep_cleanup) {
83 skip_loopback();
84 sender().connect(&receiver(), get_ep_params());
85 receiver().connect(&sender(), get_ep_params());
86
87 uint64_t send_data = ucs::rand();
88 uint64_t recv_data = 0;
89 ucp_datatype_t dt = ucp_dt_make_contig(sizeof(uint64_t));
90
91 ucp::data_type_desc_t send_dt_desc(dt, &send_data, sizeof(send_data));
92 void *sreq = stream_send_nb(send_dt_desc);
93
94 size_t recvd_length;
95 void *rreq = ucp_stream_recv_nb(receiver().ep(), &recv_data, 1, dt,
96 ucp_recv_cb, &recvd_length,
97 UCP_STREAM_RECV_FLAG_WAITALL);
98
99 EXPECT_EQ(sizeof(send_data), wait_stream_recv(rreq));
100 EXPECT_EQ(send_data, recv_data);
101 wait(sreq);
102
103 rreq = ucp_stream_recv_nb(receiver().ep(), &recv_data, 1, dt, ucp_recv_cb,
104 &recvd_length, UCP_STREAM_RECV_FLAG_WAITALL);
105 EXPECT_TRUE(UCS_PTR_IS_PTR(rreq));
106 EXPECT_EQ(UCS_INPROGRESS, ucp_request_check_status(rreq));
107 disconnect(sender());
108 disconnect(receiver());
109 EXPECT_EQ(UCS_ERR_CANCELED, ucp_request_check_status(rreq));
110 ucp_request_free(rreq);
111 }
112
UCS_TEST_P(test_ucp_stream_onesided,send_recv_no_ep)113 UCS_TEST_P(test_ucp_stream_onesided, send_recv_no_ep) {
114
115 /* connect from sender side only and send */
116 sender().connect(&receiver(), get_ep_params());
117 uint64_t send_data = ucs::rand();
118 ucp::data_type_desc_t dt_desc(ucp_dt_make_contig(sizeof(uint64_t)),
119 &send_data, sizeof(send_data));
120 void *sreq = stream_send_nb(dt_desc);
121 wait(sreq);
122
123 /* must not receive data before ep is created on receiver side */
124 static const size_t max_eps = 10;
125 ucp_stream_poll_ep_t poll_eps[max_eps];
126 ssize_t count = ucp_stream_worker_poll(receiver().worker(), poll_eps,
127 max_eps, 0);
128 EXPECT_EQ(0l, count) << "ucp_stream_worker_poll returned ep too early";
129
130 /* create receiver side ep */
131 ucp_ep_params_t recv_ep_param = get_ep_params();
132 recv_ep_param.field_mask |= UCP_EP_PARAM_FIELD_USER_DATA;
133 recv_ep_param.user_data = reinterpret_cast<void*>(static_cast<uintptr_t>(ucs::rand()));
134 receiver().connect(&sender(), recv_ep_param);
135
136 /* expect ep to be ready */
137 ucs_time_t deadline = ucs_get_time() +
138 (ucs_time_from_sec(10.0) * ucs::test_time_multiplier());
139 do {
140 progress();
141 count = ucp_stream_worker_poll(receiver().worker(), poll_eps, max_eps, 0);
142 } while ((count == 0) && (ucs_get_time() < deadline));
143 EXPECT_EQ(1l, count);
144 EXPECT_EQ(recv_ep_param.user_data, poll_eps[0].user_data);
145 EXPECT_EQ(receiver().ep(0), poll_eps[0].ep);
146
147 /* expect data to be received */
148 uint64_t recv_data = 0;
149 size_t recv_length = 0;
150 void *rreq = ucp_stream_recv_nb(receiver().ep(), &recv_data, 1,
151 ucp_dt_make_contig(sizeof(uint64_t)),
152 ucp_recv_cb, &recv_length, 0);
153 ASSERT_UCS_PTR_OK(rreq);
154 if (rreq != NULL) {
155 recv_length = wait_stream_recv(rreq);
156 }
157
158 EXPECT_EQ(sizeof(uint64_t), recv_length);
159 EXPECT_EQ(send_data, recv_data);
160 }
161
162 UCP_INSTANTIATE_TEST_CASE(test_ucp_stream_onesided)
163
164 class test_ucp_stream : public test_ucp_stream_base
165 {
166 public:
init()167 virtual void init() {
168 ucp_test::init();
169
170 sender().connect(&receiver(), get_ep_params());
171 if (!is_loopback()) {
172 receiver().connect(&sender(), get_ep_params());
173 }
174 }
175
176 protected:
177 void do_send_recv_data_test(ucp_datatype_t datatype);
178 template <typename T, unsigned recv_flags>
179 void do_send_recv_test(ucp_datatype_t datatype);
180 template <typename T, unsigned recv_flags>
181 void do_send_exp_recv_test(ucp_datatype_t datatype);
182 void do_send_recv_data_recv_test(ucp_datatype_t datatype);
183
184 /* for self-validation of generic datatype
185 * NOTE: it's tested only with byte array data since it's recv completion
186 * granularity without UCP_RECV_FLAG_WAITALL flag */
187 std::vector<uint8_t> context;
188 };
189
do_send_recv_data_test(ucp_datatype_t datatype)190 void test_ucp_stream::do_send_recv_data_test(ucp_datatype_t datatype)
191 {
192 size_t ssize = 0; /* total send size in bytes */
193 std::vector<char> sbuf(16 * UCS_MBYTE, 's');
194 std::vector<char> check_pattern;
195 ucs_status_ptr_t sstatus;
196
197 /* send all msg sizes*/
198 for (size_t i = 3; i < sbuf.size();
199 i *= (2 * ucs::test_time_multiplier())) {
200 if (UCP_DT_IS_GENERIC(datatype)) {
201 for (size_t j = 0; j < i; ++j) {
202 check_pattern.push_back(char(j));
203 }
204 } else {
205 ucs::fill_random(sbuf, i);
206 check_pattern.insert(check_pattern.end(), sbuf.begin(),
207 sbuf.begin() + i);
208 }
209 ucp::data_type_desc_t dt_desc(datatype, sbuf.data(), i);
210 sstatus = stream_send_nb(dt_desc);
211 EXPECT_FALSE(UCS_PTR_IS_ERR(sstatus));
212 wait(sstatus);
213 ssize += i;
214 }
215
216 std::vector<char> rbuf(ssize, 'r');
217 size_t roffset = 0;
218 ucs_status_ptr_t rdata;
219 size_t length;
220 do {
221 progress();
222 rdata = ucp_stream_recv_data_nb(receiver().ep(), &length);
223 if (rdata == NULL) {
224 continue;
225 }
226
227 memcpy(&rbuf[roffset], rdata, length);
228 roffset += length;
229 ucp_stream_data_release(receiver().ep(), rdata);
230 } while (roffset < ssize);
231
232 EXPECT_EQ(roffset, ssize);
233 EXPECT_EQ(check_pattern, rbuf);
234 }
235
236 template <typename T, unsigned recv_flags>
do_send_recv_test(ucp_datatype_t datatype)237 void test_ucp_stream::do_send_recv_test(ucp_datatype_t datatype)
238 {
239 const size_t dt_elem_size = UCP_DT_IS_CONTIG(datatype) ?
240 ucp_contig_dt_elem_size(datatype) : 1;
241 size_t ssize = 0; /* total send size */
242 std::vector<char> sbuf(16 * UCS_MBYTE, 's');
243 ucs_status_ptr_t sstatus;
244 std::vector<char> check_pattern;
245
246 /* send all msg sizes in bytes*/
247 for (size_t i = 3; i < sbuf.size(); i *= 2) {
248 ucp_datatype_t dt;
249 if (UCP_DT_IS_GENERIC(datatype)) {
250 dt = datatype;
251 for (size_t j = 0; j < i; ++j) {
252 context.push_back(uint8_t(j));
253 }
254 } else {
255 dt = DATATYPE;
256 ucs::fill_random(sbuf, i);
257 check_pattern.insert(check_pattern.end(), sbuf.begin(),
258 sbuf.begin() + i);
259 }
260 ucp::data_type_desc_t dt_desc(dt, sbuf.data(), i);
261 sstatus = stream_send_nb(dt_desc);
262 EXPECT_FALSE(UCS_PTR_IS_ERR(sstatus));
263 wait(sstatus);
264 ssize += i;
265 }
266
267 size_t align_tail = UCP_DT_IS_GENERIC(datatype) ? 0 :
268 (dt_elem_size - ssize % dt_elem_size);
269 if (align_tail != 0) {
270 ucs::fill_random(sbuf, align_tail);
271 check_pattern.insert(check_pattern.end(), sbuf.begin(), sbuf.begin() + align_tail);
272 ucp::data_type_desc_t dt_desc(ucp_dt_make_contig(align_tail),
273 sbuf.data(), align_tail);
274 sstatus = stream_send_nb(dt_desc);
275 EXPECT_FALSE(UCS_PTR_IS_ERR(sstatus));
276 wait(sstatus);
277 ssize += align_tail;
278 }
279
280 EXPECT_EQ(size_t(0), (ssize % dt_elem_size));
281
282 std::vector<T> rbuf(ssize / dt_elem_size, 'r');
283 size_t roffset = 0;
284 size_t counter = 0;
285 do {
286 ucp::data_type_desc_t dt_desc(datatype, &rbuf[roffset / dt_elem_size],
287 ssize - roffset);
288
289 size_t length;
290 void *rreq = ucp_stream_recv_nb(receiver().ep(), dt_desc.buf(),
291 dt_desc.count(), dt_desc.dt(),
292 ucp_recv_cb, &length, recv_flags);
293 ASSERT_TRUE(!UCS_PTR_IS_ERR(rreq));
294 if (UCS_PTR_IS_PTR(rreq)) {
295 length = wait_stream_recv(rreq);
296 }
297 EXPECT_EQ(size_t(0), length % dt_elem_size);
298 roffset += length;
299 counter++;
300 } while (roffset < ssize);
301
302 /* waitall flag requires completion by single request */
303 if (recv_flags & UCP_STREAM_RECV_FLAG_WAITALL) {
304 EXPECT_EQ(size_t(1), counter);
305 }
306
307 EXPECT_EQ(roffset, ssize);
308 if (!UCP_DT_IS_GENERIC(datatype)) {
309 const T *check_ptr = reinterpret_cast<const T *>(check_pattern.data());
310 const size_t check_size = check_pattern.size() / dt_elem_size;
311 EXPECT_EQ(std::vector<T>(check_ptr, check_ptr + check_size), rbuf);
312 }
313 }
314
315 template <typename T, unsigned recv_flags>
do_send_exp_recv_test(ucp_datatype_t datatype)316 void test_ucp_stream::do_send_exp_recv_test(ucp_datatype_t datatype)
317 {
318 const size_t dt_elem_size = UCP_DT_IS_CONTIG(datatype) ?
319 ucp_contig_dt_elem_size(datatype) : 1;
320 const size_t msg_size = dt_elem_size * UCS_MBYTE;
321 const size_t n_msgs = 10;
322
323 std::vector<std::vector<T> > rbufs(n_msgs,
324 std::vector<T>(msg_size / dt_elem_size, 'r'));
325 std::vector<ucp::data_type_desc_t> dt_rdescs(n_msgs);
326 std::vector<void *> rreqs;
327
328 /* post recvs */
329 for (size_t i = 0; i < n_msgs; ++i) {
330 ucp::data_type_desc_t &rdesc = dt_rdescs[i].make(datatype, &rbufs[i][0],
331 msg_size);
332 size_t length;
333
334 void *rreq = ucp_stream_recv_nb(receiver().ep(), rdesc.buf(),
335 rdesc.count(), rdesc.dt(), ucp_recv_cb,
336 &length, recv_flags);
337 EXPECT_TRUE(UCS_PTR_IS_PTR(rreq));
338 rreqs.push_back(rreq);
339 }
340
341 std::vector<char> sbuf(msg_size, 's');
342 size_t scount = 0; /* total send size */
343 ucp::data_type_desc_t dt_desc(datatype, sbuf.data(), sbuf.size());
344
345 /* send all msgs */
346 for (size_t i = 0; i < n_msgs; ++i) {
347 void *sreq = stream_send_nb(dt_desc);
348 EXPECT_FALSE(UCS_PTR_IS_ERR(sreq));
349 wait(sreq);
350 scount += sbuf.size();
351 }
352
353 size_t rcount = 0;
354 for (size_t i = 0; i < rreqs.size(); ++i) {
355 size_t length = wait_stream_recv(rreqs[i]);
356 EXPECT_EQ(size_t(0), length % dt_elem_size);
357 rcount += length;
358 }
359
360 size_t counter = 0;
361 while (rcount < scount) {
362 size_t length = std::numeric_limits<size_t>::max();
363 ucs_status_ptr_t rreq;
364 rreq = ucp_stream_recv_nb(receiver().ep(), dt_rdescs[0].buf(),
365 dt_rdescs[0].count(), dt_rdescs[0].dt(),
366 ucp_recv_cb, &length, 0);
367 if (UCS_PTR_IS_PTR(rreq)) {
368 length = wait_stream_recv(rreq);
369 }
370 ASSERT_GT(length, 0ul);
371 ASSERT_LE(length, msg_size);
372 EXPECT_EQ(size_t(0), length % dt_elem_size);
373 rcount += length;
374 counter++;
375 }
376 EXPECT_EQ(scount, rcount);
377
378 /* waitall flag requires completion by single request */
379 if (recv_flags & UCP_STREAM_RECV_FLAG_WAITALL) {
380 EXPECT_EQ(size_t(0), counter);
381 }
382
383 /* double check, no data should be here */
384 while (progress());
385
386 size_t s;
387 void *p;
388 while ((p = ucp_stream_recv_data_nb(receiver().ep(), &s)) != NULL) {
389 rcount += s;
390 ucp_stream_data_release(receiver().ep(), p);
391 progress();
392 }
393 EXPECT_EQ(scount, rcount);
394 }
395
do_send_recv_data_recv_test(ucp_datatype_t datatype)396 void test_ucp_stream::do_send_recv_data_recv_test(ucp_datatype_t datatype)
397 {
398 const size_t dt_elem_size = UCP_DT_IS_CONTIG(datatype) ?
399 ucp_contig_dt_elem_size(datatype) : 1;
400 size_t ssize = 0; /* total send size */
401 size_t roffset = 0;
402 size_t send_i = dt_elem_size;
403 size_t recv_i = 0;
404 std::vector<char> sbuf(16 * UCS_MBYTE, 's');
405 ucs_status_ptr_t sstatus;
406 std::vector<char> check_pattern;
407 std::vector<char> rbuf;
408 ucs_status_ptr_t rdata;
409 size_t length;
410
411 do {
412 if (send_i < sbuf.size()) {
413 rbuf.resize(rbuf.size() + send_i, 'r');
414 ucs::fill_random(sbuf, send_i);
415 check_pattern.insert(check_pattern.end(), sbuf.begin(),
416 sbuf.begin() + send_i);
417 ucp::data_type_desc_t dt_desc(datatype, sbuf.data(), send_i);
418 sstatus = stream_send_nb(dt_desc);
419 EXPECT_FALSE(UCS_PTR_IS_ERR(sstatus));
420 wait(sstatus);
421 ssize += send_i;
422 send_i *= 2;
423 }
424
425 progress();
426
427 if ((++recv_i % 2) || ((ssize - roffset) < dt_elem_size)) {
428 rdata = ucp_stream_recv_data_nb(receiver().ep(), &length);
429 if (rdata == NULL) {
430 continue;
431 }
432
433 memcpy(&rbuf[roffset], rdata, length);
434 ucp_stream_data_release(receiver().ep(), rdata);
435 } else {
436 ucp::data_type_desc_t dt_desc(datatype, &rbuf[roffset], ssize - roffset);
437 void *rreq = ucp_stream_recv_nb(receiver().ep(), dt_desc.buf(),
438 dt_desc.count(), dt_desc.dt(),
439 ucp_recv_cb, &length, 0);
440 ASSERT_TRUE(!UCS_PTR_IS_ERR(rreq));
441 if (UCS_PTR_IS_PTR(rreq)) {
442 length = wait_stream_recv(rreq);
443 }
444 }
445 roffset += length;
446 } while (roffset < ssize);
447
448 EXPECT_EQ(roffset, ssize);
449 EXPECT_EQ(check_pattern, rbuf);
450 }
451
UCS_TEST_P(test_ucp_stream,send_recv_data)452 UCS_TEST_P(test_ucp_stream, send_recv_data) {
453 do_send_recv_data_test(DATATYPE);
454 }
455
UCS_TEST_P(test_ucp_stream,send_iov_recv_data)456 UCS_TEST_P(test_ucp_stream, send_iov_recv_data) {
457 do_send_recv_data_test(DATATYPE_IOV);
458 }
459
UCS_TEST_P(test_ucp_stream,send_generic_recv_data)460 UCS_TEST_P(test_ucp_stream, send_generic_recv_data) {
461 ucp_datatype_t dt;
462 ucs_status_t status;
463
464 status = ucp_dt_create_generic(&ucp::test_dt_uint8_ops, NULL, &dt);
465 ASSERT_UCS_OK(status);
466 do_send_recv_data_test(dt);
467 ucp_dt_destroy(dt);
468 }
469
UCS_TEST_P(test_ucp_stream,send_recv_8)470 UCS_TEST_P(test_ucp_stream, send_recv_8) {
471 ucp_datatype_t datatype = ucp_dt_make_contig(sizeof(uint8_t));
472
473 do_send_recv_test<uint8_t, 0>(datatype);
474 do_send_recv_test<uint8_t, UCP_STREAM_RECV_FLAG_WAITALL>(datatype);
475 }
476
UCS_TEST_P(test_ucp_stream,send_recv_16)477 UCS_TEST_P(test_ucp_stream, send_recv_16) {
478 ucp_datatype_t datatype = ucp_dt_make_contig(sizeof(uint16_t));
479
480 do_send_recv_test<uint16_t, 0>(datatype);
481 do_send_recv_test<uint16_t, UCP_STREAM_RECV_FLAG_WAITALL>(datatype);
482 }
483
UCS_TEST_P(test_ucp_stream,send_recv_32)484 UCS_TEST_P(test_ucp_stream, send_recv_32) {
485 ucp_datatype_t datatype = ucp_dt_make_contig(sizeof(uint32_t));
486
487 do_send_recv_test<uint32_t, 0>(datatype);
488 do_send_recv_test<uint32_t, UCP_STREAM_RECV_FLAG_WAITALL>(datatype);
489 }
490
UCS_TEST_P(test_ucp_stream,send_recv_64)491 UCS_TEST_P(test_ucp_stream, send_recv_64) {
492 ucp_datatype_t datatype = ucp_dt_make_contig(sizeof(uint64_t));
493
494 do_send_recv_test<uint64_t, 0>(datatype);
495 do_send_recv_test<uint64_t, UCP_STREAM_RECV_FLAG_WAITALL>(datatype);
496 }
497
UCS_TEST_P(test_ucp_stream,send_recv_iov)498 UCS_TEST_P(test_ucp_stream, send_recv_iov) {
499 do_send_recv_test<uint8_t, 0>(DATATYPE_IOV);
500 do_send_recv_test<uint8_t, UCP_STREAM_RECV_FLAG_WAITALL>(DATATYPE_IOV);
501 }
502
UCS_TEST_P(test_ucp_stream,send_recv_generic)503 UCS_TEST_P(test_ucp_stream, send_recv_generic) {
504 ucp_datatype_t dt;
505 ucs_status_t status;
506
507 status = ucp_dt_create_generic(&ucp::test_dt_uint8_ops, &context, &dt);
508 ASSERT_UCS_OK(status);
509 do_send_recv_test<uint8_t, UCP_STREAM_RECV_FLAG_WAITALL>(dt);
510 ucp_dt_destroy(dt);
511 }
512
UCS_TEST_P(test_ucp_stream,send_exp_recv_8)513 UCS_TEST_P(test_ucp_stream, send_exp_recv_8) {
514 ucp_datatype_t datatype = ucp_dt_make_contig(sizeof(uint8_t));
515
516 do_send_exp_recv_test<uint8_t, 0>(datatype);
517 do_send_exp_recv_test<uint8_t, UCP_STREAM_RECV_FLAG_WAITALL>(datatype);
518 }
519
UCS_TEST_P(test_ucp_stream,send_exp_recv_16)520 UCS_TEST_P(test_ucp_stream, send_exp_recv_16) {
521 ucp_datatype_t datatype = ucp_dt_make_contig(sizeof(uint16_t));
522
523 do_send_exp_recv_test<uint16_t, 0>(datatype);
524 do_send_exp_recv_test<uint16_t, UCP_STREAM_RECV_FLAG_WAITALL>(datatype);
525 }
526
UCS_TEST_P(test_ucp_stream,send_exp_recv_32)527 UCS_TEST_P(test_ucp_stream, send_exp_recv_32) {
528 ucp_datatype_t datatype = ucp_dt_make_contig(sizeof(uint32_t));
529
530 do_send_exp_recv_test<uint32_t, 0>(datatype);
531 do_send_exp_recv_test<uint32_t, UCP_STREAM_RECV_FLAG_WAITALL>(datatype);
532 }
533
UCS_TEST_P(test_ucp_stream,send_exp_recv_64)534 UCS_TEST_P(test_ucp_stream, send_exp_recv_64) {
535 ucp_datatype_t datatype = ucp_dt_make_contig(sizeof(uint64_t));
536
537 do_send_exp_recv_test<uint64_t, 0>(datatype);
538 do_send_exp_recv_test<uint64_t, UCP_STREAM_RECV_FLAG_WAITALL>(datatype);
539 }
540
UCS_TEST_P(test_ucp_stream,send_exp_recv_iov)541 UCS_TEST_P(test_ucp_stream, send_exp_recv_iov) {
542 do_send_exp_recv_test<uint8_t, 0>(DATATYPE_IOV);
543 do_send_exp_recv_test<uint8_t, UCP_STREAM_RECV_FLAG_WAITALL>(DATATYPE_IOV);
544 }
545
UCS_TEST_P(test_ucp_stream,send_recv_data_recv_8)546 UCS_TEST_P(test_ucp_stream, send_recv_data_recv_8) {
547 do_send_recv_data_recv_test(ucp_dt_make_contig(sizeof(uint8_t)));
548 }
549
UCS_TEST_P(test_ucp_stream,send_recv_data_recv_16)550 UCS_TEST_P(test_ucp_stream, send_recv_data_recv_16) {
551 do_send_recv_data_recv_test(ucp_dt_make_contig(sizeof(uint16_t)));
552 }
553
UCS_TEST_P(test_ucp_stream,send_recv_data_recv_32)554 UCS_TEST_P(test_ucp_stream, send_recv_data_recv_32) {
555 do_send_recv_data_recv_test(ucp_dt_make_contig(sizeof(uint32_t)));
556 }
557
UCS_TEST_P(test_ucp_stream,send_recv_data_recv_64)558 UCS_TEST_P(test_ucp_stream, send_recv_data_recv_64) {
559 do_send_recv_data_recv_test(ucp_dt_make_contig(sizeof(uint64_t)));
560 }
561
UCS_TEST_P(test_ucp_stream,send_recv_data_recv_iov)562 UCS_TEST_P(test_ucp_stream, send_recv_data_recv_iov) {
563 do_send_recv_data_recv_test(DATATYPE_IOV);
564 }
565
UCS_TEST_P(test_ucp_stream,send_zero_ending_iov_recv_data)566 UCS_TEST_P(test_ucp_stream, send_zero_ending_iov_recv_data) {
567 const size_t min_size = UCS_KBYTE;
568 const size_t max_size = min_size * 64;
569 const size_t iov_num = 8; /* must be divisible by 4 without a
570 * remainder, caught on mlx5 based TLs
571 * where max_iov = 3 for zcopy multi
572 * protocol, where every posting includes:
573 * 1 header + 2 nonempty IOVs */
574 const size_t iov_num_nonempty = iov_num / 2;
575
576 std::vector<uint8_t> buf(max_size * 2);
577 ucs::fill_random(buf, buf.size());
578 std::vector<ucp_dt_iov_t> v(iov_num);
579
580 for (size_t size = min_size; size < max_size; ++size) {
581 size_t slen = 0;
582 for (size_t j = 0; j < iov_num; ++j) {
583 if ((j % 2) == 0) {
584 uint8_t *ptr = buf.data();
585 v[j].buffer = &(ptr[j * size / iov_num_nonempty]);
586 v[j].length = size / iov_num_nonempty;
587 slen += v[j].length;
588 } else {
589 v[j].buffer = NULL;
590 v[j].length = 0;
591 }
592 }
593
594 void *sreq = ucp_stream_send_nb(sender().ep(), &v[0], iov_num,
595 DATATYPE_IOV, ucp_send_cb, 0);
596
597 size_t rlen = 0;
598 while (rlen < slen) {
599 progress();
600 size_t length;
601 void *rdata = ucp_stream_recv_data_nb(receiver().ep(), &length);
602 EXPECT_FALSE(UCS_PTR_IS_ERR(rdata));
603 if (rdata != NULL) {
604 rlen += length;
605 ucp_stream_data_release(receiver().ep(), rdata);
606 }
607 }
608 wait(sreq);
609 }
610 }
611
612 UCP_INSTANTIATE_TEST_CASE(test_ucp_stream)
613
614 class test_ucp_stream_many2one : public test_ucp_stream_base {
615 protected:
616 struct request_wrapper_t {
request_wrapper_ttest_ucp_stream_many2one::request_wrapper_t617 request_wrapper_t(void *request, ucp::data_type_desc_t *dt_desc)
618 : m_req(request), m_dt_desc(dt_desc) {}
619
620 void *m_req;
621 ucp::data_type_desc_t *m_dt_desc;
622 };
623
624 public:
test_ucp_stream_many2one()625 test_ucp_stream_many2one() : m_receiver_idx(3), m_nsenders(3) {
626 m_recv_data.resize(m_nsenders);
627 }
628
get_ctx_params()629 static ucp_params_t get_ctx_params() {
630 return test_ucp_stream::get_ctx_params();
631 }
632
633 virtual void init();
ucp_send_cb(void * request,ucs_status_t status)634 static void ucp_send_cb(void *request, ucs_status_t status) {}
ucp_recv_cb(void * request,ucs_status_t status,size_t length)635 static void ucp_recv_cb(void *request, ucs_status_t status, size_t length) {}
636
637 void do_send_worker_poll_test(ucp_datatype_t dt);
638 void do_send_recv_test(ucp_datatype_t dt);
639
640 protected:
641 static void erase_completed_reqs(std::vector<request_wrapper_t> &reqs);
642 ucs_status_ptr_t stream_send_nb(size_t sender_idx,
643 const ucp::data_type_desc_t& dt_desc);
644 size_t send_all_nb(ucp_datatype_t datatype, size_t n_iter,
645 std::vector<request_wrapper_t> &sreqs);
646 size_t send_all(ucp_datatype_t datatype, size_t n_iter);
647 void check_no_data();
648 std::set<ucp_ep_h> check_no_data(entity &e);
649 void check_recv_data(size_t n_iter, ucp_datatype_t dt);
650
651 std::vector<std::string> m_msgs;
652 std::vector<std::vector<char> > m_recv_data;
653 const size_t m_receiver_idx;
654 const size_t m_nsenders;
655 };
656
init()657 void test_ucp_stream_many2one::init()
658 {
659 if (is_self()) {
660 UCS_TEST_SKIP_R("self");
661 }
662
663 /* Skip entities creation */
664 test_base::init();
665
666 for (size_t i = 0; i < m_nsenders + 1; ++i) {
667 create_entity();
668 }
669
670 for (size_t i = 0; i < m_nsenders; ++i) {
671 e(i).connect(&e(m_receiver_idx), get_ep_params(), i);
672
673 ucp_ep_params_t recv_ep_param = get_ep_params();
674 recv_ep_param.field_mask |= UCP_EP_PARAM_FIELD_USER_DATA;
675 recv_ep_param.user_data = (void *)uintptr_t(i);
676 e(m_receiver_idx).connect(&e(i), recv_ep_param, i);
677 }
678
679 for (size_t i = 0; i < m_nsenders; ++i) {
680 m_msgs.push_back(std::string("sender_") + ucs::to_string(i));
681 }
682 }
683
do_send_worker_poll_test(ucp_datatype_t dt)684 void test_ucp_stream_many2one::do_send_worker_poll_test(ucp_datatype_t dt)
685 {
686 const size_t niter = 2018;
687 std::vector<request_wrapper_t> sreqs;
688 size_t total_len;
689
690 total_len = send_all_nb(dt, niter, sreqs);
691
692 /* Recv and progress all data */
693 do {
694 ssize_t count;
695 do {
696 const size_t max_eps = 10;
697 ucp_stream_poll_ep_t poll_eps[max_eps];
698 progress();
699 count = ucp_stream_worker_poll(e(m_receiver_idx).worker(),
700 poll_eps, max_eps, 0);
701 EXPECT_LE(0, count);
702
703 for (ssize_t i = 0; i < count; ++i) {
704 char *rdata;
705 size_t length;
706 while ((rdata = (char *)ucp_stream_recv_data_nb(poll_eps[i].ep,
707 &length)) != NULL) {
708 ASSERT_FALSE(UCS_PTR_IS_ERR(rdata));
709 size_t senser_idx = uintptr_t(poll_eps[i].user_data);
710 std::vector<char> &dst = m_recv_data[senser_idx];
711 dst.insert(dst.end(), rdata, rdata + length);
712 total_len -= length;
713 ucp_stream_data_release(poll_eps[i].ep, rdata);
714 }
715 }
716 } while (count > 0);
717
718 erase_completed_reqs(sreqs);
719 } while (!sreqs.empty() || (total_len != 0));
720
721 check_no_data();
722 check_recv_data(niter, dt);
723 }
724
do_send_recv_test(ucp_datatype_t dt)725 void test_ucp_stream_many2one::do_send_recv_test(ucp_datatype_t dt)
726 {
727 const size_t niter = 2018;
728 std::vector<size_t> roffsets(m_nsenders, 0);
729 std::vector<ucp::data_type_desc_t> dt_rdescs(m_nsenders);
730 std::vector<std::pair<size_t, request_wrapper_t> > rreqs;
731 std::vector<request_wrapper_t> sreqs;
732 size_t total_sdata;
733
734 ASSERT_FALSE(m_msgs.empty());
735
736 /* Do preposts */
737 for (size_t i = 0; i < m_nsenders; ++i) {
738 m_recv_data[i].resize(m_msgs[i].length() * niter + 1);
739 ucp::data_type_desc_t &rdesc = dt_rdescs[i].make(dt,
740 &m_recv_data[i][roffsets[i]],
741 m_recv_data[i].size());
742 size_t length;
743 void *rreq = ucp_stream_recv_nb(e(m_receiver_idx).ep(0, i),
744 rdesc.buf(), rdesc.count(), rdesc.dt(),
745 ucp_recv_cb, &length, 0);
746 EXPECT_TRUE(UCS_PTR_IS_PTR(rreq));
747 rreqs.push_back(std::make_pair(i, request_wrapper_t(rreq, &rdesc)));
748 }
749
750 total_sdata = send_all_nb(dt, niter, sreqs);
751
752 /* Recv and progress all the rest of data */
753 do {
754 ssize_t count;
755 /* wait rreqs */
756 for (size_t i = 0; i < rreqs.size(); ++i) {
757 roffsets[rreqs[i].first] += wait_stream_recv(rreqs[i].second.m_req);
758 }
759 rreqs.clear();
760 progress();
761
762 const size_t max_eps = 10;
763 ucp_stream_poll_ep_t poll_eps[max_eps];
764 count = ucp_stream_worker_poll(e(m_receiver_idx).worker(),
765 poll_eps, max_eps, 0);
766 EXPECT_LE(0, count);
767 EXPECT_LE(size_t(count), m_nsenders);
768
769 for (ssize_t i = 0; i < count; ++i) {
770 bool again = true;
771 while (again) {
772 size_t sender_idx = uintptr_t(poll_eps[i].user_data);
773 size_t &roffset = roffsets[sender_idx];
774 ucp::data_type_desc_t &dt_desc =
775 dt_rdescs[sender_idx].forward_to(roffset);
776 EXPECT_TRUE(dt_desc.is_valid());
777 size_t length;
778 void *rreq = ucp_stream_recv_nb(poll_eps[i].ep,
779 dt_desc.buf(),
780 dt_desc.count(),
781 dt_desc.dt(),
782 ucp_recv_cb, &length, 0);
783 EXPECT_FALSE(UCS_PTR_IS_ERR(rreq));
784 if (rreq == NULL) {
785 EXPECT_LT(size_t(0), length);
786 roffset += length;
787 if (ssize_t(length) < dt_desc.buf_length()) {
788 continue; /* Need to drain the EP */
789 }
790 } else {
791 rreqs.push_back(std::make_pair(sender_idx,
792 request_wrapper_t(rreq,
793 &dt_desc)));
794 }
795 again = false;
796 }
797 }
798
799 erase_completed_reqs(sreqs);
800 } while (!rreqs.empty() || !sreqs.empty() ||
801 (total_sdata > std::accumulate(roffsets.begin(),
802 roffsets.end(), 0ul)));
803
804 EXPECT_EQ(total_sdata, std::accumulate(roffsets.begin(),
805 roffsets.end(), 0ul));
806 check_no_data();
807 check_recv_data(niter, dt);
808 }
809
810 ucs_status_ptr_t
stream_send_nb(size_t sender_idx,const ucp::data_type_desc_t & dt_desc)811 test_ucp_stream_many2one::stream_send_nb(size_t sender_idx,
812 const ucp::data_type_desc_t& dt_desc)
813 {
814 return ucp_stream_send_nb(m_entities.at(sender_idx).ep(), dt_desc.buf(),
815 dt_desc.count(), dt_desc.dt(), ucp_send_cb, 0);
816 }
817
818 size_t
send_all_nb(ucp_datatype_t datatype,size_t n_iter,std::vector<request_wrapper_t> & sreqs)819 test_ucp_stream_many2one::send_all_nb(ucp_datatype_t datatype, size_t n_iter,
820 std::vector<request_wrapper_t> &sreqs)
821 {
822 size_t total = 0;
823 /* Send many times in round robin */
824 for (size_t i = 0; i < n_iter; ++i) {
825 for (size_t sender_idx = 0; sender_idx < m_nsenders; ++sender_idx) {
826 const void *buf = m_msgs[sender_idx].c_str();
827 size_t len = m_msgs[sender_idx].length();
828 if (i == (n_iter - 1)) {
829 ++len;
830 }
831
832 ucp::data_type_desc_t *dt_desc = new ucp::data_type_desc_t(datatype,
833 buf,
834 len);
835 void *sreq = stream_send_nb(sender_idx, *dt_desc);
836 total += len;
837 if (UCS_PTR_IS_PTR(sreq)) {
838 sreqs.push_back(request_wrapper_t(sreq, dt_desc));
839 } else {
840 EXPECT_FALSE(UCS_PTR_IS_ERR(sreq));
841 delete dt_desc;
842 }
843 }
844 }
845
846 return total;
847 }
848
849 size_t
send_all(ucp_datatype_t datatype,size_t n_iter)850 test_ucp_stream_many2one::send_all(ucp_datatype_t datatype, size_t n_iter)
851 {
852 std::vector<request_wrapper_t> sreqs;
853 size_t total;
854
855 total = send_all_nb(datatype, n_iter, sreqs);
856 while (!sreqs.empty()) {
857 progress();
858 erase_completed_reqs(sreqs);
859 }
860
861 return total;
862 }
863
check_no_data()864 void test_ucp_stream_many2one::check_no_data()
865 {
866 std::set<ucp_ep_h> check;
867
868 for (size_t i = 0; i <= m_receiver_idx; ++i) {
869 std::set<ucp_ep_h> check_e = check_no_data(e(i));
870 check.insert(check_e.begin(), check_e.end());
871 }
872
873 EXPECT_EQ(size_t(0), check.size());
874 }
875
check_no_data(entity & e)876 std::set<ucp_ep_h> test_ucp_stream_many2one::check_no_data(entity &e)
877 {
878 const size_t max_eps = 10;
879 ucp_stream_poll_ep_t poll_eps[max_eps];
880 std::set<ucp_ep_h> ret;
881 std::list<ucp_ep_h> check_list;
882
883 while (progress());
884
885 ssize_t count = ucp_stream_worker_poll(m_entities.at(m_receiver_idx).worker(),
886 poll_eps, max_eps, 0);
887 EXPECT_GE(count, ssize_t(0));
888
889 for (ssize_t i = 0; i < count; ++i) {
890 ret.insert(poll_eps[i].ep);
891 }
892
893 for (int i = 0; i < e.get_num_workers(); ++i) {
894 for (int j = 0; j < e.get_num_eps(); ++j) {
895 check_list.push_back(e.ep(i, j));
896 }
897 }
898
899 std::list<ucp_ep_h>::const_iterator check_it = check_list.begin();
900 while (check_it != check_list.end()) {
901 EXPECT_EQ(ret.end(), ret.find(*check_it));
902 ++check_it;
903 }
904
905 return ret;
906 }
907
check_recv_data(size_t n_iter,ucp_datatype_t dt)908 void test_ucp_stream_many2one::check_recv_data(size_t n_iter, ucp_datatype_t dt)
909 {
910 for (size_t i = 0; i < m_nsenders; ++i) {
911 std::string test = std::string("sender_") + ucs::to_string(i);
912 const std::string str(&m_recv_data[i].front());
913 if (UCP_DT_IS_GENERIC(dt)) {
914 std::vector<char> test_gen;
915 for (size_t j = 0; j < test.length(); ++j) {
916 test_gen.push_back(char(j));
917 }
918 test_gen.push_back('\0');
919 test = std::string(test_gen.data());
920 }
921
922 size_t next = 0;
923 for (size_t j = 0; j < n_iter; ++j) {
924 size_t match = str.find(test, next);
925 EXPECT_NE(std::string::npos, match) << "failed on sender " << i
926 << " iteration " << j;
927 if (match == std::string::npos) {
928 break;
929 }
930 EXPECT_EQ(next, match);
931 next += test.length();
932 }
933 EXPECT_EQ(next, str.length()); /* nothing more */
934 }
935 }
936
937 void
erase_completed_reqs(std::vector<request_wrapper_t> & reqs)938 test_ucp_stream_many2one::erase_completed_reqs(std::vector<request_wrapper_t> &reqs)
939 {
940 std::vector<request_wrapper_t>::iterator i = reqs.begin();
941
942 while (i != reqs.end()) {
943 ucs_status_t status = ucp_request_check_status(i->m_req);
944 if (status != UCS_INPROGRESS) {
945 EXPECT_EQ(UCS_OK, status);
946 ucp_request_free(i->m_req);
947 delete i->m_dt_desc;
948 i = reqs.erase(i);
949 } else {
950 ++i;
951 }
952 }
953 }
954
UCS_TEST_P(test_ucp_stream_many2one,drop_data)955 UCS_TEST_P(test_ucp_stream_many2one, drop_data) {
956 send_all(DATATYPE, 10);
957
958 ASSERT_EQ(m_receiver_idx, m_nsenders);
959 for (size_t i = 0; i <= m_receiver_idx; ++i) {
960 flush_worker(e(i));
961 }
962
963 /* destroy 1 connection */
964 entity::ep_destructor(m_entities.at(0).ep(),
965 &m_entities.at(0));
966 entity::ep_destructor(m_entities.at(m_receiver_idx).ep(),
967 &m_entities.at(0));
968 m_entities.at(0).revoke_ep();
969 m_entities.at(m_receiver_idx).revoke_ep(0, 0);
970
971 /* wait for 1-st byte on the last EP to be sure the network packets have
972 been arrived */
973 uint8_t check;
974 size_t check_length;
975 ucp_ep_h last_ep = m_entities.at(m_receiver_idx).ep(0, m_nsenders - 1);
976 void *check_req = ucp_stream_recv_nb(last_ep, &check, 1, DATATYPE,
977 ucp_recv_cb, &check_length, 0);
978 EXPECT_FALSE(UCS_PTR_IS_ERR(check_req));
979 if (UCS_PTR_IS_PTR(check_req)) {
980 wait_stream_recv(check_req);
981 }
982
983 /* data from disconnected EP should be dropped */
984 std::set<ucp_ep_h> others = check_no_data(m_entities.at(0));
985 /* since ordering between EPs is not guaranteed, some data may be still in
986 * the network or buffered by transport */
987 EXPECT_LE(others.size(), m_nsenders - 1);
988
989 /* reconnect */
990 m_entities.at(0).connect(&m_entities.at(m_receiver_idx), get_ep_params(), 0);
991 ucp_ep_params_t recv_ep_param = get_ep_params();
992 recv_ep_param.field_mask |= UCP_EP_PARAM_FIELD_USER_DATA;
993 recv_ep_param.user_data = (void *)uintptr_t(0xdeadbeef);
994 e(m_receiver_idx).connect(&e(0), recv_ep_param, 0);
995
996 /* send again */
997 send_all(DATATYPE, 10);
998
999 for (size_t i = 0; i <= m_receiver_idx; ++i) {
1000 flush_worker(e(i));
1001 }
1002
1003 /* Need to poll out all incoming data from transport layer, see PR #2048 */
1004 while (progress() > 0);
1005 }
1006
UCS_TEST_P(test_ucp_stream_many2one,send_worker_poll)1007 UCS_TEST_P(test_ucp_stream_many2one, send_worker_poll) {
1008 do_send_worker_poll_test(DATATYPE);
1009 }
1010
UCS_TEST_P(test_ucp_stream_many2one,send_worker_poll_iov)1011 UCS_TEST_P(test_ucp_stream_many2one, send_worker_poll_iov) {
1012 do_send_worker_poll_test(DATATYPE_IOV);
1013 }
1014
UCS_TEST_P(test_ucp_stream_many2one,send_worker_poll_generic)1015 UCS_TEST_P(test_ucp_stream_many2one, send_worker_poll_generic) {
1016 ucp_datatype_t dt;
1017 ucs_status_t status;
1018
1019 status = ucp_dt_create_generic(&ucp::test_dt_uint8_ops, NULL, &dt);
1020 ASSERT_UCS_OK(status);
1021 do_send_worker_poll_test(dt);
1022 ucp_dt_destroy(dt);
1023 }
1024
UCS_TEST_P(test_ucp_stream_many2one,send_recv_nb)1025 UCS_TEST_P(test_ucp_stream_many2one, send_recv_nb) {
1026 do_send_recv_test(DATATYPE);
1027 }
1028
UCS_TEST_P(test_ucp_stream_many2one,send_recv_nb_iov)1029 UCS_TEST_P(test_ucp_stream_many2one, send_recv_nb_iov) {
1030 do_send_recv_test(DATATYPE_IOV);
1031 }
1032
UCS_TEST_P(test_ucp_stream_many2one,send_recv_nb_generic)1033 UCS_TEST_P(test_ucp_stream_many2one, send_recv_nb_generic) {
1034 ucp_datatype_t dt;
1035 ucs_status_t status;
1036
1037 status = ucp_dt_create_generic(&ucp::test_dt_uint8_ops, NULL, &dt);
1038 ASSERT_UCS_OK(status);
1039 do_send_recv_test(dt);
1040 ucp_dt_destroy(dt);
1041 }
1042
1043 UCP_INSTANTIATE_TEST_CASE(test_ucp_stream_many2one)
1044