1 /**
2 * Copyright (C) Mellanox Technologies Ltd. 2017.  ALL RIGHTS RESERVED.
3 *
4 * See file LICENSE for terms.
5 */
6 
7 #include <list>
8 #include <numeric>
9 #include <set>
10 #include <vector>
11 
12 #include "ucp_datatype.h"
13 #include "ucp_test.h"
14 
15 
16 class test_ucp_stream_base : public ucp_test {
17 public:
get_ctx_params()18     static ucp_params_t get_ctx_params() {
19         ucp_params_t params = ucp_test::get_ctx_params();
20         params.field_mask  |= UCP_PARAM_FIELD_FEATURES;
21         params.features     = UCP_FEATURE_STREAM;
22         return params;
23     }
24 
ucp_send_cb(void * request,ucs_status_t status)25     static void ucp_send_cb(void *request, ucs_status_t status) {}
ucp_recv_cb(void * request,ucs_status_t status,size_t length)26     static void ucp_recv_cb(void *request, ucs_status_t status, size_t length) {}
27 
28     size_t wait_stream_recv(void *request);
29 
30 protected:
31     ucs_status_ptr_t stream_send_nb(const ucp::data_type_desc_t& dt_desc);
32 };
33 
wait_stream_recv(void * request)34 size_t test_ucp_stream_base::wait_stream_recv(void *request)
35 {
36     ucs_time_t deadline = ucs::get_deadline();
37     ucs_status_t status;
38     size_t       length;
39     do {
40         progress();
41         status = ucp_stream_recv_request_test(request, &length);
42     } while ((status == UCS_INPROGRESS) && (ucs_get_time() < deadline));
43     ASSERT_UCS_OK(status);
44     ucp_request_free(request);
45 
46     return length;
47 }
48 
49 ucs_status_ptr_t
stream_send_nb(const ucp::data_type_desc_t & dt_desc)50 test_ucp_stream_base::stream_send_nb(const ucp::data_type_desc_t& dt_desc)
51 {
52     return ucp_stream_send_nb(sender().ep(), dt_desc.buf(), dt_desc.count(),
53                               dt_desc.dt(), ucp_send_cb, 0);
54 }
55 
56 class test_ucp_stream_onesided : public test_ucp_stream_base {
57 public:
get_ep_params()58     ucp_ep_params_t get_ep_params() {
59         ucp_ep_params_t params = test_ucp_stream_base::get_ep_params();
60         params.field_mask |= UCP_EP_PARAM_FIELD_FLAGS;
61         params.flags      |= UCP_EP_PARAMS_FLAGS_NO_LOOPBACK;
62         return params;
63     }
64 };
65 
UCS_TEST_P(test_ucp_stream_onesided,recv_not_connected_ep_cleanup)66 UCS_TEST_P(test_ucp_stream_onesided, recv_not_connected_ep_cleanup) {
67     receiver().connect(&sender(), get_ep_params());
68 
69     uint64_t recv_data = 0;
70     size_t length;
71     void *rreq = ucp_stream_recv_nb(receiver().ep(), &recv_data, 1,
72                                     ucp_dt_make_contig(sizeof(uint64_t)),
73                                     ucp_recv_cb, &length,
74                                     UCP_STREAM_RECV_FLAG_WAITALL);
75     EXPECT_TRUE(UCS_PTR_IS_PTR(rreq));
76     EXPECT_EQ(UCS_INPROGRESS, ucp_request_check_status(rreq));
77     disconnect(receiver());
78     EXPECT_EQ(UCS_ERR_CANCELED, ucp_request_check_status(rreq));
79     ucp_request_free(rreq);
80 }
81 
UCS_TEST_P(test_ucp_stream_onesided,recv_connected_ep_cleanup)82 UCS_TEST_P(test_ucp_stream_onesided, recv_connected_ep_cleanup) {
83     skip_loopback();
84     sender().connect(&receiver(), get_ep_params());
85     receiver().connect(&sender(), get_ep_params());
86 
87     uint64_t send_data = ucs::rand();
88     uint64_t recv_data = 0;
89     ucp_datatype_t dt  = ucp_dt_make_contig(sizeof(uint64_t));
90 
91     ucp::data_type_desc_t send_dt_desc(dt, &send_data, sizeof(send_data));
92     void *sreq = stream_send_nb(send_dt_desc);
93 
94     size_t recvd_length;
95     void *rreq = ucp_stream_recv_nb(receiver().ep(), &recv_data, 1, dt,
96                                     ucp_recv_cb, &recvd_length,
97                                     UCP_STREAM_RECV_FLAG_WAITALL);
98 
99     EXPECT_EQ(sizeof(send_data), wait_stream_recv(rreq));
100     EXPECT_EQ(send_data, recv_data);
101     wait(sreq);
102 
103     rreq = ucp_stream_recv_nb(receiver().ep(), &recv_data, 1, dt, ucp_recv_cb,
104                               &recvd_length, UCP_STREAM_RECV_FLAG_WAITALL);
105     EXPECT_TRUE(UCS_PTR_IS_PTR(rreq));
106     EXPECT_EQ(UCS_INPROGRESS, ucp_request_check_status(rreq));
107     disconnect(sender());
108     disconnect(receiver());
109     EXPECT_EQ(UCS_ERR_CANCELED, ucp_request_check_status(rreq));
110     ucp_request_free(rreq);
111 }
112 
UCS_TEST_P(test_ucp_stream_onesided,send_recv_no_ep)113 UCS_TEST_P(test_ucp_stream_onesided, send_recv_no_ep) {
114 
115     /* connect from sender side only and send */
116     sender().connect(&receiver(), get_ep_params());
117     uint64_t send_data = ucs::rand();
118     ucp::data_type_desc_t dt_desc(ucp_dt_make_contig(sizeof(uint64_t)),
119                                   &send_data, sizeof(send_data));
120     void *sreq = stream_send_nb(dt_desc);
121     wait(sreq);
122 
123     /* must not receive data before ep is created on receiver side */
124     static const size_t max_eps = 10;
125     ucp_stream_poll_ep_t poll_eps[max_eps];
126     ssize_t count = ucp_stream_worker_poll(receiver().worker(), poll_eps,
127                                            max_eps, 0);
128     EXPECT_EQ(0l, count) << "ucp_stream_worker_poll returned ep too early";
129 
130     /* create receiver side ep */
131     ucp_ep_params_t recv_ep_param = get_ep_params();
132     recv_ep_param.field_mask |= UCP_EP_PARAM_FIELD_USER_DATA;
133     recv_ep_param.user_data   = reinterpret_cast<void*>(static_cast<uintptr_t>(ucs::rand()));
134     receiver().connect(&sender(), recv_ep_param);
135 
136     /* expect ep to be ready */
137     ucs_time_t deadline = ucs_get_time() +
138                           (ucs_time_from_sec(10.0) * ucs::test_time_multiplier());
139     do {
140         progress();
141         count = ucp_stream_worker_poll(receiver().worker(), poll_eps, max_eps, 0);
142     } while ((count == 0) && (ucs_get_time() < deadline));
143     EXPECT_EQ(1l, count);
144     EXPECT_EQ(recv_ep_param.user_data, poll_eps[0].user_data);
145     EXPECT_EQ(receiver().ep(0), poll_eps[0].ep);
146 
147     /* expect data to be received */
148     uint64_t recv_data = 0;
149     size_t recv_length = 0;
150     void *rreq = ucp_stream_recv_nb(receiver().ep(), &recv_data, 1,
151                                     ucp_dt_make_contig(sizeof(uint64_t)),
152                                     ucp_recv_cb, &recv_length, 0);
153     ASSERT_UCS_PTR_OK(rreq);
154     if (rreq != NULL) {
155         recv_length = wait_stream_recv(rreq);
156     }
157 
158     EXPECT_EQ(sizeof(uint64_t), recv_length);
159     EXPECT_EQ(send_data, recv_data);
160 }
161 
162 UCP_INSTANTIATE_TEST_CASE(test_ucp_stream_onesided)
163 
164 class test_ucp_stream : public test_ucp_stream_base
165 {
166 public:
init()167     virtual void init() {
168         ucp_test::init();
169 
170         sender().connect(&receiver(), get_ep_params());
171         if (!is_loopback()) {
172             receiver().connect(&sender(), get_ep_params());
173         }
174     }
175 
176 protected:
177     void do_send_recv_data_test(ucp_datatype_t datatype);
178     template <typename T, unsigned recv_flags>
179     void do_send_recv_test(ucp_datatype_t datatype);
180     template <typename T, unsigned recv_flags>
181     void do_send_exp_recv_test(ucp_datatype_t datatype);
182     void do_send_recv_data_recv_test(ucp_datatype_t datatype);
183 
184     /* for self-validation of generic datatype
185      * NOTE: it's tested only with byte array data since it's recv completion
186      *       granularity without UCP_RECV_FLAG_WAITALL flag */
187     std::vector<uint8_t> context;
188 };
189 
do_send_recv_data_test(ucp_datatype_t datatype)190 void test_ucp_stream::do_send_recv_data_test(ucp_datatype_t datatype)
191 {
192     size_t            ssize = 0; /* total send size in bytes */
193     std::vector<char> sbuf(16 * UCS_MBYTE, 's');
194     std::vector<char> check_pattern;
195     ucs_status_ptr_t  sstatus;
196 
197     /* send all msg sizes*/
198     for (size_t i = 3; i < sbuf.size();
199          i *= (2 * ucs::test_time_multiplier())) {
200         if (UCP_DT_IS_GENERIC(datatype)) {
201             for (size_t j = 0; j < i; ++j) {
202                 check_pattern.push_back(char(j));
203             }
204         } else {
205             ucs::fill_random(sbuf, i);
206             check_pattern.insert(check_pattern.end(), sbuf.begin(),
207                                  sbuf.begin() + i);
208         }
209         ucp::data_type_desc_t dt_desc(datatype, sbuf.data(), i);
210         sstatus = stream_send_nb(dt_desc);
211         EXPECT_FALSE(UCS_PTR_IS_ERR(sstatus));
212         wait(sstatus);
213         ssize += i;
214     }
215 
216     std::vector<char> rbuf(ssize, 'r');
217     size_t            roffset = 0;
218     ucs_status_ptr_t  rdata;
219     size_t length;
220     do {
221         progress();
222         rdata = ucp_stream_recv_data_nb(receiver().ep(), &length);
223         if (rdata == NULL) {
224             continue;
225         }
226 
227         memcpy(&rbuf[roffset], rdata, length);
228         roffset += length;
229         ucp_stream_data_release(receiver().ep(), rdata);
230     } while (roffset < ssize);
231 
232     EXPECT_EQ(roffset, ssize);
233     EXPECT_EQ(check_pattern, rbuf);
234 }
235 
236 template <typename T, unsigned recv_flags>
do_send_recv_test(ucp_datatype_t datatype)237 void test_ucp_stream::do_send_recv_test(ucp_datatype_t datatype)
238 {
239     const size_t      dt_elem_size = UCP_DT_IS_CONTIG(datatype) ?
240                                      ucp_contig_dt_elem_size(datatype) : 1;
241     size_t            ssize        = 0; /* total send size */
242     std::vector<char> sbuf(16 * UCS_MBYTE, 's');
243     ucs_status_ptr_t  sstatus;
244     std::vector<char> check_pattern;
245 
246     /* send all msg sizes in bytes*/
247     for (size_t i = 3; i < sbuf.size(); i *= 2) {
248         ucp_datatype_t dt;
249         if (UCP_DT_IS_GENERIC(datatype)) {
250             dt = datatype;
251             for (size_t j = 0; j < i; ++j) {
252                 context.push_back(uint8_t(j));
253             }
254         } else {
255             dt = DATATYPE;
256             ucs::fill_random(sbuf, i);
257             check_pattern.insert(check_pattern.end(), sbuf.begin(),
258                                  sbuf.begin() + i);
259         }
260         ucp::data_type_desc_t dt_desc(dt, sbuf.data(), i);
261         sstatus = stream_send_nb(dt_desc);
262         EXPECT_FALSE(UCS_PTR_IS_ERR(sstatus));
263         wait(sstatus);
264         ssize += i;
265     }
266 
267     size_t align_tail = UCP_DT_IS_GENERIC(datatype) ? 0 :
268                         (dt_elem_size - ssize % dt_elem_size);
269     if (align_tail != 0) {
270         ucs::fill_random(sbuf, align_tail);
271         check_pattern.insert(check_pattern.end(), sbuf.begin(), sbuf.begin() + align_tail);
272         ucp::data_type_desc_t dt_desc(ucp_dt_make_contig(align_tail),
273                                       sbuf.data(), align_tail);
274         sstatus = stream_send_nb(dt_desc);
275         EXPECT_FALSE(UCS_PTR_IS_ERR(sstatus));
276         wait(sstatus);
277         ssize += align_tail;
278     }
279 
280     EXPECT_EQ(size_t(0), (ssize % dt_elem_size));
281 
282     std::vector<T> rbuf(ssize / dt_elem_size, 'r');
283     size_t         roffset = 0;
284     size_t         counter = 0;
285     do {
286         ucp::data_type_desc_t dt_desc(datatype, &rbuf[roffset / dt_elem_size],
287                                       ssize - roffset);
288 
289         size_t length;
290         void   *rreq = ucp_stream_recv_nb(receiver().ep(), dt_desc.buf(),
291                                           dt_desc.count(), dt_desc.dt(),
292                                           ucp_recv_cb, &length, recv_flags);
293         ASSERT_TRUE(!UCS_PTR_IS_ERR(rreq));
294         if (UCS_PTR_IS_PTR(rreq)) {
295             length = wait_stream_recv(rreq);
296         }
297         EXPECT_EQ(size_t(0), length % dt_elem_size);
298         roffset += length;
299         counter++;
300     } while (roffset < ssize);
301 
302     /* waitall flag requires completion by single request */
303     if (recv_flags & UCP_STREAM_RECV_FLAG_WAITALL) {
304         EXPECT_EQ(size_t(1), counter);
305     }
306 
307     EXPECT_EQ(roffset, ssize);
308     if (!UCP_DT_IS_GENERIC(datatype)) {
309         const T     *check_ptr  = reinterpret_cast<const T *>(check_pattern.data());
310         const size_t check_size = check_pattern.size() / dt_elem_size;
311         EXPECT_EQ(std::vector<T>(check_ptr, check_ptr + check_size), rbuf);
312     }
313 }
314 
315 template <typename T, unsigned recv_flags>
do_send_exp_recv_test(ucp_datatype_t datatype)316 void test_ucp_stream::do_send_exp_recv_test(ucp_datatype_t datatype)
317 {
318     const size_t dt_elem_size = UCP_DT_IS_CONTIG(datatype) ?
319                                 ucp_contig_dt_elem_size(datatype) : 1;
320     const size_t msg_size = dt_elem_size * UCS_MBYTE;
321     const size_t n_msgs   = 10;
322 
323     std::vector<std::vector<T> > rbufs(n_msgs,
324                                        std::vector<T>(msg_size / dt_elem_size, 'r'));
325     std::vector<ucp::data_type_desc_t> dt_rdescs(n_msgs);
326     std::vector<void *> rreqs;
327 
328     /* post recvs */
329     for (size_t i = 0; i < n_msgs; ++i) {
330         ucp::data_type_desc_t &rdesc = dt_rdescs[i].make(datatype, &rbufs[i][0],
331                                                          msg_size);
332         size_t length;
333 
334         void *rreq = ucp_stream_recv_nb(receiver().ep(), rdesc.buf(),
335                                         rdesc.count(), rdesc.dt(), ucp_recv_cb,
336                                         &length, recv_flags);
337         EXPECT_TRUE(UCS_PTR_IS_PTR(rreq));
338         rreqs.push_back(rreq);
339     }
340 
341     std::vector<char>     sbuf(msg_size, 's');
342     size_t                scount = 0; /* total send size */
343     ucp::data_type_desc_t dt_desc(datatype, sbuf.data(), sbuf.size());
344 
345     /* send all msgs */
346     for (size_t i = 0; i < n_msgs; ++i) {
347         void *sreq = stream_send_nb(dt_desc);
348         EXPECT_FALSE(UCS_PTR_IS_ERR(sreq));
349         wait(sreq);
350         scount += sbuf.size();
351     }
352 
353     size_t rcount = 0;
354     for (size_t i = 0; i < rreqs.size(); ++i) {
355         size_t length = wait_stream_recv(rreqs[i]);
356         EXPECT_EQ(size_t(0), length % dt_elem_size);
357         rcount += length;
358     }
359 
360     size_t counter = 0;
361     while (rcount < scount) {
362         size_t           length = std::numeric_limits<size_t>::max();
363         ucs_status_ptr_t rreq;
364         rreq = ucp_stream_recv_nb(receiver().ep(), dt_rdescs[0].buf(),
365                                   dt_rdescs[0].count(), dt_rdescs[0].dt(),
366                                   ucp_recv_cb, &length, 0);
367         if (UCS_PTR_IS_PTR(rreq)) {
368             length = wait_stream_recv(rreq);
369         }
370         ASSERT_GT(length, 0ul);
371         ASSERT_LE(length, msg_size);
372         EXPECT_EQ(size_t(0), length % dt_elem_size);
373         rcount += length;
374         counter++;
375     }
376     EXPECT_EQ(scount, rcount);
377 
378     /* waitall flag requires completion by single request */
379     if (recv_flags & UCP_STREAM_RECV_FLAG_WAITALL) {
380         EXPECT_EQ(size_t(0), counter);
381     }
382 
383     /* double check, no data should be here */
384     while (progress());
385 
386     size_t s;
387     void   *p;
388     while ((p = ucp_stream_recv_data_nb(receiver().ep(), &s)) != NULL) {
389         rcount += s;
390         ucp_stream_data_release(receiver().ep(), p);
391         progress();
392     }
393     EXPECT_EQ(scount, rcount);
394 }
395 
do_send_recv_data_recv_test(ucp_datatype_t datatype)396 void test_ucp_stream::do_send_recv_data_recv_test(ucp_datatype_t datatype)
397 {
398     const size_t dt_elem_size = UCP_DT_IS_CONTIG(datatype) ?
399                                 ucp_contig_dt_elem_size(datatype) : 1;
400     size_t            ssize   = 0; /* total send size */
401     size_t            roffset = 0;
402     size_t            send_i  = dt_elem_size;
403     size_t            recv_i  = 0;
404     std::vector<char> sbuf(16 * UCS_MBYTE, 's');
405     ucs_status_ptr_t  sstatus;
406     std::vector<char> check_pattern;
407     std::vector<char> rbuf;
408     ucs_status_ptr_t  rdata;
409     size_t            length;
410 
411     do {
412         if (send_i < sbuf.size()) {
413             rbuf.resize(rbuf.size() + send_i, 'r');
414             ucs::fill_random(sbuf, send_i);
415             check_pattern.insert(check_pattern.end(), sbuf.begin(),
416                                  sbuf.begin() + send_i);
417             ucp::data_type_desc_t dt_desc(datatype, sbuf.data(), send_i);
418             sstatus = stream_send_nb(dt_desc);
419             EXPECT_FALSE(UCS_PTR_IS_ERR(sstatus));
420             wait(sstatus);
421             ssize += send_i;
422             send_i *= 2;
423         }
424 
425         progress();
426 
427         if ((++recv_i % 2) || ((ssize - roffset) < dt_elem_size)) {
428             rdata = ucp_stream_recv_data_nb(receiver().ep(), &length);
429             if (rdata == NULL) {
430                 continue;
431             }
432 
433             memcpy(&rbuf[roffset], rdata, length);
434             ucp_stream_data_release(receiver().ep(), rdata);
435         } else {
436             ucp::data_type_desc_t dt_desc(datatype, &rbuf[roffset], ssize - roffset);
437             void *rreq = ucp_stream_recv_nb(receiver().ep(), dt_desc.buf(),
438                                             dt_desc.count(), dt_desc.dt(),
439                                             ucp_recv_cb, &length, 0);
440             ASSERT_TRUE(!UCS_PTR_IS_ERR(rreq));
441             if (UCS_PTR_IS_PTR(rreq)) {
442                 length = wait_stream_recv(rreq);
443             }
444         }
445         roffset += length;
446     } while (roffset < ssize);
447 
448     EXPECT_EQ(roffset, ssize);
449     EXPECT_EQ(check_pattern, rbuf);
450 }
451 
UCS_TEST_P(test_ucp_stream,send_recv_data)452 UCS_TEST_P(test_ucp_stream, send_recv_data) {
453     do_send_recv_data_test(DATATYPE);
454 }
455 
UCS_TEST_P(test_ucp_stream,send_iov_recv_data)456 UCS_TEST_P(test_ucp_stream, send_iov_recv_data) {
457     do_send_recv_data_test(DATATYPE_IOV);
458 }
459 
UCS_TEST_P(test_ucp_stream,send_generic_recv_data)460 UCS_TEST_P(test_ucp_stream, send_generic_recv_data) {
461     ucp_datatype_t dt;
462     ucs_status_t status;
463 
464     status = ucp_dt_create_generic(&ucp::test_dt_uint8_ops, NULL, &dt);
465     ASSERT_UCS_OK(status);
466     do_send_recv_data_test(dt);
467     ucp_dt_destroy(dt);
468 }
469 
UCS_TEST_P(test_ucp_stream,send_recv_8)470 UCS_TEST_P(test_ucp_stream, send_recv_8) {
471     ucp_datatype_t datatype = ucp_dt_make_contig(sizeof(uint8_t));
472 
473     do_send_recv_test<uint8_t, 0>(datatype);
474     do_send_recv_test<uint8_t, UCP_STREAM_RECV_FLAG_WAITALL>(datatype);
475 }
476 
UCS_TEST_P(test_ucp_stream,send_recv_16)477 UCS_TEST_P(test_ucp_stream, send_recv_16) {
478     ucp_datatype_t datatype = ucp_dt_make_contig(sizeof(uint16_t));
479 
480     do_send_recv_test<uint16_t, 0>(datatype);
481     do_send_recv_test<uint16_t, UCP_STREAM_RECV_FLAG_WAITALL>(datatype);
482 }
483 
UCS_TEST_P(test_ucp_stream,send_recv_32)484 UCS_TEST_P(test_ucp_stream, send_recv_32) {
485     ucp_datatype_t datatype = ucp_dt_make_contig(sizeof(uint32_t));
486 
487     do_send_recv_test<uint32_t, 0>(datatype);
488     do_send_recv_test<uint32_t, UCP_STREAM_RECV_FLAG_WAITALL>(datatype);
489 }
490 
UCS_TEST_P(test_ucp_stream,send_recv_64)491 UCS_TEST_P(test_ucp_stream, send_recv_64) {
492     ucp_datatype_t datatype = ucp_dt_make_contig(sizeof(uint64_t));
493 
494     do_send_recv_test<uint64_t, 0>(datatype);
495     do_send_recv_test<uint64_t, UCP_STREAM_RECV_FLAG_WAITALL>(datatype);
496 }
497 
UCS_TEST_P(test_ucp_stream,send_recv_iov)498 UCS_TEST_P(test_ucp_stream, send_recv_iov) {
499     do_send_recv_test<uint8_t, 0>(DATATYPE_IOV);
500     do_send_recv_test<uint8_t, UCP_STREAM_RECV_FLAG_WAITALL>(DATATYPE_IOV);
501 }
502 
UCS_TEST_P(test_ucp_stream,send_recv_generic)503 UCS_TEST_P(test_ucp_stream, send_recv_generic) {
504     ucp_datatype_t dt;
505     ucs_status_t status;
506 
507     status = ucp_dt_create_generic(&ucp::test_dt_uint8_ops, &context, &dt);
508     ASSERT_UCS_OK(status);
509     do_send_recv_test<uint8_t, UCP_STREAM_RECV_FLAG_WAITALL>(dt);
510     ucp_dt_destroy(dt);
511 }
512 
UCS_TEST_P(test_ucp_stream,send_exp_recv_8)513 UCS_TEST_P(test_ucp_stream, send_exp_recv_8) {
514     ucp_datatype_t datatype = ucp_dt_make_contig(sizeof(uint8_t));
515 
516     do_send_exp_recv_test<uint8_t, 0>(datatype);
517     do_send_exp_recv_test<uint8_t, UCP_STREAM_RECV_FLAG_WAITALL>(datatype);
518 }
519 
UCS_TEST_P(test_ucp_stream,send_exp_recv_16)520 UCS_TEST_P(test_ucp_stream, send_exp_recv_16) {
521     ucp_datatype_t datatype = ucp_dt_make_contig(sizeof(uint16_t));
522 
523     do_send_exp_recv_test<uint16_t, 0>(datatype);
524     do_send_exp_recv_test<uint16_t, UCP_STREAM_RECV_FLAG_WAITALL>(datatype);
525 }
526 
UCS_TEST_P(test_ucp_stream,send_exp_recv_32)527 UCS_TEST_P(test_ucp_stream, send_exp_recv_32) {
528     ucp_datatype_t datatype = ucp_dt_make_contig(sizeof(uint32_t));
529 
530     do_send_exp_recv_test<uint32_t, 0>(datatype);
531     do_send_exp_recv_test<uint32_t, UCP_STREAM_RECV_FLAG_WAITALL>(datatype);
532 }
533 
UCS_TEST_P(test_ucp_stream,send_exp_recv_64)534 UCS_TEST_P(test_ucp_stream, send_exp_recv_64) {
535     ucp_datatype_t datatype = ucp_dt_make_contig(sizeof(uint64_t));
536 
537     do_send_exp_recv_test<uint64_t, 0>(datatype);
538     do_send_exp_recv_test<uint64_t, UCP_STREAM_RECV_FLAG_WAITALL>(datatype);
539 }
540 
UCS_TEST_P(test_ucp_stream,send_exp_recv_iov)541 UCS_TEST_P(test_ucp_stream, send_exp_recv_iov) {
542     do_send_exp_recv_test<uint8_t, 0>(DATATYPE_IOV);
543     do_send_exp_recv_test<uint8_t, UCP_STREAM_RECV_FLAG_WAITALL>(DATATYPE_IOV);
544 }
545 
UCS_TEST_P(test_ucp_stream,send_recv_data_recv_8)546 UCS_TEST_P(test_ucp_stream, send_recv_data_recv_8) {
547     do_send_recv_data_recv_test(ucp_dt_make_contig(sizeof(uint8_t)));
548 }
549 
UCS_TEST_P(test_ucp_stream,send_recv_data_recv_16)550 UCS_TEST_P(test_ucp_stream, send_recv_data_recv_16) {
551     do_send_recv_data_recv_test(ucp_dt_make_contig(sizeof(uint16_t)));
552 }
553 
UCS_TEST_P(test_ucp_stream,send_recv_data_recv_32)554 UCS_TEST_P(test_ucp_stream, send_recv_data_recv_32) {
555     do_send_recv_data_recv_test(ucp_dt_make_contig(sizeof(uint32_t)));
556 }
557 
UCS_TEST_P(test_ucp_stream,send_recv_data_recv_64)558 UCS_TEST_P(test_ucp_stream, send_recv_data_recv_64) {
559     do_send_recv_data_recv_test(ucp_dt_make_contig(sizeof(uint64_t)));
560 }
561 
UCS_TEST_P(test_ucp_stream,send_recv_data_recv_iov)562 UCS_TEST_P(test_ucp_stream, send_recv_data_recv_iov) {
563     do_send_recv_data_recv_test(DATATYPE_IOV);
564 }
565 
UCS_TEST_P(test_ucp_stream,send_zero_ending_iov_recv_data)566 UCS_TEST_P(test_ucp_stream, send_zero_ending_iov_recv_data) {
567     const size_t min_size         = UCS_KBYTE;
568     const size_t max_size         = min_size * 64;
569     const size_t iov_num          = 8; /* must be divisible by 4 without a
570                                         * remainder, caught on mlx5 based TLs
571                                         * where max_iov = 3 for zcopy multi
572                                         * protocol, where every posting includes:
573                                         * 1 header + 2 nonempty IOVs */
574     const size_t iov_num_nonempty = iov_num / 2;
575 
576     std::vector<uint8_t> buf(max_size * 2);
577     ucs::fill_random(buf, buf.size());
578     std::vector<ucp_dt_iov_t> v(iov_num);
579 
580     for (size_t size = min_size; size < max_size; ++size) {
581         size_t slen = 0;
582         for (size_t j = 0; j < iov_num; ++j) {
583             if ((j % 2) == 0) {
584                 uint8_t *ptr = buf.data();
585                 v[j].buffer = &(ptr[j * size / iov_num_nonempty]);
586                 v[j].length = size / iov_num_nonempty;
587                 slen       += v[j].length;
588             } else {
589                 v[j].buffer = NULL;
590                 v[j].length = 0;
591             }
592         }
593 
594         void *sreq = ucp_stream_send_nb(sender().ep(), &v[0], iov_num,
595                                         DATATYPE_IOV, ucp_send_cb, 0);
596 
597         size_t rlen = 0;
598         while (rlen < slen) {
599             progress();
600             size_t length;
601             void *rdata = ucp_stream_recv_data_nb(receiver().ep(), &length);
602             EXPECT_FALSE(UCS_PTR_IS_ERR(rdata));
603             if (rdata != NULL) {
604                 rlen += length;
605                 ucp_stream_data_release(receiver().ep(), rdata);
606             }
607         }
608         wait(sreq);
609     }
610 }
611 
612 UCP_INSTANTIATE_TEST_CASE(test_ucp_stream)
613 
614 class test_ucp_stream_many2one : public test_ucp_stream_base {
615 protected:
616     struct request_wrapper_t {
request_wrapper_ttest_ucp_stream_many2one::request_wrapper_t617         request_wrapper_t(void *request, ucp::data_type_desc_t *dt_desc)
618             : m_req(request), m_dt_desc(dt_desc) {}
619 
620         void                  *m_req;
621         ucp::data_type_desc_t *m_dt_desc;
622     };
623 
624 public:
test_ucp_stream_many2one()625     test_ucp_stream_many2one() : m_receiver_idx(3), m_nsenders(3) {
626         m_recv_data.resize(m_nsenders);
627     }
628 
get_ctx_params()629     static ucp_params_t get_ctx_params() {
630         return test_ucp_stream::get_ctx_params();
631     }
632 
633     virtual void init();
ucp_send_cb(void * request,ucs_status_t status)634     static void ucp_send_cb(void *request, ucs_status_t status) {}
ucp_recv_cb(void * request,ucs_status_t status,size_t length)635     static void ucp_recv_cb(void *request, ucs_status_t status, size_t length) {}
636 
637     void do_send_worker_poll_test(ucp_datatype_t dt);
638     void do_send_recv_test(ucp_datatype_t dt);
639 
640 protected:
641     static void erase_completed_reqs(std::vector<request_wrapper_t> &reqs);
642     ucs_status_ptr_t stream_send_nb(size_t sender_idx,
643                                     const ucp::data_type_desc_t& dt_desc);
644     size_t send_all_nb(ucp_datatype_t datatype, size_t n_iter,
645                        std::vector<request_wrapper_t> &sreqs);
646     size_t send_all(ucp_datatype_t datatype, size_t n_iter);
647     void check_no_data();
648     std::set<ucp_ep_h> check_no_data(entity &e);
649     void check_recv_data(size_t n_iter, ucp_datatype_t dt);
650 
651     std::vector<std::string>        m_msgs;
652     std::vector<std::vector<char> > m_recv_data;
653     const size_t                    m_receiver_idx;
654     const size_t                    m_nsenders;
655 };
656 
init()657 void test_ucp_stream_many2one::init()
658 {
659     if (is_self()) {
660         UCS_TEST_SKIP_R("self");
661     }
662 
663     /* Skip entities creation */
664     test_base::init();
665 
666     for (size_t i = 0; i < m_nsenders + 1; ++i) {
667         create_entity();
668     }
669 
670     for (size_t i = 0; i < m_nsenders; ++i) {
671         e(i).connect(&e(m_receiver_idx), get_ep_params(), i);
672 
673         ucp_ep_params_t recv_ep_param = get_ep_params();
674         recv_ep_param.field_mask |= UCP_EP_PARAM_FIELD_USER_DATA;
675         recv_ep_param.user_data   = (void *)uintptr_t(i);
676         e(m_receiver_idx).connect(&e(i), recv_ep_param, i);
677     }
678 
679     for (size_t i = 0; i < m_nsenders; ++i) {
680         m_msgs.push_back(std::string("sender_") + ucs::to_string(i));
681     }
682 }
683 
do_send_worker_poll_test(ucp_datatype_t dt)684 void test_ucp_stream_many2one::do_send_worker_poll_test(ucp_datatype_t dt)
685 {
686     const size_t                   niter = 2018;
687     std::vector<request_wrapper_t> sreqs;
688     size_t                         total_len;
689 
690     total_len = send_all_nb(dt, niter, sreqs);
691 
692     /* Recv and progress all data */
693     do {
694         ssize_t count;
695         do {
696             const size_t max_eps = 10;
697             ucp_stream_poll_ep_t poll_eps[max_eps];
698             progress();
699             count = ucp_stream_worker_poll(e(m_receiver_idx).worker(),
700                                            poll_eps, max_eps, 0);
701             EXPECT_LE(0, count);
702 
703             for (ssize_t i = 0; i < count; ++i) {
704                 char   *rdata;
705                 size_t length;
706                 while ((rdata = (char *)ucp_stream_recv_data_nb(poll_eps[i].ep,
707                                                                 &length)) != NULL) {
708                     ASSERT_FALSE(UCS_PTR_IS_ERR(rdata));
709                     size_t senser_idx = uintptr_t(poll_eps[i].user_data);
710                     std::vector<char> &dst = m_recv_data[senser_idx];
711                     dst.insert(dst.end(), rdata, rdata + length);
712                     total_len -= length;
713                     ucp_stream_data_release(poll_eps[i].ep, rdata);
714                 }
715             }
716         } while (count > 0);
717 
718         erase_completed_reqs(sreqs);
719     } while (!sreqs.empty() || (total_len != 0));
720 
721     check_no_data();
722     check_recv_data(niter, dt);
723 }
724 
do_send_recv_test(ucp_datatype_t dt)725 void test_ucp_stream_many2one::do_send_recv_test(ucp_datatype_t dt)
726 {
727     const size_t                                       niter = 2018;
728     std::vector<size_t>                                roffsets(m_nsenders, 0);
729     std::vector<ucp::data_type_desc_t>                 dt_rdescs(m_nsenders);
730     std::vector<std::pair<size_t, request_wrapper_t> > rreqs;
731     std::vector<request_wrapper_t>                     sreqs;
732     size_t                                             total_sdata;
733 
734     ASSERT_FALSE(m_msgs.empty());
735 
736     /* Do preposts */
737     for (size_t i = 0; i < m_nsenders; ++i) {
738         m_recv_data[i].resize(m_msgs[i].length() * niter + 1);
739         ucp::data_type_desc_t &rdesc = dt_rdescs[i].make(dt,
740                                                          &m_recv_data[i][roffsets[i]],
741                                                          m_recv_data[i].size());
742         size_t length;
743         void *rreq = ucp_stream_recv_nb(e(m_receiver_idx).ep(0, i),
744                                         rdesc.buf(), rdesc.count(), rdesc.dt(),
745                                         ucp_recv_cb, &length, 0);
746         EXPECT_TRUE(UCS_PTR_IS_PTR(rreq));
747         rreqs.push_back(std::make_pair(i, request_wrapper_t(rreq, &rdesc)));
748     }
749 
750     total_sdata = send_all_nb(dt, niter, sreqs);
751 
752     /* Recv and progress all the rest of data */
753     do {
754         ssize_t count;
755         /* wait rreqs */
756         for (size_t i = 0; i < rreqs.size(); ++i) {
757             roffsets[rreqs[i].first] += wait_stream_recv(rreqs[i].second.m_req);
758         }
759         rreqs.clear();
760         progress();
761 
762         const size_t max_eps = 10;
763         ucp_stream_poll_ep_t poll_eps[max_eps];
764         count = ucp_stream_worker_poll(e(m_receiver_idx).worker(),
765                                        poll_eps, max_eps, 0);
766         EXPECT_LE(0, count);
767         EXPECT_LE(size_t(count), m_nsenders);
768 
769         for (ssize_t i = 0; i < count; ++i) {
770             bool again = true;
771             while (again) {
772                 size_t sender_idx = uintptr_t(poll_eps[i].user_data);
773                 size_t &roffset   = roffsets[sender_idx];
774                 ucp::data_type_desc_t &dt_desc =
775                     dt_rdescs[sender_idx].forward_to(roffset);
776                 EXPECT_TRUE(dt_desc.is_valid());
777                 size_t length;
778                 void *rreq = ucp_stream_recv_nb(poll_eps[i].ep,
779                                                 dt_desc.buf(),
780                                                 dt_desc.count(),
781                                                 dt_desc.dt(),
782                                                 ucp_recv_cb, &length, 0);
783                 EXPECT_FALSE(UCS_PTR_IS_ERR(rreq));
784                 if (rreq == NULL) {
785                     EXPECT_LT(size_t(0), length);
786                     roffset += length;
787                     if (ssize_t(length) < dt_desc.buf_length()) {
788                         continue; /* Need to drain the EP */
789                     }
790                 } else {
791                     rreqs.push_back(std::make_pair(sender_idx,
792                                                    request_wrapper_t(rreq,
793                                                                      &dt_desc)));
794                 }
795                 again = false;
796             }
797         }
798 
799         erase_completed_reqs(sreqs);
800     } while (!rreqs.empty() || !sreqs.empty() ||
801              (total_sdata > std::accumulate(roffsets.begin(),
802                                             roffsets.end(), 0ul)));
803 
804     EXPECT_EQ(total_sdata, std::accumulate(roffsets.begin(),
805                                            roffsets.end(), 0ul));
806     check_no_data();
807     check_recv_data(niter, dt);
808 }
809 
810 ucs_status_ptr_t
stream_send_nb(size_t sender_idx,const ucp::data_type_desc_t & dt_desc)811 test_ucp_stream_many2one::stream_send_nb(size_t sender_idx,
812                                          const ucp::data_type_desc_t& dt_desc)
813 {
814     return ucp_stream_send_nb(m_entities.at(sender_idx).ep(), dt_desc.buf(),
815                               dt_desc.count(), dt_desc.dt(), ucp_send_cb, 0);
816 }
817 
818 size_t
send_all_nb(ucp_datatype_t datatype,size_t n_iter,std::vector<request_wrapper_t> & sreqs)819 test_ucp_stream_many2one::send_all_nb(ucp_datatype_t datatype, size_t n_iter,
820                                       std::vector<request_wrapper_t> &sreqs)
821 {
822     size_t total = 0;
823     /* Send many times in round robin */
824     for (size_t i = 0; i < n_iter; ++i) {
825         for (size_t sender_idx = 0; sender_idx < m_nsenders; ++sender_idx) {
826             const void  *buf = m_msgs[sender_idx].c_str();
827             size_t      len  = m_msgs[sender_idx].length();
828             if (i == (n_iter - 1)) {
829                 ++len;
830             }
831 
832             ucp::data_type_desc_t *dt_desc = new ucp::data_type_desc_t(datatype,
833                                                                        buf,
834                                                                        len);
835             void *sreq = stream_send_nb(sender_idx, *dt_desc);
836             total += len;
837             if (UCS_PTR_IS_PTR(sreq)) {
838                 sreqs.push_back(request_wrapper_t(sreq, dt_desc));
839             } else {
840                 EXPECT_FALSE(UCS_PTR_IS_ERR(sreq));
841                 delete dt_desc;
842             }
843         }
844     }
845 
846     return total;
847 }
848 
849 size_t
send_all(ucp_datatype_t datatype,size_t n_iter)850 test_ucp_stream_many2one::send_all(ucp_datatype_t datatype, size_t n_iter)
851 {
852     std::vector<request_wrapper_t> sreqs;
853     size_t                         total;
854 
855     total = send_all_nb(datatype, n_iter, sreqs);
856     while (!sreqs.empty()) {
857         progress();
858         erase_completed_reqs(sreqs);
859     }
860 
861     return total;
862 }
863 
check_no_data()864 void test_ucp_stream_many2one::check_no_data()
865 {
866     std::set<ucp_ep_h> check;
867 
868     for (size_t i = 0; i <= m_receiver_idx; ++i) {
869         std::set<ucp_ep_h> check_e = check_no_data(e(i));
870         check.insert(check_e.begin(), check_e.end());
871     }
872 
873     EXPECT_EQ(size_t(0), check.size());
874 }
875 
check_no_data(entity & e)876 std::set<ucp_ep_h> test_ucp_stream_many2one::check_no_data(entity &e)
877 {
878     const size_t         max_eps = 10;
879     ucp_stream_poll_ep_t poll_eps[max_eps];
880     std::set<ucp_ep_h>   ret;
881     std::list<ucp_ep_h>  check_list;
882 
883     while (progress());
884 
885     ssize_t count = ucp_stream_worker_poll(m_entities.at(m_receiver_idx).worker(),
886                                            poll_eps, max_eps, 0);
887     EXPECT_GE(count, ssize_t(0));
888 
889     for (ssize_t i = 0; i < count; ++i) {
890         ret.insert(poll_eps[i].ep);
891     }
892 
893     for (int i = 0; i < e.get_num_workers(); ++i) {
894         for (int j = 0; j < e.get_num_eps(); ++j) {
895             check_list.push_back(e.ep(i, j));
896         }
897     }
898 
899     std::list<ucp_ep_h>::const_iterator check_it = check_list.begin();
900     while (check_it != check_list.end()) {
901         EXPECT_EQ(ret.end(), ret.find(*check_it));
902         ++check_it;
903     }
904 
905     return ret;
906 }
907 
check_recv_data(size_t n_iter,ucp_datatype_t dt)908 void test_ucp_stream_many2one::check_recv_data(size_t n_iter, ucp_datatype_t dt)
909 {
910     for (size_t i = 0; i < m_nsenders; ++i) {
911         std::string test = std::string("sender_") + ucs::to_string(i);
912         const std::string str(&m_recv_data[i].front());
913         if (UCP_DT_IS_GENERIC(dt)) {
914             std::vector<char> test_gen;
915             for (size_t j = 0; j < test.length(); ++j) {
916                 test_gen.push_back(char(j));
917             }
918             test_gen.push_back('\0');
919             test = std::string(test_gen.data());
920         }
921 
922         size_t            next = 0;
923         for (size_t j = 0; j < n_iter; ++j) {
924             size_t match = str.find(test, next);
925             EXPECT_NE(std::string::npos, match) << "failed on sender " << i
926                                                 << " iteration " << j;
927             if (match == std::string::npos) {
928                 break;
929             }
930             EXPECT_EQ(next, match);
931             next += test.length();
932         }
933         EXPECT_EQ(next, str.length()); /* nothing more */
934     }
935 }
936 
937 void
erase_completed_reqs(std::vector<request_wrapper_t> & reqs)938 test_ucp_stream_many2one::erase_completed_reqs(std::vector<request_wrapper_t> &reqs)
939 {
940     std::vector<request_wrapper_t>::iterator i = reqs.begin();
941 
942     while (i != reqs.end()) {
943         ucs_status_t status = ucp_request_check_status(i->m_req);
944         if (status != UCS_INPROGRESS) {
945             EXPECT_EQ(UCS_OK, status);
946             ucp_request_free(i->m_req);
947             delete i->m_dt_desc;
948             i = reqs.erase(i);
949         } else {
950             ++i;
951         }
952     }
953 }
954 
UCS_TEST_P(test_ucp_stream_many2one,drop_data)955 UCS_TEST_P(test_ucp_stream_many2one, drop_data) {
956     send_all(DATATYPE, 10);
957 
958     ASSERT_EQ(m_receiver_idx, m_nsenders);
959     for (size_t i = 0; i <= m_receiver_idx; ++i) {
960         flush_worker(e(i));
961     }
962 
963     /* destroy 1 connection */
964     entity::ep_destructor(m_entities.at(0).ep(),
965                           &m_entities.at(0));
966     entity::ep_destructor(m_entities.at(m_receiver_idx).ep(),
967                           &m_entities.at(0));
968     m_entities.at(0).revoke_ep();
969     m_entities.at(m_receiver_idx).revoke_ep(0, 0);
970 
971     /* wait for 1-st byte on the last EP to be sure the network packets have
972        been arrived */
973     uint8_t check;
974     size_t  check_length;
975     ucp_ep_h last_ep = m_entities.at(m_receiver_idx).ep(0, m_nsenders - 1);
976     void *check_req  = ucp_stream_recv_nb(last_ep, &check, 1, DATATYPE,
977                                           ucp_recv_cb, &check_length, 0);
978     EXPECT_FALSE(UCS_PTR_IS_ERR(check_req));
979     if (UCS_PTR_IS_PTR(check_req)) {
980         wait_stream_recv(check_req);
981     }
982 
983     /* data from disconnected EP should be dropped */
984     std::set<ucp_ep_h> others = check_no_data(m_entities.at(0));
985     /* since ordering between EPs is not guaranteed, some data may be still in
986      * the network or buffered by transport */
987     EXPECT_LE(others.size(), m_nsenders - 1);
988 
989     /* reconnect */
990     m_entities.at(0).connect(&m_entities.at(m_receiver_idx), get_ep_params(), 0);
991     ucp_ep_params_t recv_ep_param = get_ep_params();
992     recv_ep_param.field_mask |= UCP_EP_PARAM_FIELD_USER_DATA;
993     recv_ep_param.user_data   = (void *)uintptr_t(0xdeadbeef);
994     e(m_receiver_idx).connect(&e(0), recv_ep_param, 0);
995 
996     /* send again */
997     send_all(DATATYPE, 10);
998 
999     for (size_t i = 0; i <= m_receiver_idx; ++i) {
1000         flush_worker(e(i));
1001     }
1002 
1003     /* Need to poll out all incoming data from transport layer, see PR #2048 */
1004     while (progress() > 0);
1005 }
1006 
UCS_TEST_P(test_ucp_stream_many2one,send_worker_poll)1007 UCS_TEST_P(test_ucp_stream_many2one, send_worker_poll) {
1008     do_send_worker_poll_test(DATATYPE);
1009 }
1010 
UCS_TEST_P(test_ucp_stream_many2one,send_worker_poll_iov)1011 UCS_TEST_P(test_ucp_stream_many2one, send_worker_poll_iov) {
1012     do_send_worker_poll_test(DATATYPE_IOV);
1013 }
1014 
UCS_TEST_P(test_ucp_stream_many2one,send_worker_poll_generic)1015 UCS_TEST_P(test_ucp_stream_many2one, send_worker_poll_generic) {
1016     ucp_datatype_t dt;
1017     ucs_status_t status;
1018 
1019     status = ucp_dt_create_generic(&ucp::test_dt_uint8_ops, NULL, &dt);
1020     ASSERT_UCS_OK(status);
1021     do_send_worker_poll_test(dt);
1022     ucp_dt_destroy(dt);
1023 }
1024 
UCS_TEST_P(test_ucp_stream_many2one,send_recv_nb)1025 UCS_TEST_P(test_ucp_stream_many2one, send_recv_nb) {
1026     do_send_recv_test(DATATYPE);
1027 }
1028 
UCS_TEST_P(test_ucp_stream_many2one,send_recv_nb_iov)1029 UCS_TEST_P(test_ucp_stream_many2one, send_recv_nb_iov) {
1030     do_send_recv_test(DATATYPE_IOV);
1031 }
1032 
UCS_TEST_P(test_ucp_stream_many2one,send_recv_nb_generic)1033 UCS_TEST_P(test_ucp_stream_many2one, send_recv_nb_generic) {
1034     ucp_datatype_t dt;
1035     ucs_status_t status;
1036 
1037     status = ucp_dt_create_generic(&ucp::test_dt_uint8_ops, NULL, &dt);
1038     ASSERT_UCS_OK(status);
1039     do_send_recv_test(dt);
1040     ucp_dt_destroy(dt);
1041 }
1042 
1043 UCP_INSTANTIATE_TEST_CASE(test_ucp_stream_many2one)
1044