1 /**
2 * Copyright (C) Mellanox Technologies Ltd. 2001-2018.  ALL RIGHTS RESERVED.
3 *
4 * See file LICENSE for terms.
5 */
6 
7 #ifdef HAVE_CONFIG_H
8 #  include "config.h"
9 #endif
10 
11 #include "ucp_listener.h"
12 #include "uct/base/uct_cm.h"
13 
14 #include <ucp/stream/stream.h>
15 #include <ucp/wireup/wireup_ep.h>
16 #include <ucp/wireup/wireup_cm.h>
17 #include <ucp/core/ucp_ep.h>
18 #include <ucp/core/ucp_ep.inl>
19 #include <ucs/debug/log.h>
20 #include <ucs/sys/sock.h>
21 
22 
ucp_listener_accept_cb_progress(void * arg)23 static unsigned ucp_listener_accept_cb_progress(void *arg)
24 {
25     ucp_ep_h       ep       = arg;
26     ucp_listener_h listener = ucp_ep_ext_gen(ep)->listener;
27 
28     /* NOTE: protect union */
29     ucs_assert(!(ep->flags & (UCP_EP_FLAG_ON_MATCH_CTX |
30                               UCP_EP_FLAG_FLUSH_STATE_VALID)));
31     ucs_assert(ep->flags   & UCP_EP_FLAG_LISTENER);
32 
33     ep->flags &= ~UCP_EP_FLAG_LISTENER;
34     ep->flags |= UCP_EP_FLAG_USED;
35     ucp_stream_ep_activate(ep);
36     ucp_ep_flush_state_reset(ep);
37 
38     /*
39      * listener is NULL if the EP was created with UCP_EP_PARAM_FIELD_EP_ADDR
40      * and we are here because long address requires wireup protocol
41      */
42     if (listener && listener->accept_cb) {
43         listener->accept_cb(ep, listener->arg);
44     }
45 
46     return 1;
47 }
48 
ucp_listener_accept_cb_remove_filter(const ucs_callbackq_elem_t * elem,void * arg)49 int ucp_listener_accept_cb_remove_filter(const ucs_callbackq_elem_t *elem,
50                                                 void *arg)
51 {
52     ucp_ep_h ep = elem->arg;
53 
54     return (elem->cb == ucp_listener_accept_cb_progress) && (ep == arg);
55 }
56 
ucp_listener_schedule_accept_cb(ucp_ep_h ep)57 void ucp_listener_schedule_accept_cb(ucp_ep_h ep)
58 {
59     uct_worker_cb_id_t prog_id = UCS_CALLBACKQ_ID_NULL;
60 
61     uct_worker_progress_register_safe(ep->worker->uct,
62                                       ucp_listener_accept_cb_progress,
63                                       ep, UCS_CALLBACKQ_FLAG_ONESHOT,
64                                       &prog_id);
65 }
66 
ucp_listener_conn_request_progress(void * arg)67 static unsigned ucp_listener_conn_request_progress(void *arg)
68 {
69     ucp_conn_request_h conn_request = arg;
70     ucp_listener_h     listener     = conn_request->listener;
71     ucp_worker_h       worker       = listener->worker;
72     ucp_ep_h           ep;
73     ucs_status_t       status;
74 
75     ucs_trace_func("listener=%p", listener);
76 
77     if (listener->conn_cb) {
78         listener->conn_cb(conn_request, listener->arg);
79         return 1;
80     }
81 
82     UCS_ASYNC_BLOCK(&worker->async);
83     status = ucp_ep_create_server_accept(worker, conn_request, &ep);
84     if (status != UCS_OK) {
85         goto out;
86     }
87 
88     if (listener->accept_cb != NULL) {
89         if (ep->flags & UCP_EP_FLAG_LISTENER) {
90             ucs_assert(!(ep->flags & UCP_EP_FLAG_USED));
91             ucp_ep_ext_gen(ep)->listener = listener;
92         } else {
93             ep->flags |= UCP_EP_FLAG_USED;
94             listener->accept_cb(ep, listener->arg);
95         }
96     }
97 
98 out:
99     UCS_ASYNC_UNBLOCK(&worker->async);
100     return 1;
101 }
102 
ucp_listener_remove_filter(const ucs_callbackq_elem_t * elem,void * arg)103 static int ucp_listener_remove_filter(const ucs_callbackq_elem_t *elem,
104                                       void *arg)
105 {
106     ucp_listener_h *listener = elem->arg;
107 
108     return (elem->cb == ucp_listener_conn_request_progress) && (listener == arg);
109 }
110 
ucp_listener_conn_request_callback(uct_iface_h tl_iface,void * arg,uct_conn_request_h uct_req,const void * conn_priv_data,size_t length)111 static void ucp_listener_conn_request_callback(uct_iface_h tl_iface, void *arg,
112                                                uct_conn_request_h uct_req,
113                                                const void *conn_priv_data,
114                                                size_t length)
115 {
116     ucp_listener_h     listener = arg;
117     uct_worker_cb_id_t prog_id  = UCS_CALLBACKQ_ID_NULL;
118     ucp_conn_request_h conn_request;
119 
120     ucs_trace("listener %p: got connection request", listener);
121 
122     /* Defer wireup init and user's callback to be invoked from the main thread */
123     conn_request = ucs_malloc(ucs_offsetof(ucp_conn_request_t, sa_data) +
124                               length, "accept connection request");
125     if (conn_request == NULL) {
126         ucs_error("failed to allocate connect request, "
127                   "rejecting connection request %p on TL iface %p, reason %s",
128                   uct_req, tl_iface, ucs_status_string(UCS_ERR_NO_MEMORY));
129         uct_iface_reject(tl_iface, uct_req);
130         return;
131     }
132 
133     conn_request->listener  = listener;
134     conn_request->uct_req   = uct_req;
135     conn_request->uct.iface = tl_iface;
136     memset(&conn_request->client_address, 0, sizeof(struct sockaddr_storage));
137     memcpy(&conn_request->sa_data, conn_priv_data, length);
138 
139     uct_worker_progress_register_safe(listener->worker->uct,
140                                       ucp_listener_conn_request_progress,
141                                       conn_request, UCS_CALLBACKQ_FLAG_ONESHOT,
142                                       &prog_id);
143 
144     /* If the worker supports the UCP_FEATURE_WAKEUP feature, signal the user so
145      * that he can wake-up on this event */
146     ucp_worker_signal_internal(listener->worker);
147 }
148 
ucp_conn_request_query(ucp_conn_request_h conn_request,ucp_conn_request_attr_t * attr)149 ucs_status_t ucp_conn_request_query(ucp_conn_request_h conn_request,
150                                     ucp_conn_request_attr_t *attr)
151 {
152     ucs_status_t status;
153 
154     if (attr->field_mask & UCP_CONN_REQUEST_ATTR_FIELD_CLIENT_ADDR) {
155         if (conn_request->client_address.ss_family == 0) {
156             return UCS_ERR_UNSUPPORTED;
157         }
158 
159         status = ucs_sockaddr_copy((struct sockaddr *)&attr->client_address,
160                                    (struct sockaddr *)&conn_request->client_address);
161         if (status != UCS_OK) {
162             return status;
163         }
164     }
165 
166     return UCS_OK;
167 }
168 
ucp_listener_query(ucp_listener_h listener,ucp_listener_attr_t * attr)169 ucs_status_t ucp_listener_query(ucp_listener_h listener,
170                                 ucp_listener_attr_t *attr)
171 {
172     ucs_status_t status;
173 
174     if (attr->field_mask & UCP_LISTENER_ATTR_FIELD_SOCKADDR) {
175         status = ucs_sockaddr_copy((struct sockaddr *)&attr->sockaddr,
176                                    (struct sockaddr *)&listener->sockaddr);
177         if (status != UCS_OK) {
178             return status;
179         }
180     }
181 
182     return UCS_OK;
183 }
184 
ucp_listener_close_uct_listeners(ucp_listener_h listener)185 static void ucp_listener_close_uct_listeners(ucp_listener_h listener)
186 {
187     ucp_rsc_index_t i;
188 
189     ucs_assert_always(ucp_worker_sockaddr_is_cm_proto(listener->worker));
190 
191     for (i = 0; i < listener->num_rscs; ++i) {
192         uct_listener_destroy(listener->listeners[i]);
193     }
194 
195     ucs_free(listener->listeners);
196 
197     listener->listeners = NULL;
198     listener->num_rscs  = 0;
199 }
200 
ucp_listener_close_ifaces(ucp_listener_h listener)201 static void ucp_listener_close_ifaces(ucp_listener_h listener)
202 {
203     ucp_worker_h worker;
204     int i;
205 
206     ucs_assert_always(!ucp_worker_sockaddr_is_cm_proto(listener->worker));
207 
208     for (i = 0; i < listener->num_rscs; i++) {
209         worker = listener->wifaces[i]->worker;
210         ucs_assert_always(worker == listener->worker);
211         /* remove pending slow-path progress in case it wasn't removed yet */
212         ucs_callbackq_remove_if(&worker->uct->progress_q,
213                                 ucp_listener_remove_filter, listener);
214         ucp_worker_iface_cleanup(listener->wifaces[i]);
215     }
216 
217     ucs_free(listener->wifaces);
218 }
219 
220 static ucs_status_t
ucp_listen_on_cm(ucp_listener_h listener,const ucp_listener_params_t * params)221 ucp_listen_on_cm(ucp_listener_h listener, const ucp_listener_params_t *params)
222 {
223     ucp_worker_h          worker  = listener->worker;
224     const ucp_rsc_index_t num_cms = ucp_worker_num_cm_cmpts(worker);
225     struct sockaddr_storage addr_storage;
226     struct sockaddr       *addr;
227     uct_listener_h        *uct_listeners;
228     uct_listener_params_t uct_params;
229     uct_listener_attr_t   uct_attr;
230     uint16_t              port, uct_listen_port;
231     ucp_rsc_index_t       i;
232     char                  addr_str[UCS_SOCKADDR_STRING_LEN];
233     ucp_worker_cm_t       *ucp_cm;
234     ucs_status_t          status;
235 
236     addr = (struct sockaddr *)&addr_storage;
237     status = ucs_sockaddr_copy(addr, params->sockaddr.addr);
238     if (status != UCS_OK) {
239         return status;
240     }
241 
242     ucs_assert_always(num_cms > 0);
243 
244     uct_params.field_mask       = UCT_LISTENER_PARAM_FIELD_CONN_REQUEST_CB |
245                                   UCT_LISTENER_PARAM_FIELD_USER_DATA;
246     uct_params.conn_request_cb  = ucp_cm_server_conn_request_cb;
247     uct_params.user_data        = listener;
248 
249     listener->num_rscs          = 0;
250     uct_listeners               = ucs_calloc(num_cms, sizeof(*uct_listeners),
251                                              "uct_listeners_arr");
252     if (uct_listeners == NULL) {
253         ucs_error("Can't allocate memory for UCT listeners array");
254         return UCS_ERR_NO_MEMORY;
255     }
256 
257     listener->listeners = uct_listeners;
258 
259     for (i = 0; i < num_cms; ++i) {
260         ucp_cm = &worker->cms[i];
261         status = uct_listener_create(ucp_cm->cm, addr,
262                                      params->sockaddr.addrlen, &uct_params,
263                                      &uct_listeners[listener->num_rscs]);
264         if (status != UCS_OK) {
265             ucs_debug("failed to create UCT listener on CM %p (component %s) "
266                       "with address %s status %s", ucp_cm->cm,
267                       worker->context->tl_cmpts[ucp_cm->cmpt_idx].attr.name,
268                       ucs_sockaddr_str(params->sockaddr.addr, addr_str,
269                                        UCS_SOCKADDR_STRING_LEN),
270                       ucs_status_string(status));
271             continue;
272         }
273 
274         ++listener->num_rscs;
275 
276         status = ucs_sockaddr_get_port(addr, &port);
277         if (status != UCS_OK) {
278             goto err_destroy_listeners;
279         }
280 
281         uct_attr.field_mask = UCT_LISTENER_ATTR_FIELD_SOCKADDR;
282         status = uct_listener_query(uct_listeners[listener->num_rscs - 1],
283                                     &uct_attr);
284         if (status != UCS_OK) {
285             goto err_destroy_listeners;
286         }
287 
288         status = ucs_sockaddr_get_port((struct sockaddr *)&uct_attr.sockaddr,
289                                        &uct_listen_port);
290         if (status != UCS_OK) {
291             goto err_destroy_listeners;
292         }
293 
294         if (port != uct_listen_port) {
295             ucs_assert(port == 0);
296             status = ucs_sockaddr_set_port(addr, uct_listen_port);
297             if (status != UCS_OK) {
298                 goto err_destroy_listeners;
299             }
300         }
301     }
302 
303     if (listener->num_rscs > 0) {
304         status = ucs_sockaddr_copy((struct sockaddr *)&listener->sockaddr,
305                                    addr);
306         if (status != UCS_OK) {
307             goto err_destroy_listeners;
308         }
309     }
310 
311     /* return the status of the last call of uct_listener_create if no listener
312        was created */
313     return (listener->num_rscs > 0) ? UCS_OK : status;
314 
315 err_destroy_listeners:
316     ucp_listener_close_uct_listeners(listener);
317     return status;
318 }
319 
320 static ucs_status_t
ucp_listen_on_iface(ucp_listener_h listener,const ucp_listener_params_t * params)321 ucp_listen_on_iface(ucp_listener_h listener,
322                     const ucp_listener_params_t *params)
323 {
324     ucp_worker_h worker   = listener->worker;
325     ucp_context_h context = listener->worker->context;
326     int sockaddr_tls      = 0;
327     char saddr_str[UCS_SOCKADDR_STRING_LEN];
328     ucp_tl_resource_desc_t *resource;
329     uct_iface_params_t iface_params;
330     struct sockaddr_storage *listen_sock;
331     ucp_worker_iface_t **tmp;
332     ucp_rsc_index_t tl_id;
333     ucs_status_t status;
334     ucp_tl_md_t *tl_md;
335     uint16_t port;
336     int i;
337 
338     status = ucs_sockaddr_get_port(params->sockaddr.addr, &port);
339     if (status != UCS_OK) {
340        return status;
341     }
342 
343     /* Go through all the available resources and for each one, check if the given
344      * sockaddr is accessible from its md. Start listening on all the mds that
345      * satisfy this.
346      * If the given port is set to 0, i.e. use a random port, the first transport
347      * in the sockaddr priority list from the environment configuration will
348      * dictate the port to listen on for the other sockaddr transports in the list.
349      * */
350     for (i = 0; i < context->config.num_sockaddr_tls; i++) {
351         tl_id    = context->config.sockaddr_tl_ids[i];
352         resource = &context->tl_rscs[tl_id];
353         tl_md    = &context->tl_mds[resource->md_index];
354 
355         if (!uct_md_is_sockaddr_accessible(tl_md->md, &params->sockaddr,
356                                            UCT_SOCKADDR_ACC_LOCAL)) {
357             continue;
358         }
359 
360         tmp = ucs_realloc(listener->wifaces,
361                           sizeof(*tmp) * (sockaddr_tls + 1),
362                           "listener wifaces");
363         if (tmp == NULL) {
364             ucs_error("failed to allocate listener wifaces");
365             status = UCS_ERR_NO_MEMORY;
366             goto err_close_listener_wifaces;
367         }
368 
369         listener->wifaces = tmp;
370 
371         iface_params.field_mask                     = UCT_IFACE_PARAM_FIELD_OPEN_MODE |
372                                                       UCT_IFACE_PARAM_FIELD_SOCKADDR;
373         iface_params.open_mode                      = UCT_IFACE_OPEN_MODE_SOCKADDR_SERVER;
374         iface_params.mode.sockaddr.conn_request_cb  = ucp_listener_conn_request_callback;
375         iface_params.mode.sockaddr.conn_request_arg = listener;
376         iface_params.mode.sockaddr.listen_sockaddr  = params->sockaddr;
377         iface_params.mode.sockaddr.cb_flags         = UCT_CB_FLAG_ASYNC;
378 
379         if (port) {
380             /* Set the port for the next sockaddr iface. This port was either
381              * obtained from the user or generated by the first created sockaddr
382              * iface if the port from the user was equal to zero */
383             status = ucs_sockaddr_set_port(
384                         (struct sockaddr *)
385                         iface_params.mode.sockaddr.listen_sockaddr.addr, port);
386             if (status != UCS_OK) {
387                 ucs_error("failed to set port parameter (%d) for creating %s iface",
388                           port, resource->tl_rsc.tl_name);
389                 goto err_close_listener_wifaces;
390             }
391         }
392 
393         status = ucp_worker_iface_open(worker, tl_id, &iface_params,
394                                        &listener->wifaces[sockaddr_tls]);
395         if (status != UCS_OK) {
396             ucs_error("failed to open listener on %s on md %s",
397                       ucs_sockaddr_str(
398                             iface_params.mode.sockaddr.listen_sockaddr.addr,
399                             saddr_str, sizeof(saddr_str)),
400                             tl_md->rsc.md_name);
401             goto err_close_listener_wifaces;
402         }
403 
404         status = ucp_worker_iface_init(worker, tl_id,
405                                        listener->wifaces[sockaddr_tls]);
406         if ((status != UCS_OK) ||
407             ((context->config.features & UCP_FEATURE_WAKEUP) &&
408              !(listener->wifaces[sockaddr_tls]->attr.cap.flags &
409                UCT_IFACE_FLAG_CB_ASYNC))) {
410             ucp_worker_iface_cleanup(listener->wifaces[sockaddr_tls]);
411             goto err_close_listener_wifaces;
412         }
413 
414         listen_sock = &listener->wifaces[sockaddr_tls]->attr.listen_sockaddr;
415         status = ucs_sockaddr_get_port((struct sockaddr *)listen_sock, &port);
416         if (status != UCS_OK) {
417             goto err_close_listener_wifaces;
418         }
419 
420         sockaddr_tls++;
421         listener->num_rscs = sockaddr_tls;
422         ucs_trace("listener %p: accepting connections on %s on %s",
423                   listener, tl_md->rsc.md_name,
424                   ucs_sockaddr_str(iface_params.mode.sockaddr.listen_sockaddr.addr,
425                                    saddr_str, sizeof(saddr_str)));
426     }
427 
428     if (!sockaddr_tls) {
429         ucs_error("none of the available transports can listen for connections on %s",
430                   ucs_sockaddr_str(params->sockaddr.addr, saddr_str,
431                   sizeof(saddr_str)));
432         listener->num_rscs = 0;
433         status = UCS_ERR_UNREACHABLE;
434         goto err_close_listener_wifaces;
435     }
436 
437     listen_sock = &listener->wifaces[sockaddr_tls - 1]->attr.listen_sockaddr;
438     status = ucs_sockaddr_copy((struct sockaddr *)&listener->sockaddr,
439                                (struct sockaddr *)listen_sock);
440     if (status != UCS_OK) {
441         goto err_close_listener_wifaces;
442     }
443 
444     return UCS_OK;
445 
446 err_close_listener_wifaces:
447     ucp_listener_close_ifaces(listener);
448     return status;
449 }
450 
ucp_listener_create(ucp_worker_h worker,const ucp_listener_params_t * params,ucp_listener_h * listener_p)451 ucs_status_t ucp_listener_create(ucp_worker_h worker,
452                                  const ucp_listener_params_t *params,
453                                  ucp_listener_h *listener_p)
454 {
455     ucp_listener_h listener;
456     ucs_status_t   status;
457 
458     if (!(params->field_mask & UCP_LISTENER_PARAM_FIELD_SOCK_ADDR)) {
459         ucs_error("missing sockaddr for listener");
460         return UCS_ERR_INVALID_PARAM;
461     }
462 
463     UCP_CHECK_PARAM_NON_NULL(params->sockaddr.addr, status, return status);
464 
465     if (ucs_test_all_flags(params->field_mask,
466                            UCP_LISTENER_PARAM_FIELD_ACCEPT_HANDLER |
467                            UCP_LISTENER_PARAM_FIELD_CONN_HANDLER)) {
468         ucs_error("only one accept handler should be provided");
469         return UCS_ERR_INVALID_PARAM;
470     }
471 
472     listener = ucs_calloc(1, sizeof(*listener), "ucp_listener");
473     if (listener == NULL) {
474         ucs_error("cannot allocate memory for UCP listener");
475         return UCS_ERR_NO_MEMORY;
476     }
477 
478     UCS_ASYNC_BLOCK(&worker->async);
479 
480     listener->worker = worker;
481 
482     if (params->field_mask & UCP_LISTENER_PARAM_FIELD_ACCEPT_HANDLER) {
483         UCP_CHECK_PARAM_NON_NULL(params->accept_handler.cb, status,
484                                  goto err_free_listener);
485         listener->accept_cb = params->accept_handler.cb;
486         listener->arg       = params->accept_handler.arg;
487     } else if (params->field_mask & UCP_LISTENER_PARAM_FIELD_CONN_HANDLER) {
488         UCP_CHECK_PARAM_NON_NULL(params->conn_handler.cb, status,
489                                  goto err_free_listener);
490         listener->conn_cb   = params->conn_handler.cb;
491         listener->arg       = params->conn_handler.arg;
492     }
493 
494     if (ucp_worker_sockaddr_is_cm_proto(worker)) {
495         status = ucp_listen_on_cm(listener, params);
496     } else {
497         status = ucp_listen_on_iface(listener, params);
498     }
499 
500     if (status == UCS_OK) {
501         *listener_p = listener;
502         goto out;
503     }
504 
505 err_free_listener:
506     ucs_free(listener);
507 out:
508     UCS_ASYNC_UNBLOCK(&worker->async);
509     return status;
510 }
511 
ucp_listener_destroy(ucp_listener_h listener)512 void ucp_listener_destroy(ucp_listener_h listener)
513 {
514     ucs_trace("listener %p: destroying", listener);
515 
516     if (ucp_worker_sockaddr_is_cm_proto(listener->worker)) {
517         ucp_listener_close_uct_listeners(listener);
518     } else {
519         ucp_listener_close_ifaces(listener);
520     }
521 
522     ucs_free(listener);
523 }
524 
ucp_listener_reject(ucp_listener_h listener,ucp_conn_request_h conn_request)525 ucs_status_t ucp_listener_reject(ucp_listener_h listener,
526                                  ucp_conn_request_h conn_request)
527 {
528     ucp_worker_h worker = listener->worker;
529 
530     UCS_ASYNC_BLOCK(&worker->async);
531 
532     if (ucp_worker_sockaddr_is_cm_proto(worker)) {
533         uct_listener_reject(conn_request->uct.listener, conn_request->uct_req);
534         ucs_free(conn_request->remote_dev_addr);
535     } else {
536         uct_iface_reject(conn_request->uct.iface, conn_request->uct_req);
537     }
538 
539     UCS_ASYNC_UNBLOCK(&worker->async);
540 
541     ucs_free(conn_request);
542 
543     return UCS_OK;
544 }
545