1 /**
2 * Copyright (C) Mellanox Technologies Ltd. 2001-2014.  ALL RIGHTS RESERVED.
3 * See file LICENSE for terms.
4 */
5 
6 #ifndef UCP_TEST_H_
7 #define UCP_TEST_H_
8 
9 #include <ucp/api/ucp.h>
10 #include <ucs/time/time.h>
11 #include <common/mem_buffer.h>
12 
13 /* ucp version compile time test */
14 #if (UCP_API_VERSION != UCP_VERSION(UCP_API_MAJOR,UCP_API_MINOR))
15 #error possible bug in UCP version
16 #endif
17 
18 #include <common/test.h>
19 
20 #include <queue>
21 
22 #if _OPENMP
23 #include "omp.h"
24 #endif
25 
26 #if _OPENMP && ENABLE_MT
27 #define MT_TEST_NUM_THREADS omp_get_max_threads()
28 #else
29 #define MT_TEST_NUM_THREADS 4
30 #endif
31 
32 
33 namespace ucp {
34 extern const uint32_t MAGIC;
35 }
36 
37 
38 struct ucp_test_param {
39     ucp_params_t              ctx_params;
40     std::vector<std::string>  transports;
41     int                       variant;
42     int                       thread_type;
43 };
44 
45 class ucp_test; // forward declaration
46 
47 class ucp_test_base : public ucs::test_base {
48 public:
49     enum {
50         SINGLE_THREAD = 42,
51         MULTI_THREAD_CONTEXT, /* workers are single-threaded, context is mt-shared */
52         MULTI_THREAD_WORKER   /* workers are multi-threaded, cotnext is mt-single */
53     };
54 
55     class entity {
56         typedef std::vector<ucs::handle<ucp_ep_h, entity*> > ep_vec_t;
57         typedef std::vector<std::pair<ucs::handle<ucp_worker_h>,
58                                       ep_vec_t> > worker_vec_t;
59         typedef std::deque<void*> close_ep_reqs_t;
60 
61     public:
62         typedef enum {
63             LISTEN_CB_EP,       /* User's callback accepts ucp_ep_h */
64             LISTEN_CB_CONN,     /* User's callback accepts ucp_conn_request_h */
65             LISTEN_CB_REJECT    /* User's callback rejects ucp_conn_request_h */
66         } listen_cb_type_t;
67 
68         entity(const ucp_test_param& test_param, ucp_config_t* ucp_config,
69                const ucp_worker_params_t& worker_params,
70                const ucp_test_base* test_owner);
71 
72         ~entity();
73 
74         void connect(const entity* other, const ucp_ep_params_t& ep_params,
75                      int ep_idx = 0, int do_set_ep = 1);
76 
77         bool verify_client_address(struct sockaddr_storage *client_address);
78 
79         ucp_ep_h accept(ucp_worker_h worker, ucp_conn_request_h conn_request);
80 
81         void* modify_ep(const ucp_ep_params_t& ep_params, int worker_idx = 0,
82                        int ep_idx = 0);
83 
84         void* flush_ep_nb(int worker_index = 0, int ep_index = 0) const;
85 
86         void* flush_worker_nb(int worker_index = 0) const;
87 
88         void fence(int worker_index = 0) const;
89 
90         void* disconnect_nb(int worker_index = 0, int ep_index = 0,
91                             enum ucp_ep_close_mode mode = UCP_EP_CLOSE_MODE_FLUSH);
92 
93         void close_ep_req_free(void *close_req);
94 
95         void close_all_eps(const ucp_test &test, int wirker_idx,
96                            enum ucp_ep_close_mode mode = UCP_EP_CLOSE_MODE_FLUSH);
97 
98         void destroy_worker(int worker_index = 0);
99 
100         ucs_status_t listen(listen_cb_type_t cb_type,
101                             const struct sockaddr *saddr, socklen_t addrlen,
102                             const ucp_ep_params_t& ep_params,
103                             int worker_index = 0);
104 
105         ucp_ep_h ep(int worker_index = 0, int ep_index = 0) const;
106 
107         ucp_ep_h revoke_ep(int worker_index = 0, int ep_index = 0) const;
108 
109         ucp_worker_h worker(int worker_index = 0) const;
110 
111         ucp_context_h ucph() const;
112 
113         ucp_listener_h listenerh() const;
114 
115         unsigned progress(int worker_index = 0);
116 
117         int get_num_workers() const;
118 
119         int get_num_eps(int worker_index = 0) const;
120 
121         void add_err(ucs_status_t status);
122 
123         const size_t &get_err_num_rejected() const;
124 
125         const size_t &get_err_num() const;
126 
127         void warn_existing_eps() const;
128 
129         double set_ib_ud_timeout(double timeout_sec);
130 
131         void cleanup();
132 
133         static void ep_destructor(ucp_ep_h ep, entity *e);
134 
135     protected:
136         ucs::handle<ucp_context_h>      m_ucph;
137         worker_vec_t                    m_workers;
138         ucs::handle<ucp_listener_h>     m_listener;
139         std::queue<ucp_conn_request_h>  m_conn_reqs;
140         close_ep_reqs_t                 m_close_ep_reqs;
141         size_t                          m_err_cntr;
142         size_t                          m_rejected_cntr;
143         ucs::handle<ucp_ep_params_t*>   m_server_ep_params;
144 
145     private:
146         static void empty_send_completion(void *r, ucs_status_t status);
147         static void accept_ep_cb(ucp_ep_h ep, void *arg);
148         static void accept_conn_cb(ucp_conn_request_h conn_req, void *arg);
149         static void reject_conn_cb(ucp_conn_request_h conn_req, void *arg);
150 
151         void set_ep(ucp_ep_h ep, int worker_index, int ep_index);
152     };
153 
154     static bool is_request_completed(void *req);
155 };
156 
157 /**
158  * UCP test
159  */
160 class ucp_test : public ucp_test_base,
161                  public ::testing::TestWithParam<ucp_test_param>,
162                  public ucs::entities_storage<ucp_test_base::entity> {
163 
164     friend class ucp_test_base::entity;
165 
166 public:
167     enum {
168         DEFAULT_PARAM_VARIANT = 0
169     };
170 
171     UCS_TEST_BASE_IMPL;
172 
173     ucp_test();
174     virtual ~ucp_test();
175 
176     ucp_config_t* m_ucp_config;
177 
178     static std::vector<ucp_test_param>
179     enum_test_params(const ucp_params_t& ctx_params,
180                      const std::string& name,
181                      const std::string& test_case_name,
182                      const std::string& tls);
183 
184     static ucp_params_t get_ctx_params();
185     virtual ucp_worker_params_t get_worker_params();
186     virtual ucp_ep_params_t get_ep_params();
187 
188     static void
189     generate_test_params_variant(const ucp_params_t& ctx_params,
190                                  const std::string& name,
191                                  const std::string& test_case_name,
192                                  const std::string& tls,
193                                  int variant,
194                                  std::vector<ucp_test_param>& test_params,
195                                  int thread_type = SINGLE_THREAD);
196 
197     virtual void modify_config(const std::string& name, const std::string& value,
198                                bool optional = false);
199     void stats_activate();
200     void stats_restore();
201 
202 private:
203     static void set_ucp_config(ucp_config_t *config,
204                                const ucp_test_param& test_param);
205     static bool check_test_param(const std::string& name,
206                                  const std::string& test_case_name,
207                                  const ucp_test_param& test_param);
208 
209 protected:
210     virtual void init();
211     bool is_self() const;
212     virtual void cleanup();
213     virtual bool has_transport(const std::string& tl_name) const;
214     bool has_any_transport(const std::vector<std::string>& tl_names) const;
215     entity* create_entity(bool add_in_front = false);
216     entity* create_entity(bool add_in_front, const ucp_test_param& test_param);
217     unsigned progress(int worker_index = 0) const;
218     void short_progress_loop(int worker_index = 0) const;
219     void flush_ep(const entity &e, int worker_index = 0, int ep_index = 0);
220     void flush_worker(const entity &e, int worker_index = 0);
221     void disconnect(entity& entity);
222     void wait(void *req, int worker_index = 0);
223     void set_ucp_config(ucp_config_t *config);
224     int max_connections();
225 
err_handler_cb(void * arg,ucp_ep_h ep,ucs_status_t status)226     static void err_handler_cb(void *arg, ucp_ep_h ep, ucs_status_t status) {
227         entity *e = reinterpret_cast<entity*>(arg);
228         e->add_err(status);
229     }
230 
231     template <typename T>
232     void wait_for_flag(volatile T *flag, double timeout = 10.0) {
233         ucs_time_t loop_end_limit = ucs_get_time() + ucs_time_from_sec(timeout);
234         while ((ucs_get_time() < loop_end_limit) && (!(*flag))) {
235             short_progress_loop();
236         }
237     }
238 
239     static const ucp_datatype_t DATATYPE;
240     static const ucp_datatype_t DATATYPE_IOV;
241 
242 protected:
243     class mapped_buffer : public mem_buffer {
244     public:
245         mapped_buffer(size_t size, const entity& entity, int flags = 0,
246                       ucs_memory_type_t mem_type = UCS_MEMORY_TYPE_HOST);
247         virtual ~mapped_buffer();
248 
249         ucs::handle<ucp_rkey_h> rkey(const entity& entity) const;
250 
251         ucp_mem_h memh() const;
252 
253     private:
254         const entity& m_entity;
255         ucp_mem_h     m_memh;
256         void*         m_rkey_buffer;
257     };
258 };
259 
260 
261 std::ostream& operator<<(std::ostream& os, const ucp_test_param& test_param);
262 
263 /**
264  * Instantiate the parameterized test case a combination of transports.
265  *
266  * @param _test_case   Test case class, derived from ucp_test.
267  * @param _name        Instantiation name.
268  * @param ...          Transport names.
269  */
270 #define UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, _name, _tls) \
271     INSTANTIATE_TEST_CASE_P(_name,  _test_case, \
272                             testing::ValuesIn(_test_case::enum_test_params(_test_case::get_ctx_params(), \
273                                                                            #_name, \
274                                                                            #_test_case, \
275                                                                            _tls)));
276 
277 
278 /**
279  * Instantiate the parameterized test case for all transport combinations.
280  *
281  * @param _test_case  Test case class, derived from ucp_test.
282  */
283 #define UCP_INSTANTIATE_TEST_CASE(_test_case) \
284     UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, dcx,    "dc_x") \
285     UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, ud,     "ud_v") \
286     UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, udx,    "ud_x") \
287     UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, rc,     "rc_v") \
288     UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, rcx,    "rc_x") \
289     UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, shm_ib, "shm,ib") \
290     UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, ugni,   "ugni") \
291     UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, self,   "self") \
292     UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, tcp,    "tcp")
293 
294 
295 /**
296  * The list of GPU copy TLs
297  */
298 #define UCP_TEST_GPU_COPY_TLS "cuda_copy,rocm_copy"
299 
300 
301 /**
302  * Instantiate the parameterized test case for all transport combinations
303  * with GPU memory awareness
304  *
305  * @param _test_case  Test case class, derived from ucp_test.
306  */
307 #define UCP_INSTANTIATE_TEST_CASE_GPU_AWARE(_test_case) \
308     UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, dcx,        "dc_x," UCP_TEST_GPU_COPY_TLS) \
309     UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, ud,         "ud_v," UCP_TEST_GPU_COPY_TLS) \
310     UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, udx,        "ud_x," UCP_TEST_GPU_COPY_TLS) \
311     UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, rc,         "rc_v," UCP_TEST_GPU_COPY_TLS) \
312     UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, rcx,        "rc_x," UCP_TEST_GPU_COPY_TLS) \
313     UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, shm_ib,     "shm,ib," UCP_TEST_GPU_COPY_TLS) \
314     UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, shm_ib_ipc, "shm,ib,cuda_ipc,rocm_ipc," \
315                                                           UCP_TEST_GPU_COPY_TLS) \
316     UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, ugni,       "ugni," UCP_TEST_GPU_COPY_TLS) \
317     UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, tcp,        "tcp," UCP_TEST_GPU_COPY_TLS)
318 
319 #endif
320