1 /** 2 * Copyright (C) Mellanox Technologies Ltd. 2001-2014. ALL RIGHTS RESERVED. 3 * See file LICENSE for terms. 4 */ 5 6 #ifndef UCP_TEST_H_ 7 #define UCP_TEST_H_ 8 9 #include <ucp/api/ucp.h> 10 #include <ucs/time/time.h> 11 #include <common/mem_buffer.h> 12 13 /* ucp version compile time test */ 14 #if (UCP_API_VERSION != UCP_VERSION(UCP_API_MAJOR,UCP_API_MINOR)) 15 #error possible bug in UCP version 16 #endif 17 18 #include <common/test.h> 19 20 #include <queue> 21 22 #if _OPENMP 23 #include "omp.h" 24 #endif 25 26 #if _OPENMP && ENABLE_MT 27 #define MT_TEST_NUM_THREADS omp_get_max_threads() 28 #else 29 #define MT_TEST_NUM_THREADS 4 30 #endif 31 32 33 namespace ucp { 34 extern const uint32_t MAGIC; 35 } 36 37 38 struct ucp_test_param { 39 ucp_params_t ctx_params; 40 std::vector<std::string> transports; 41 int variant; 42 int thread_type; 43 }; 44 45 class ucp_test; // forward declaration 46 47 class ucp_test_base : public ucs::test_base { 48 public: 49 enum { 50 SINGLE_THREAD = 42, 51 MULTI_THREAD_CONTEXT, /* workers are single-threaded, context is mt-shared */ 52 MULTI_THREAD_WORKER /* workers are multi-threaded, cotnext is mt-single */ 53 }; 54 55 class entity { 56 typedef std::vector<ucs::handle<ucp_ep_h, entity*> > ep_vec_t; 57 typedef std::vector<std::pair<ucs::handle<ucp_worker_h>, 58 ep_vec_t> > worker_vec_t; 59 typedef std::deque<void*> close_ep_reqs_t; 60 61 public: 62 typedef enum { 63 LISTEN_CB_EP, /* User's callback accepts ucp_ep_h */ 64 LISTEN_CB_CONN, /* User's callback accepts ucp_conn_request_h */ 65 LISTEN_CB_REJECT /* User's callback rejects ucp_conn_request_h */ 66 } listen_cb_type_t; 67 68 entity(const ucp_test_param& test_param, ucp_config_t* ucp_config, 69 const ucp_worker_params_t& worker_params, 70 const ucp_test_base* test_owner); 71 72 ~entity(); 73 74 void connect(const entity* other, const ucp_ep_params_t& ep_params, 75 int ep_idx = 0, int do_set_ep = 1); 76 77 bool verify_client_address(struct sockaddr_storage *client_address); 78 79 ucp_ep_h accept(ucp_worker_h worker, ucp_conn_request_h conn_request); 80 81 void* modify_ep(const ucp_ep_params_t& ep_params, int worker_idx = 0, 82 int ep_idx = 0); 83 84 void* flush_ep_nb(int worker_index = 0, int ep_index = 0) const; 85 86 void* flush_worker_nb(int worker_index = 0) const; 87 88 void fence(int worker_index = 0) const; 89 90 void* disconnect_nb(int worker_index = 0, int ep_index = 0, 91 enum ucp_ep_close_mode mode = UCP_EP_CLOSE_MODE_FLUSH); 92 93 void close_ep_req_free(void *close_req); 94 95 void close_all_eps(const ucp_test &test, int wirker_idx, 96 enum ucp_ep_close_mode mode = UCP_EP_CLOSE_MODE_FLUSH); 97 98 void destroy_worker(int worker_index = 0); 99 100 ucs_status_t listen(listen_cb_type_t cb_type, 101 const struct sockaddr *saddr, socklen_t addrlen, 102 const ucp_ep_params_t& ep_params, 103 int worker_index = 0); 104 105 ucp_ep_h ep(int worker_index = 0, int ep_index = 0) const; 106 107 ucp_ep_h revoke_ep(int worker_index = 0, int ep_index = 0) const; 108 109 ucp_worker_h worker(int worker_index = 0) const; 110 111 ucp_context_h ucph() const; 112 113 ucp_listener_h listenerh() const; 114 115 unsigned progress(int worker_index = 0); 116 117 int get_num_workers() const; 118 119 int get_num_eps(int worker_index = 0) const; 120 121 void add_err(ucs_status_t status); 122 123 const size_t &get_err_num_rejected() const; 124 125 const size_t &get_err_num() const; 126 127 void warn_existing_eps() const; 128 129 double set_ib_ud_timeout(double timeout_sec); 130 131 void cleanup(); 132 133 static void ep_destructor(ucp_ep_h ep, entity *e); 134 135 protected: 136 ucs::handle<ucp_context_h> m_ucph; 137 worker_vec_t m_workers; 138 ucs::handle<ucp_listener_h> m_listener; 139 std::queue<ucp_conn_request_h> m_conn_reqs; 140 close_ep_reqs_t m_close_ep_reqs; 141 size_t m_err_cntr; 142 size_t m_rejected_cntr; 143 ucs::handle<ucp_ep_params_t*> m_server_ep_params; 144 145 private: 146 static void empty_send_completion(void *r, ucs_status_t status); 147 static void accept_ep_cb(ucp_ep_h ep, void *arg); 148 static void accept_conn_cb(ucp_conn_request_h conn_req, void *arg); 149 static void reject_conn_cb(ucp_conn_request_h conn_req, void *arg); 150 151 void set_ep(ucp_ep_h ep, int worker_index, int ep_index); 152 }; 153 154 static bool is_request_completed(void *req); 155 }; 156 157 /** 158 * UCP test 159 */ 160 class ucp_test : public ucp_test_base, 161 public ::testing::TestWithParam<ucp_test_param>, 162 public ucs::entities_storage<ucp_test_base::entity> { 163 164 friend class ucp_test_base::entity; 165 166 public: 167 enum { 168 DEFAULT_PARAM_VARIANT = 0 169 }; 170 171 UCS_TEST_BASE_IMPL; 172 173 ucp_test(); 174 virtual ~ucp_test(); 175 176 ucp_config_t* m_ucp_config; 177 178 static std::vector<ucp_test_param> 179 enum_test_params(const ucp_params_t& ctx_params, 180 const std::string& name, 181 const std::string& test_case_name, 182 const std::string& tls); 183 184 static ucp_params_t get_ctx_params(); 185 virtual ucp_worker_params_t get_worker_params(); 186 virtual ucp_ep_params_t get_ep_params(); 187 188 static void 189 generate_test_params_variant(const ucp_params_t& ctx_params, 190 const std::string& name, 191 const std::string& test_case_name, 192 const std::string& tls, 193 int variant, 194 std::vector<ucp_test_param>& test_params, 195 int thread_type = SINGLE_THREAD); 196 197 virtual void modify_config(const std::string& name, const std::string& value, 198 bool optional = false); 199 void stats_activate(); 200 void stats_restore(); 201 202 private: 203 static void set_ucp_config(ucp_config_t *config, 204 const ucp_test_param& test_param); 205 static bool check_test_param(const std::string& name, 206 const std::string& test_case_name, 207 const ucp_test_param& test_param); 208 209 protected: 210 virtual void init(); 211 bool is_self() const; 212 virtual void cleanup(); 213 virtual bool has_transport(const std::string& tl_name) const; 214 bool has_any_transport(const std::vector<std::string>& tl_names) const; 215 entity* create_entity(bool add_in_front = false); 216 entity* create_entity(bool add_in_front, const ucp_test_param& test_param); 217 unsigned progress(int worker_index = 0) const; 218 void short_progress_loop(int worker_index = 0) const; 219 void flush_ep(const entity &e, int worker_index = 0, int ep_index = 0); 220 void flush_worker(const entity &e, int worker_index = 0); 221 void disconnect(entity& entity); 222 void wait(void *req, int worker_index = 0); 223 void set_ucp_config(ucp_config_t *config); 224 int max_connections(); 225 err_handler_cb(void * arg,ucp_ep_h ep,ucs_status_t status)226 static void err_handler_cb(void *arg, ucp_ep_h ep, ucs_status_t status) { 227 entity *e = reinterpret_cast<entity*>(arg); 228 e->add_err(status); 229 } 230 231 template <typename T> 232 void wait_for_flag(volatile T *flag, double timeout = 10.0) { 233 ucs_time_t loop_end_limit = ucs_get_time() + ucs_time_from_sec(timeout); 234 while ((ucs_get_time() < loop_end_limit) && (!(*flag))) { 235 short_progress_loop(); 236 } 237 } 238 239 static const ucp_datatype_t DATATYPE; 240 static const ucp_datatype_t DATATYPE_IOV; 241 242 protected: 243 class mapped_buffer : public mem_buffer { 244 public: 245 mapped_buffer(size_t size, const entity& entity, int flags = 0, 246 ucs_memory_type_t mem_type = UCS_MEMORY_TYPE_HOST); 247 virtual ~mapped_buffer(); 248 249 ucs::handle<ucp_rkey_h> rkey(const entity& entity) const; 250 251 ucp_mem_h memh() const; 252 253 private: 254 const entity& m_entity; 255 ucp_mem_h m_memh; 256 void* m_rkey_buffer; 257 }; 258 }; 259 260 261 std::ostream& operator<<(std::ostream& os, const ucp_test_param& test_param); 262 263 /** 264 * Instantiate the parameterized test case a combination of transports. 265 * 266 * @param _test_case Test case class, derived from ucp_test. 267 * @param _name Instantiation name. 268 * @param ... Transport names. 269 */ 270 #define UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, _name, _tls) \ 271 INSTANTIATE_TEST_CASE_P(_name, _test_case, \ 272 testing::ValuesIn(_test_case::enum_test_params(_test_case::get_ctx_params(), \ 273 #_name, \ 274 #_test_case, \ 275 _tls))); 276 277 278 /** 279 * Instantiate the parameterized test case for all transport combinations. 280 * 281 * @param _test_case Test case class, derived from ucp_test. 282 */ 283 #define UCP_INSTANTIATE_TEST_CASE(_test_case) \ 284 UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, dcx, "dc_x") \ 285 UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, ud, "ud_v") \ 286 UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, udx, "ud_x") \ 287 UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, rc, "rc_v") \ 288 UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, rcx, "rc_x") \ 289 UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, shm_ib, "shm,ib") \ 290 UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, ugni, "ugni") \ 291 UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, self, "self") \ 292 UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, tcp, "tcp") 293 294 295 /** 296 * The list of GPU copy TLs 297 */ 298 #define UCP_TEST_GPU_COPY_TLS "cuda_copy,rocm_copy" 299 300 301 /** 302 * Instantiate the parameterized test case for all transport combinations 303 * with GPU memory awareness 304 * 305 * @param _test_case Test case class, derived from ucp_test. 306 */ 307 #define UCP_INSTANTIATE_TEST_CASE_GPU_AWARE(_test_case) \ 308 UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, dcx, "dc_x," UCP_TEST_GPU_COPY_TLS) \ 309 UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, ud, "ud_v," UCP_TEST_GPU_COPY_TLS) \ 310 UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, udx, "ud_x," UCP_TEST_GPU_COPY_TLS) \ 311 UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, rc, "rc_v," UCP_TEST_GPU_COPY_TLS) \ 312 UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, rcx, "rc_x," UCP_TEST_GPU_COPY_TLS) \ 313 UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, shm_ib, "shm,ib," UCP_TEST_GPU_COPY_TLS) \ 314 UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, shm_ib_ipc, "shm,ib,cuda_ipc,rocm_ipc," \ 315 UCP_TEST_GPU_COPY_TLS) \ 316 UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, ugni, "ugni," UCP_TEST_GPU_COPY_TLS) \ 317 UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, tcp, "tcp," UCP_TEST_GPU_COPY_TLS) 318 319 #endif 320