1 /**
2 * Copyright (C) Mellanox Technologies Ltd. 2001-2017.  ALL RIGHTS RESERVED.
3 *
4 * See file LICENSE for terms.
5 */
6 
7 #include "ucp_test.h"
8 #include <common/mem_buffer.h>
9 
10 extern "C" {
11 #include <uct/api/uct.h>
12 #include <ucp/core/ucp_context.h>
13 #include <ucp/core/ucp_mm.h>
14 }
15 
16 
17 #define UCP_INSTANTIATE_TEST_CASE_MEMTYPE(_test_case, _name, _mem_type) \
18     INSTANTIATE_TEST_CASE_P(_name, _test_case, \
19                             testing::ValuesIn(_test_case::enum_test_params( \
20                                               _test_case::get_ctx_params(), \
21                                               #_test_case, _mem_type)));
22 
23 #define UCP_INSTANTIATE_TEST_CASE_MEMTYPES(_test_case) \
24     UCP_INSTANTIATE_TEST_CASE_MEMTYPE(_test_case, host,         UCS_MEMORY_TYPE_HOST) \
25     UCP_INSTANTIATE_TEST_CASE_MEMTYPE(_test_case, cuda,         UCS_MEMORY_TYPE_CUDA) \
26     UCP_INSTANTIATE_TEST_CASE_MEMTYPE(_test_case, cuda_managed, UCS_MEMORY_TYPE_CUDA_MANAGED) \
27     UCP_INSTANTIATE_TEST_CASE_MEMTYPE(_test_case, rocm,         UCS_MEMORY_TYPE_ROCM) \
28     UCP_INSTANTIATE_TEST_CASE_MEMTYPE(_test_case, rocm_managed, UCS_MEMORY_TYPE_ROCM_MANAGED)
29 
30 class test_ucp_mem_type : public ucp_test {
31 public:
get_ctx_params()32     static ucp_params_t get_ctx_params() {
33         ucp_params_t params = ucp_test::get_ctx_params();
34         params.features |= UCP_FEATURE_TAG;
35         return params;
36     }
37 
38     static std::vector<ucp_test_param>
enum_test_params(const ucp_params_t & ctx_params,const std::string & test_case_name,ucs_memory_type_t mem_type)39     enum_test_params(const ucp_params_t& ctx_params,
40                      const std::string& test_case_name, ucs_memory_type_t mem_type)
41     {
42         std::vector<ucp_test_param> result;
43 
44         std::vector<ucs_memory_type_t> mem_types =
45                         mem_buffer::supported_mem_types();
46         if (std::find(mem_types.begin(), mem_types.end(), mem_type) !=
47             mem_types.end()) {
48             generate_test_params_variant(ctx_params, "all", test_case_name,
49                                          "all", mem_type, result);
50         }
51 
52         return result;
53     }
54 
55 protected:
mem_type() const56     ucs_memory_type_t mem_type() const {
57         return static_cast<ucs_memory_type_t>(GetParam().variant);
58     }
59 };
60 
UCS_TEST_P(test_ucp_mem_type,detect)61 UCS_TEST_P(test_ucp_mem_type, detect) {
62 
63     const size_t size                      = 256;
64     const ucs_memory_type_t alloc_mem_type = mem_type();
65 
66     mem_buffer b(size, alloc_mem_type);
67 
68     ucs_memory_type_t detected_mem_type =
69                     ucp_memory_type_detect(sender().ucph(), b.ptr(), size);
70     EXPECT_EQ(alloc_mem_type, detected_mem_type);
71 }
72 
73 UCP_INSTANTIATE_TEST_CASE_MEMTYPES(test_ucp_mem_type)
74 
75 class test_ucp_mem_type_alloc_before_init : public test_ucp_mem_type {
76 public:
get_ctx_params()77     static ucp_params_t get_ctx_params() {
78         ucp_params_t params = ucp_test::get_ctx_params();
79         params.features    |= UCP_FEATURE_TAG;
80         return params;
81     }
82 
test_ucp_mem_type_alloc_before_init()83     test_ucp_mem_type_alloc_before_init() {
84         m_size = 10000;
85     }
86 
init()87     virtual void init() {
88         m_send_buffer.reset(new mem_buffer(m_size, mem_type()));
89         m_recv_buffer.reset(new mem_buffer(m_size, mem_type()));
90         test_ucp_mem_type::init();
91     }
92 
cleanup()93     virtual void cleanup() {
94         test_ucp_mem_type::cleanup();
95         m_send_buffer.reset();
96         m_recv_buffer.reset();
97     }
98 
99     static const uint64_t SEED = 0x1111111111111111lu;
100 protected:
101     size_t                     m_size;
102     ucs::auto_ptr<mem_buffer>  m_send_buffer, m_recv_buffer;
103 };
104 
UCS_TEST_P(test_ucp_mem_type_alloc_before_init,xfer)105 UCS_TEST_P(test_ucp_mem_type_alloc_before_init, xfer) {
106     sender().connect(&receiver(), get_ep_params());
107 
108     EXPECT_EQ(mem_type(), ucp_memory_type_detect(sender().ucph(),
109                                                  m_send_buffer->ptr(), m_size));
110     EXPECT_EQ(mem_type(), ucp_memory_type_detect(receiver().ucph(),
111                                                  m_recv_buffer->ptr(), m_size));
112 
113     mem_buffer::pattern_fill(m_send_buffer->ptr(), m_size, SEED, mem_type());
114 
115     for (int i = 0; i < 3; ++i) {
116         mem_buffer::pattern_fill(m_recv_buffer->ptr(), m_size, 0, mem_type());
117 
118         void *sreq = ucp_tag_send_nb(sender().ep(), m_send_buffer->ptr(), m_size,
119                                      ucp_dt_make_contig(1), 1,
120                                      (ucp_send_callback_t)ucs_empty_function);
121         void *rreq = ucp_tag_recv_nb(receiver().worker(), m_recv_buffer->ptr(),
122                                      m_size, ucp_dt_make_contig(1), 1, 1,
123                                      (ucp_tag_recv_callback_t)ucs_empty_function);
124         wait(sreq);
125         wait(rreq);
126 
127         mem_buffer::pattern_check(m_recv_buffer->ptr(), m_size, SEED, mem_type());
128     }
129 }
130 
131 UCP_INSTANTIATE_TEST_CASE_MEMTYPES(test_ucp_mem_type_alloc_before_init)
132