1 /*
2 ** Copyright (C) 2006-2020 by Carnegie Mellon University.
3 **
4 ** @OPENSOURCE_LICENSE_START@
5 ** See license information in ../../LICENSE.txt
6 ** @OPENSOURCE_LICENSE_END@
7 */
8 
9 /*
10 **  SiLK message functions
11 **
12 */
13 
14 
15 #include <silk/silk.h>
16 
17 RCSIDENT("$SiLK: skmsg.c ef14e54179be 2020-04-14 21:57:45Z mthomas $");
18 
19 #include "intdict.h"
20 #include "multiqueue.h"
21 #include "skmsg.h"
22 #include "libsendrcv.h"
23 #include <silk/skdeque.h>
24 #include <silk/sklog.h>
25 #include <silk/skstream.h>
26 #include <silk/skstringmap.h>
27 #include <silk/utils.h>
28 #include <poll.h>
29 #if SK_ENABLE_GNUTLS
30 #include <gnutls/gnutls.h>
31 #include <gnutls/x509.h>
32 #include <gnutls/pkcs12.h>
33 #endif  /* SK_ENABLE_GNUTLS */
34 
35 /* SENDRCV_DEBUG is defined in libsendrcv.h */
36 #if (SENDRCV_DEBUG) & DEBUG_SKMSG_MUTEX
37 #  define SKTHREAD_DEBUG_MUTEX 1
38 #endif
39 #include <silk/skthread.h>
40 
41 
42 /* Define Constants */
43 
44 /* Maximum number of CA certs that can be in the CA cert file */
45 #define MAX_CA_CERTS 32
46 
47 /* GnuTLS minimum version */
48 #define MIN_GNUTLS_VERISON      "2.12.0"
49 
50 /* Maximum GnuTLS logging level accepted by gnutls_global_set_log_level */
51 #define MAX_TLS_DEBUG_LEVEL     99
52 
53 /* Keepalive timeout for the control channel */
54 #define SKMSG_CONTROL_KEEPALIVE_TIMEOUT 60 /* seconds */
55 
56 /* Time used by connections without keepalive times to determine how
57  * whether the connection is stagnant. */
58 #define SKMSG_DEFAULT_STAGNANT_TIMEOUT (2 * SKMSG_CONTROL_KEEPALIVE_TIMEOUT)
59 
60 /* Time used to determine how often to check to see if the connection
61  * is still alive. */
62 #define SKMSG_MINIMUM_READ_SELECT_TIMEOUT SKMSG_CONTROL_KEEPALIVE_TIMEOUT
63 
64 /* Read and write sides of control pipes */
65 #define READ 0
66 #define WRITE 1
67 
68 #define SKMERR_MEMORY  -1
69 #define SKMERR_PIPE    -2
70 #define SKMERR_MUTEX   -3
71 #define SKMERR_PTHREAD -4
72 #define SKMERR_ERROR   -5
73 #define SKMERR_ERRNO   -6
74 #define SKMERR_CLOSED  -7
75 #define SKMERR_SHORT   -8
76 #define SKMERR_PARTIAL -9
77 #define SKMERR_EMPTY   -10
78 #define SKMERR_GNUTLS  -11
79 
80 #define LISTENQ 5
81 
82 #define SKMSG_CTL_CHANNEL_ANNOUNCE  0xFFFE
83 #define SKMSG_CTL_CHANNEL_REPLY     0xFFFD
84 #define SKMSG_CTL_CHANNEL_KILL      0xFFFC
85 #define SKMSG_CTL_CHANNEL_KEEPALIVE 0xFFFA
86 #define SKMSG_WRITER_UNBLOCKER      0xFFFB
87 
88 #define SKMSG_MINIMUM_SYSTEM_CTL_CHANNEL 0xFFFA
89 
90 /* Default security level. */
91 #define TLS_SECURITY_DEFAULT    "medium"
92 
93 /* TLS read timeout, in milliseconds */
94 #define TLS_POLL_TIMEOUT 1000
95 
96 /* IO thread check timeout, in milliseconds*/
97 #define SKMSG_IO_POLL_TIMEOUT 1000
98 
99 /* Whether to use the custom tls_pull, tls_push */
100 #ifndef SK_TLS_USE_CUSTOM_PULL_PUSH
101 #define SK_TLS_USE_CUSTOM_PULL_PUSH 0
102 #endif
103 
104 
105 /* Define Macros */
106 
107 /*
108  *  this macro is used when the extra-level debugging statements write
109  *  to the log, since we do not want the result of the log's printf()
110  *  to trash the value of 'errno'
111  */
112 #define WRAP_ERRNO(x)                                           \
113     do { int _saveerrno = errno; x; errno = _saveerrno; }       \
114     while (0)
115 
116 /* SENDRCV_DEBUG is defined in libsendrcv.h */
117 #if 0 == ((SENDRCV_DEBUG) & DEBUG_SKMSG_OTHER)
118 #define DEBUG_PRINT1(x)
119 #define DEBUG_PRINT2(x, y)
120 #define DEBUG_PRINT3(x, y, z)
121 #define DEBUG_PRINT4(x, y, z, q)
122 #else
123 #define DEBUG_PRINT1(x)          WRAP_ERRNO(SKTHREAD_DEBUG_PRINT1(x))
124 #define DEBUG_PRINT2(x, y)       WRAP_ERRNO(SKTHREAD_DEBUG_PRINT2(x, y))
125 #define DEBUG_PRINT3(x, y, z)    WRAP_ERRNO(SKTHREAD_DEBUG_PRINT3(x, y, z))
126 #define DEBUG_PRINT4(x, y, z, q) WRAP_ERRNO(SKTHREAD_DEBUG_PRINT4(x, y, z, q))
127 #endif  /* (SENDRCV_DEBUG) & DEBUG_SKMSG_OTHER */
128 
129 
130 /*
131  *  Logging messages for function entry/exit.
132  *
133  *    Use DEBUG_ENTER_FUNC at the beginning of every function.  Use
134  *    RETURN(x) for functions that return a value, and RETURN_VOID for
135  *    functions that have no return value.
136  *
137  *    SENDRCV_DEBUG is defined in libsendrcv.h.  DEBUG_SKMSG_FN turns
138  *    on function entry/exit debug print macros.
139  */
140 #if 0 == ((SENDRCV_DEBUG) & DEBUG_SKMSG_FN)
141 #  define DEBUG_ENTER_FUNC
142 #  define RETURN(x)         return x
143 #  define RETURN_VOID       return
144 #else
145 #  define DEBUG_ENTER_FUNC                                      \
146     WRAP_ERRNO(SKTHREAD_DEBUG_PRINT2("Entering %s", __func__))
147 
148 #  define RETURN(x)                                                     \
149     do {                                                                \
150         WRAP_ERRNO(SKTHREAD_DEBUG_PRINT2("Exiting %s", __func__));      \
151         return x;                                                       \
152     } while (0)
153 
154 #  define RETURN_VOID                                                   \
155     do {                                                                \
156         WRAP_ERRNO(SKTHREAD_DEBUG_PRINT2("Exiting %s", __func__));      \
157         return;                                                         \
158     } while (0)
159 
160 #endif  /* ((SENDRCV_DEBUG) & DEBUG_SKMSG_FN) */
161 
162 
163 /* MUTEX_* macros are defined in skthread.h */
164 
165 /* Mutex accessor */
166 #define QUEUE_MUTEX(q)      (&(q)->root->mutex)
167 
168 /* Queue lock */
169 #define QUEUE_LOCK(q)       MUTEX_LOCK(QUEUE_MUTEX(q))
170 
171 /* Queue unlock */
172 #define QUEUE_UNLOCK(q)     MUTEX_UNLOCK(QUEUE_MUTEX(q))
173 
174 /* Queue wait */
175 #define QUEUE_WAIT(cond, q) MUTEX_WAIT(cond, QUEUE_MUTEX(q))
176 
177 /* Queue locked assert */
178 #define ASSERT_QUEUE_LOCK(q)                                    \
179     assert(pthread_mutex_trylock(QUEUE_MUTEX(q)) == EBUSY)
180 
181 
182 /* Any place an XASSERT occurs could be replaced with better error
183    handling at some point in time. */
184 #define XASSERT(x) do {                                                 \
185     if (!(x)) {                                                         \
186         CRITMSG("Unhandled error at " __FILE__ ":%" PRIu32 " \"" #x "\"", \
187                 __LINE__);                                              \
188         skAbort();                                                      \
189     }                                                                   \
190     } while (0)
191 #define MEM_ASSERT(x) do {                                              \
192     if (!(x)) {                                                         \
193         CRITMSG(("Memory allocation error creating \"" #x "\" at "      \
194                  __FILE__ ":%" PRIu32), __LINE__);                      \
195         abort();                                                        \
196     }                                                                   \
197     } while (0)
198 
199 /* Macros to deal with thread creation and destruction */
200 #define QUEUE_TINFO(q) ((q)->root->tinfo)
201 
202 #define THREAD_INFO_INIT(q)                             \
203     do {                                                \
204         pthread_cond_init(&QUEUE_TINFO(q).cond, NULL);  \
205         QUEUE_TINFO(q).count = 0;                       \
206     } while (0)
207 
208 #define THREAD_INFO_DESTROY(q) pthread_cond_destroy(&QUEUE_TINFO(q).cond)
209 
210 
211 #define THREAD_START(name, rv, q, loc, fn, arg)                 \
212     do {                                                        \
213         DEBUG_PRINT1("THREAD_START");                           \
214         ASSERT_QUEUE_LOCK(q);                                   \
215         QUEUE_TINFO(q).count++;                                 \
216         (rv) = skthread_create(name, (loc), (fn), (arg));       \
217         if ((rv) != 0) {                                        \
218             QUEUE_TINFO(q).count--;                             \
219         }                                                       \
220     } while (0)
221 
222 #define THREAD_END(q)                                           \
223     do {                                                        \
224         DEBUG_PRINT1("THREAD_END");                             \
225         ASSERT_QUEUE_LOCK(q);                                   \
226         assert(QUEUE_TINFO(q).count != 0);                      \
227         QUEUE_TINFO(q).count--;                                 \
228         DEBUG_PRINT2("THREAD END COUNT decremented to %d",      \
229                      QUEUE_TINFO(q).count);                     \
230         MUTEX_BROADCAST(&QUEUE_TINFO(q).cond);                  \
231     } while (0)
232 
233 #define THREAD_WAIT_END(q, state)                               \
234     do {                                                        \
235         DEBUG_PRINT1("WAITING FOR THREAD_END");                 \
236         ASSERT_QUEUE_LOCK(q);                                   \
237         while ((state) != SKM_THREAD_ENDED) {                   \
238             QUEUE_WAIT(&QUEUE_TINFO(q).cond, q);                \
239         }                                                       \
240         DEBUG_PRINT1("FINISHED WAITING FOR THREAD_END");        \
241     } while (0)
242 
243 #define THREAD_WAIT_ALL_END(q)                                  \
244     do {                                                        \
245         DEBUG_PRINT1("WAITING FOR ALL THREAD_END");             \
246         ASSERT_QUEUE_LOCK(q);                                   \
247         while (QUEUE_TINFO(q).count != 0) {                     \
248             DEBUG_PRINT2("THREAD ALL END WAIT COUNT == %d",     \
249                          QUEUE_TINFO(q).count);                 \
250             QUEUE_WAIT(&QUEUE_TINFO(q).cond, q);                \
251         }                                                       \
252         DEBUG_PRINT1("FINISHED WAITING FOR ALL THREAD_END");    \
253     } while (0)
254 
255 
256 /*
257  *  is_stagnat = CONNECTION_STAGNANT(conn, t);
258  *
259  *    Compare current time 't' with time the last message was received
260  *    for connection 'conn'; return 1 if the time difference is larger
261  *    than the timeout or 0 otherwise.
262  */
263 #define CONNECTION_STAGNANT(conn, t)                                    \
264     (difftime(t, (conn)->last_recv) > ((conn)->keepalive                \
265                                        ? (2.0 * (conn)->keepalive)      \
266                                        : SKMSG_DEFAULT_STAGNANT_TIMEOUT))
267 
268 /* return a string for a received event from poll() */
269 #define SK_POLL_EVENT_STR(ev_)                          \
270     (((ev_) & POLLHUP) ? "POLLHUP"                      \
271      : (((ev_) & POLLERR) ? "POLLERR"                   \
272         : (((ev_) & POLLNVAL) ? "POLLNVAL" : "")))
273 
274 
275 /*
276  *  ASSERT_RESULT(ar_func_args, ar_type, ar_expected);
277  *
278  *    ar_func_args  -- is a function and any arugments it requires
279  *    ar_type       -- is the type that ar_func_args returns
280  *    ar_expected   -- is the expected return value from ar_func_args
281  *
282  *    If assert() is disabled, simply call 'ar_func_args'.
283  *
284  *    If assert() is enabled, call 'ar_func_args', capture its result,
285  *    and assert() that its result is 'ar_expected'.
286  */
287 #ifdef  NDEBUG
288 /* asserts are disabled; just call the function */
289 #define ASSERT_RESULT(ar_func_args, ar_type, ar_expected)  ar_func_args
290 #else
291 #define ASSERT_RESULT(ar_func_args, ar_type, ar_expected)       \
292     do {                                                        \
293         ar_type ar_rv = (ar_func_args);                         \
294         assert(ar_rv == (ar_expected));                         \
295     } while(0)
296 #endif  /* #else of #ifdef NDEBUG */
297 
298 
299 
300 /*** Local types ***/
301 
302 /* The type of message headers */
303 typedef struct sk_msg_hdr_st {
304     skm_channel_t channel;
305     skm_type_t    type;
306     skm_len_t     size;
307 } sk_msg_hdr_t;
308 
309 /* The type of messages */
310 /* typedef struct sk_msg_st sk_msg_t;  // from skmsg.h */
311 struct sk_msg_st {
312     sk_msg_hdr_t hdr;
313     void       (*free_fn)(uint16_t, struct iovec *);
314     void       (*simple_free)(void *);
315     uint16_t     segments;
316     struct iovec segment[1];
317 };
318 
319 /* Buffer for reading an sk_msg_t; used to support partial reads */
320 typedef struct sk_msg_read_buf_st {
321     /* The message being read */
322     sk_msg_t *msg;
323     /* Location to read into */
324     uint8_t  *loc;
325     /* Number of bytes still to read */
326     uint16_t  count;
327 } sk_msg_read_buf_t;
328 
329 /* Buffer for writing an sk_msg_t; used to support partial writes */
330 typedef struct sk_msg_write_buf_st {
331     /* The message being written */
332     sk_msg_t   *msg;
333     /* Number of bytes of the message remaining to write */
334     ssize_t     msg_size;
335     /* The index of the iov segment currently being sent */
336     uint16_t    cur_seg;
337     /* Number of bytes of the current segment that have been sent */
338     uint16_t    seg_offset;
339 } sk_msg_write_buf_t;
340 
341 
342 /* Allow tracking of thread entance and exits. */
343 typedef struct sk_thread_info_st {
344     pthread_cond_t  cond;
345     uint32_t        count;
346 } sk_thread_info_t;
347 
348 typedef enum {
349     SKM_THREAD_BEFORE,
350     SKM_THREAD_RUNNING,
351     SKM_THREAD_SHUTTING_DOWN,
352     SKM_THREAD_ENDED
353 } sk_thread_state_t;
354 
355 /* The type of a message queue root */
356 typedef struct sk_msg_root_st {
357     /* Global mutex for message queue */
358     pthread_mutex_t     mutex;
359 
360     /* The next channel number to try to allocate */
361     skm_channel_t       next_channel;
362 
363     /* Information about active threads */
364     sk_thread_info_t    tinfo;
365 
366     /* Map of channel-id to channels */
367     int_dict_t         *channel;
368     /* Map of read socket to connections */
369     int_dict_t         *connection;
370     /* Map of channel-ids to queues */
371     int_dict_t         *groups;
372 
373     /* The information for binding and listening for connections */
374     struct pollfd      *pfd;
375     nfds_t              pfd_len;
376     pthread_t           listener;
377     /* The listener state */
378     sk_thread_state_t   listener_state;
379     /* Listener condition variable */
380     pthread_cond_t      listener_cond;
381 
382     sk_msg_queue_t     *shutdownqueue;
383 
384 #if SK_ENABLE_GNUTLS
385     /* Auth/Encryption credentials */
386     gnutls_certificate_credentials_t cred;
387 #endif
388 
389     unsigned            shuttingdown: 1;
390     /* whether GnuTLS credentials have been set */
391     unsigned            cred_set: 1;
392     /* whether this connection uses TLS */
393     unsigned            bind_tls: 1;
394 } sk_msg_root_t;
395 
396 
397 /* The type of a message queue */
398 /* typedef struct sk_msg_queue_st sk_msg_queue_t; */
399 struct sk_msg_queue_st {
400     sk_msg_root_t      *root;
401     /* Map of channels */
402     int_dict_t         *channel;
403     /* Queue group for all channels */
404     mq_multi_t         *group;
405     pthread_cond_t      shutdowncond;
406     unsigned            shuttingdown: 1;
407 };
408 
409 
410 /* States for connections and channels */
411 typedef enum {
412     SKM_CREATED,
413     SKM_CONNECTING,
414     SKM_CONNECTED,
415     SKM_CLOSED
416 } sk_msg_state_t;
417 
418 /* States for threads */
419 typedef enum {
420     SKM_SEND_INTERNAL,
421     SKM_SEND_REMOTE,
422     SKM_SEND_CONTROL
423 } sk_send_type_t;
424 
425 /* The type of connection to use */
426 typedef enum {
427     CONN_TCP,
428     CONN_TLS
429 } skm_conn_t;
430 
431 /* TLS connections */
432 typedef enum {
433     SKM_TLS_NONE,
434     SKM_TLS_CLIENT,
435     SKM_TLS_SERVER
436 } skm_tls_type_t;
437 
438 
439 /* Forward declarations; these strucures reference each other */
440 typedef struct sk_msg_channel_queue_st sk_msg_channel_queue_t;
441 typedef struct sk_msg_conn_queue_st sk_msg_conn_queue_t;
442 
443 /* Represents a connected socket or pipe */
444 /* typedef struct sk_msg_conn_queue_st sk_msg_conn_queue_t; */
445 struct sk_msg_conn_queue_st {
446     /* Read socket */
447     int                     rsocket;
448     /* Write socket */
449     int                     wsocket;
450 
451     /* Address of connection */
452     struct sockaddr        *addr;
453     /* Length of address of connection */
454     socklen_t               addrlen;
455 
456     /* Transport type */
457     skm_conn_t              transport;
458 
459     /* Channel map */
460     int_dict_t             *channelmap;
461     /* Channel refcount */
462     uint16_t                refcount;
463 
464     /* Current state of connection */
465     sk_msg_state_t          state;
466 
467     /* Outgoing write queue */
468     skDeque_t               queue;
469     /* Writer thread handle */
470     pthread_t               writer;
471     /* State */
472     sk_thread_state_t       writer_state;
473     /* Condition variable for thread */
474     pthread_cond_t          writer_cond;
475 
476     /* The reader thread */
477     pthread_t               reader;
478     /* The reader state */
479     sk_thread_state_t       reader_state;
480     /* Reader condition variable */
481     pthread_cond_t          reader_cond;
482 
483     /* Most recent error from read/write; paired with SKMERR_ERRNO and
484      * SKMERR_GNUTLS return values */
485     int                     last_errnum;
486 
487     /* Keepalive timeout */
488     uint16_t                keepalive;
489     /* Time of last received data */
490     time_t                  last_recv;
491 
492     /* Buffer to support partial read of incoming messages */
493     sk_msg_read_buf_t       msg_read_buf;
494 
495     /* Pre-connected initial channel */
496     sk_msg_channel_queue_t *first_channel;
497 
498 #if SK_ENABLE_GNUTLS
499     gnutls_session_t        session;
500     unsigned                use_tls : 1;
501 #endif  /* SK_ENABLE_GNUTLS */
502 };
503 
504 
505 /* Represents a channel */
506 /* typedef struct sk_msg_channel_queue_st sk_msg_channel_queue_t; */
507 struct sk_msg_channel_queue_st {
508     /* channel's queue */
509     mq_queue_t            *queue;
510     /* local channel ID */
511     skm_channel_t          channel;
512     /* remote channel ID */
513     skm_channel_t          rchannel;
514     /* channel state */
515     sk_msg_state_t         state;
516     /* associated connection */
517     sk_msg_conn_queue_t   *conn;
518     /* group associated with this channel */
519     sk_msg_queue_t        *group;
520     /* pending condition variable */
521     pthread_cond_t         pending;
522     /* whether we are waiting for connection */
523     unsigned               is_pending: 1;
524     unsigned               flushing: 1;
525 };
526 
527 /* Used for passing data to a writer thread */
528 typedef struct sk_queue_and_conn_st {
529     sk_msg_queue_t         *q;
530     sk_msg_conn_queue_t    *conn;
531 } sk_queue_and_conn_t;
532 
533 
534 /* Used to represent a local/remote channel pair */
535 typedef struct sk_channel_pair_st {
536     skm_channel_t   lchannel;
537     skm_channel_t   rchannel;
538 } sk_channel_pair_t;
539 
540 
541 
542 /*** Local function prototypes ***/
543 
544 static void *reader_thread(void *);
545 static void *writer_thread(void *);
546 static void *listener_thread(void *);
547 
548 static int
549 destroy_connection(
550     sk_msg_queue_t         *q,
551     sk_msg_conn_queue_t    *conn);
552 
553 static int
554 send_message(
555     sk_msg_queue_t     *q,
556     skm_channel_t       lchannel,
557     skm_type_t          type,
558     void               *message,
559     skm_len_t           length,
560     sk_send_type_t      send_type,
561     int                 no_copy,
562     void              (*free_fn)(void *));
563 
564 static int
565 send_message_internal(
566     sk_msg_channel_queue_t *chan,
567     sk_msg_t               *msg,
568     sk_send_type_t          send_type);
569 
570 
571 /*** Local variables ***/
572 
573 #if SK_ENABLE_GNUTLS
574 #if GNUTLS_VERSION_NUMBER < 0x030300
575 #define GNUTLS_SEC_PARAM_MEDIUM GNUTLS_SEC_PARAM_NORMAL
576 #endif
577 
578 /* Mutex to control access to GnuTLS global state */
579 static pthread_mutex_t sk_msg_gnutls_mutex = PTHREAD_MUTEX_INITIALIZER;
580 static int sk_msg_gnutls_initialized = 0;
581 
582 /* Local handle to the name of environment variable holding the
583  * password for the PKCS12 file. */
584 static const char *password_env_name = NULL;
585 
586 /* Encryption and authentication files */
587 static const char *tls_ca_file     = NULL;
588 static const char *tls_cert_file   = NULL;
589 static const char *tls_key_file    = NULL;
590 static const char *tls_pkcs12_file = NULL;
591 static const char *tls_crl_file    = NULL;
592 
593 /* GnuTLS debugging log level */
594 static int tls_debug_level = 0;
595 
596 /* The priority string the user provides for GnuTLS priority, used to
597  * fill tls_priority_cache. */
598 static const char *tls_priority = NULL;
599 
600 /* GnuTLS priority cache passed to gnutls_priority_set(), based on
601  * tls_priority.  If NULL, the default priority is used. */
602 static gnutls_priority_t tls_priority_cache = NULL;
603 
604 /* The level the user provides for the GnuTLS security parameter
605  * level.  If not provided, defaults to TLS_SECURITY_DEFAULT. */
606 static const char *tls_security = NULL;
607 
608 /* The GnuTLS security parameter level based on 'tls_security'. */
609 static gnutls_sec_param_t tls_security_param;
610 
611 /* Potential values for tls_security */
612 static const sk_stringmap_entry_t tls_security_levels[] = {
613     {"low",             GNUTLS_SEC_PARAM_LOW,       NULL,   NULL},
614     {"medium",          GNUTLS_SEC_PARAM_MEDIUM,    NULL,   NULL},
615     {"high",            GNUTLS_SEC_PARAM_HIGH,      NULL,   NULL},
616     {"ultra",           GNUTLS_SEC_PARAM_ULTRA,     NULL,   NULL},
617     SK_STRINGMAP_SENTINEL
618 };
619 
620 typedef enum {
621     OPT_TLS_CA,
622     OPT_TLS_CERT,
623     OPT_TLS_KEY,
624     OPT_TLS_PKCS12,
625     OPT_TLS_SECURITY,
626     OPT_TLS_PRIORITY,
627     OPT_TLS_CRL,
628     OPT_TLS_DEBUG_LEVEL
629 } sk_tls_options_enum;
630 
631 /* Command line switches for encryption/authentication files */
632 static struct option sk_tls_options[] = {
633     {"tls-ca",          REQUIRED_ARG, 0, OPT_TLS_CA},
634     {"tls-cert",        REQUIRED_ARG, 0, OPT_TLS_CERT},
635     {"tls-key",         REQUIRED_ARG, 0, OPT_TLS_KEY},
636     {"tls-pkcs12",      REQUIRED_ARG, 0, OPT_TLS_PKCS12},
637     {"tls-security",    REQUIRED_ARG, 0, OPT_TLS_SECURITY},
638     {"tls-priority",    REQUIRED_ARG, 0, OPT_TLS_PRIORITY},
639     {"tls-crl",         REQUIRED_ARG, 0, OPT_TLS_CRL},
640     {"tls-debug-level", REQUIRED_ARG, 0, OPT_TLS_DEBUG_LEVEL},
641     {0, 0, 0, 0}        /* sentinel */
642 };
643 
644 /* Usage text for each command line switch */
645 static const char *sk_tls_options_help[] = {
646     ("Load the Certificate Authority from the file in PEM format\n"
647      "\tlocated at this complete path. Def. None. Either --tls-key\n"
648      "\tand --tls-key or --tls-pkcs12 must also be specified"),
649     ("Load the encryption cert from the file in PEM format\n"
650      "\tlocated at this complete path. Def. None.  Requires that --tls-ca\n"
651      "\tand --tls-key are also specified"),
652     ("Load the encryption key from the file in PEM format\n"
653      "\tlocated at this complete path. Def. None. Requires that --tls-ca\n"
654      "\tand --tls-cert are also specified"),
655     ("Load the encryption cert and key from the file in\n"
656      "\tPKCS#12 format located at this complete path. Def. None. Requires\n"
657      "\tthat --tls-ca is also specified"),
658     ("Specify the security level to use when the required\n"
659      "\tfile options are provided. Def. '" TLS_SECURITY_DEFAULT "'.\n"
660      "\tChoices:"),
661     ("Specify the priorities for ciphers, key exchange\n"
662      "\tmethods, message authentication codes, and compression methods.\n"
663      "\tSee the GnuTLS documentation of \"Priority Strings\" for the details\n"
664      "\tabout the format of the argument. Def. 'NORMAL'"),
665     ("Load the Certificate Revocation List from the file in PEM\n"
666      "\tformat located at this complete path. Def. None"),
667     ("Set TLS debugging level to the specified value.\n"
668      "\tDef. 0. Range: 0-"),
669     (char *)NULL
670 };
671 
672 
673 #endif /* SK_ENABLE_GNUTLS */
674 
675 /* Type of connections to use.  Set to TLS when the required files are
676  * specified on the command line. */
677 static skm_conn_t connection_type = CONN_TCP;
678 
679 
680 /* Utility functions */
681 
682 
683 static sk_msg_channel_queue_t *
find_channel(sk_msg_queue_t * q,skm_channel_t channel)684 find_channel(
685     sk_msg_queue_t     *q,
686     skm_channel_t       channel)
687 {
688     sk_msg_channel_queue_t **chan;
689 
690     DEBUG_ENTER_FUNC;
691 
692     chan = (sk_msg_channel_queue_t **)int_dict_get(q->root->channel,
693                                                    channel, NULL);
694 
695     RETURN(chan ? *chan : NULL);
696 }
697 
698 
699 static void
msg_simple_free(uint16_t n,struct iovec * iov)700 msg_simple_free(
701     uint16_t            n,
702     struct iovec       *iov)
703 {
704     while (n > 0) {
705         --n;
706         free(iov[n].iov_base);
707     }
708 }
709 
710 
711 static void
sk_destroy_report_message(void * vmsg)712 sk_destroy_report_message(
713     void               *vmsg)
714 {
715     sk_msg_t *msg = (sk_msg_t *)vmsg;
716 
717     DEBUG_ENTER_FUNC;
718 
719     DEBUG_PRINT3("Queue (destroy): chan=%#x type=%#x",
720                  msg->hdr.channel, msg->hdr.type);
721 
722     skMsgDestroy(msg);
723 
724     RETURN_VOID;
725 }
726 
727 
728 static void
set_nonblock(int fd)729 set_nonblock(
730     int                 fd)
731 {
732     int flags, rv;
733 
734     DEBUG_ENTER_FUNC;
735 
736     flags = fcntl(fd, F_GETFL, 0);
737     XASSERT(flags != -1);
738     flags |= O_NONBLOCK;
739     rv = fcntl(fd, F_SETFL, flags);
740     XASSERT(rv != -1);
741 
742     RETURN_VOID;
743 }
744 
745 
746 static const char *
skmerr_strerror(const sk_msg_conn_queue_t * conn,int retval)747 skmerr_strerror(
748     const sk_msg_conn_queue_t  *conn,
749     int                         retval)
750 {
751     static char buf[256];
752 
753     switch (retval) {
754       case SKMERR_MEMORY:
755         return "Memory allocation failure";
756       case SKMERR_PIPE:
757         return "Failed to create pipe";
758       case SKMERR_MUTEX:
759         return "Failed to initialize pthread mutex or condition variable";
760       case SKMERR_PTHREAD:
761         return "Error with pthread";
762       case SKMERR_ERROR:
763         return "Generic error";
764       case SKMERR_ERRNO:
765         if (conn) {
766             return strerror(conn->last_errnum);
767         }
768         return strerror(errno);
769       case SKMERR_CLOSED:
770         return "Connection is closed";
771       case SKMERR_SHORT:
772         return "Short read or write (fail)";
773       case SKMERR_PARTIAL:
774         return "Partial read or write (will retry)";
775       case SKMERR_EMPTY:
776         return "Empty read (will retry)";
777       case SKMERR_GNUTLS:
778 #if SK_ENABLE_GNUTLS
779         if (conn) {
780             return gnutls_strerror(conn->last_errnum);
781         }
782 #endif  /* SK_ENABLE_GNUTLS */
783         return "GnuTLS error";
784     }
785 
786     snprintf(buf, sizeof(buf), "Unknown SKMERR_ error code value %d", retval);
787     return buf;
788 }
789 
790 
791 
792 #if SK_ENABLE_GNUTLS
793 /*
794  *  file_exists = optionsFileCheck(opt_name, opt_arg);
795  *
796  *    Verify that the file in 'opt_arg' exists and that we have a full
797  *    path to the file.  Verify that the length is shorter than
798  *    PATH_MAX.  If so, return 0; otherwise, print an error that the
799  *    option named by 'opt_name' was bad and return -1.
800  */
801 static int
optionsFileCheck(const char * opt_name,const char * opt_arg)802 optionsFileCheck(
803     const char         *opt_name,
804     const char         *opt_arg)
805 {
806     if (!opt_arg || !opt_arg[0]) {
807         skAppPrintErr("Invalid %s: The argument empty", opt_name);
808         return -1;
809     }
810 
811     if (strlen(opt_arg)+1 >= PATH_MAX) {
812         skAppPrintErr("Invalid %s: Path is too long", opt_name);
813         return -1;
814     }
815 
816     if (!skFileExists(opt_arg)) {
817         skAppPrintErr(("Invalid %s:"
818                        " File '%s' does not exist or is not a regular file"),
819                       opt_name, opt_arg);
820         return -1;
821     }
822 
823     if (opt_arg[0] != '/') {
824         skAppPrintErr(("Invalid %s: Must use complete path"
825                        " ('%s' does not begin with slash)"),
826                       opt_name, opt_arg);
827         return -1;
828     }
829 
830     return 0;
831 }
832 
833 
834 /**
835  *    Handle a single command line argument.
836  */
837 static int
skMsgTlsOptionsHandler(clientData cData,int opt_index,char * opt_arg)838 skMsgTlsOptionsHandler(
839     clientData          cData,
840     int                 opt_index,
841     char               *opt_arg)
842 {
843     uint32_t tmp32;
844     int rv;
845 
846     SK_UNUSED_PARAM(cData);
847 
848 #define SET_FILE_OPTION(variable)                                       \
849     if (variable) {                                                     \
850         skAppPrintErr("Invalid %s: Switch used multiple times",         \
851                       sk_tls_options[opt_index].name);                  \
852     }                                                                   \
853     if (optionsFileCheck(sk_tls_options[opt_index].name, opt_arg)) {    \
854         return 1;                                                       \
855     }                                                                   \
856     variable = opt_arg
857 
858 
859     switch ((sk_tls_options_enum)opt_index) {
860       case OPT_TLS_CA:
861         SET_FILE_OPTION(tls_ca_file);
862         break;
863 
864       case OPT_TLS_CERT:
865         SET_FILE_OPTION(tls_cert_file);
866         break;
867 
868       case OPT_TLS_KEY:
869         SET_FILE_OPTION(tls_key_file);
870         break;
871 
872       case OPT_TLS_PKCS12:
873         SET_FILE_OPTION(tls_pkcs12_file);
874         break;
875 
876       case OPT_TLS_CRL:
877         SET_FILE_OPTION(tls_crl_file);
878         break;
879 
880       case OPT_TLS_PRIORITY:
881         if (tls_priority) {
882             skAppPrintErr("Invalid %s: Switch used multiple times",
883                           sk_tls_options[opt_index].name);
884         }
885         tls_priority = opt_arg;
886         break;
887 
888       case OPT_TLS_SECURITY:
889         if (tls_security) {
890             skAppPrintErr("Invalid %s: Switch used multiple times",
891                           sk_tls_options[opt_index].name);
892         }
893         tls_security = opt_arg;
894         break;
895 
896       case OPT_TLS_DEBUG_LEVEL:
897         rv = skStringParseUint32(&tmp32, opt_arg, 0, MAX_TLS_DEBUG_LEVEL);
898         if (rv) {
899             skAppPrintErr("Invalid %s '%s': %s",
900                           sk_tls_options[opt_index].name, opt_arg,
901                           skStringParseStrerror(rv));
902             return 1;
903         }
904         tls_debug_level = (int)tmp32;
905         break;
906     }
907 
908     return 0;
909 }
910 
911 
912 /* Register the switches and their handler function */
913 int
skMsgTlsOptionsRegister(const char * passwd_env_name)914 skMsgTlsOptionsRegister(
915     const char         *passwd_env_name)
916 {
917     password_env_name = passwd_env_name;
918     return skOptionsRegister(sk_tls_options, &skMsgTlsOptionsHandler, NULL);
919 }
920 
921 /* Print usage for the switches */
922 void
skMsgTlsOptionsUsage(FILE * fh)923 skMsgTlsOptionsUsage(
924     FILE               *fh)
925 {
926     unsigned int i, j;
927 
928     fprintf(fh, "\nTransport encryption switches:\n");
929     for (i = 0; sk_tls_options[i].name; ++i) {
930         fprintf(fh, "--%s %s. %s", sk_tls_options[i].name,
931                 SK_OPTION_HAS_ARG(sk_tls_options[i]),
932                 sk_tls_options_help[i]);
933         switch (sk_tls_options[i].val) {
934 #if GNUTLS_VERSION_NUMBER >= 0x030400
935           case OPT_TLS_PRIORITY:
936             {
937                 const unsigned int flag = GNUTLS_PRIORITY_LIST_INIT_KEYWORDS;
938                 const char *s;
939                 fprintf(fh, ".\n\tValues:");
940                 for (j = 0; (s = gnutls_priority_string_list(j, flag)); ++j) {
941                     if (*s) {
942                         fprintf(fh, "%s %s", ((0 == j) ? "" : ","), s);
943                     }
944                 }
945             }
946             break;
947 #endif  /* GNUTLS_VERSION_NUMBER >= 0x030400 */
948           case OPT_TLS_SECURITY:
949             for (j = 0; tls_security_levels[j].name; ++j) {
950                 fprintf(fh, "%s %s",
951                         ((0 == j) ? "" : ","), tls_security_levels[j].name);
952             }
953             break;
954           case OPT_TLS_DEBUG_LEVEL:
955             fprintf(fh, "%d", MAX_TLS_DEBUG_LEVEL);
956             break;
957           default:
958             break;
959         }
960         fprintf(fh, "\n");
961     }
962 }
963 
964 /* Verify the switches. */
965 int
skMsgTlsOptionsVerify(unsigned int * tls_available)966 skMsgTlsOptionsVerify(
967     unsigned int       *tls_available)
968 {
969     sk_stringmap_t *field_map = NULL;
970     sk_stringmap_entry_t *sm_entry;
971     sk_stringmap_status_t sm_err;
972     int rv = 0;
973 
974     if (NULL == tls_security) {
975         tls_security = TLS_SECURITY_DEFAULT;
976     }
977 
978     /* create a stringmap */
979     if ((sm_err = skStringMapCreate(&field_map))
980         || (sm_err = skStringMapAddEntries(field_map, -1, tls_security_levels)))
981     {
982         skAppPrintErr("Unable to create string map: %s",
983                       skStringMapStrerror(sm_err));
984         rv = 1;
985         goto END;
986     }
987 
988     /* attempt to match */
989     sm_err = skStringMapGetByName(field_map, tls_security, &sm_entry);
990     switch (sm_err) {
991       case SKSTRINGMAP_OK:
992         tls_security_param = (gnutls_sec_param_t)sm_entry->id;
993         break;
994       case SKSTRINGMAP_PARSE_AMBIGUOUS:
995         skAppPrintErr("Invalid %s: Field '%s' is ambiguous",
996                       sk_tls_options[OPT_TLS_SECURITY].name, tls_security);
997         rv = 1;
998         break;
999       case SKSTRINGMAP_PARSE_NO_MATCH:
1000         skAppPrintErr("Invalid %s: Field '%s' is not recognized",
1001                       sk_tls_options[OPT_TLS_SECURITY].name, tls_security);
1002         rv = 1;
1003         break;
1004       default:
1005         skAppPrintErr("Unexpected return value from string-map parser (%d)",
1006                       sm_err);
1007         rv = 1;
1008         break;
1009     }
1010 
1011     if (tls_ca_file || tls_cert_file || tls_key_file || tls_pkcs12_file) {
1012         if (!tls_ca_file) {
1013             skAppPrintErr("A certificate authority file must be specified"
1014                           " with --%s when using encryption",
1015                           sk_tls_options[OPT_TLS_CA].name);
1016             rv = 1;
1017         }
1018         if (0 == ((tls_cert_file && tls_key_file) ^ (!!tls_pkcs12_file))) {
1019             skAppPrintErr("When using encryption, you must specify --%s and "
1020                           "--%s, or just --%s",
1021                           sk_tls_options[OPT_TLS_CERT].name,
1022                           sk_tls_options[OPT_TLS_KEY].name,
1023                           sk_tls_options[OPT_TLS_PKCS12].name);
1024             rv = 1;
1025         }
1026     }
1027 
1028     if (tls_ca_file && 0 == rv) {
1029         connection_type = CONN_TLS;
1030     }
1031     if (tls_available) {
1032         *tls_available = (CONN_TLS == connection_type);
1033     }
1034 
1035   END:
1036     skStringMapDestroy(field_map);
1037     return rv;
1038 }
1039 
1040 #endif  /* SK_ENABLE_GNUTLS */
1041 
1042 
1043 /*** TCP functions ***/
1044 
1045 /*
1046  *    Use standard TCP functions to write (send) a message.
1047  *
1048  *    Either this function or tls_send() is used in the
1049  *    writer_thread() to write a message.  The 'transport' member of
1050  *    the sk_msg_conn_queue_t determines which function is used.
1051  *
1052  *    tcp_recv() is the matching function for receiving a message.
1053  */
1054 static int
tcp_send(sk_msg_conn_queue_t * conn,sk_msg_write_buf_t * write_buf)1055 tcp_send(
1056     sk_msg_conn_queue_t    *conn,
1057     sk_msg_write_buf_t     *write_buf)
1058 {
1059     uint8_t *seg_base = NULL;
1060     sk_msg_t *msg;
1061     int retval = 0;
1062     ssize_t rv;
1063 
1064     DEBUG_ENTER_FUNC;
1065 
1066     assert(write_buf);
1067     msg = write_buf->msg;
1068 
1069     assert(msg);
1070     assert(conn);
1071     assert(msg->segments);
1072     assert(msg->segment[0].iov_base == &msg->hdr);
1073     assert(msg->segment[0].iov_len  == sizeof(msg->hdr));
1074 
1075     DEBUG_PRINT3("Sending chan=%#x type=%#x",
1076                  ntohs(msg->hdr.channel), ntohs(msg->hdr.type));
1077 
1078     if (write_buf->seg_offset) {
1079         /* there is a partially written segment; move iov_base to
1080          * start of unwritten data and adjust length of segment  */
1081         assert(msg->segment[write_buf->cur_seg].iov_len
1082                > write_buf->seg_offset);
1083         seg_base = (uint8_t*)msg->segment[write_buf->cur_seg].iov_base;
1084         msg->segment[write_buf->cur_seg].iov_base
1085             = seg_base + write_buf->seg_offset;
1086         msg->segment[write_buf->cur_seg].iov_len -= write_buf->seg_offset;
1087     }
1088 
1089     /* Write the message */
1090     while ((rv = writev(conn->wsocket, msg->segment + write_buf->cur_seg,
1091                         msg->segments - write_buf->cur_seg))
1092            != write_buf->msg_size)
1093     {
1094         if (rv == -1) {
1095             if (errno == EINTR) {
1096                 continue;
1097             }
1098             if (errno == EAGAIN) {
1099                 DEBUG_PRINT1("send: writev returned EAGAIN");
1100                 retval = SKMERR_PARTIAL;
1101                 break;
1102             }
1103             if (errno == EPIPE || errno == ECONNRESET) {
1104                 DEBUG_PRINT3("send: Connection closed due to %d [%s]",
1105                              errno, strerror(errno));
1106                 retval = SKMERR_CLOSED;
1107                 break;
1108             }
1109             conn->last_errnum = errno;
1110             DEBUG_PRINT3("send: System error %d [%s]", errno, strerror(errno));
1111             retval = SKMERR_ERRNO;
1112             break;
1113         }
1114         if (rv == 0) {
1115             DEBUG_PRINT1("send: Connection closed due to write returning 0");
1116             retval = SKMERR_CLOSED;
1117             break;
1118         }
1119         assert(rv < write_buf->msg_size);
1120         DEBUG_PRINT3("send: Handling short write (%" SK_PRIdZ "/%" SK_PRIdZ ")",
1121                      rv, write_buf->msg_size);
1122         retval = SKMERR_PARTIAL;
1123         write_buf->msg_size -= rv;
1124         if (seg_base) {
1125             msg->segment[write_buf->cur_seg].iov_base = seg_base;
1126             seg_base = NULL;
1127             if (msg->segment[write_buf->cur_seg].iov_len > (size_t)rv) {
1128                 /* did not complete writing this segment */
1129                 msg->segment[write_buf->cur_seg].iov_len
1130                     += write_buf->seg_offset;
1131                 write_buf->seg_offset += rv;
1132                 break;
1133             }
1134             /* finished this segment */
1135             rv -= msg->segment[write_buf->cur_seg].iov_len;
1136             msg->segment[write_buf->cur_seg].iov_len += write_buf->seg_offset;
1137             ++write_buf->cur_seg;
1138         }
1139         while (rv > 0
1140                && msg->segment[write_buf->cur_seg].iov_len <= (size_t)rv)
1141         {
1142             rv -= msg->segment[write_buf->cur_seg].iov_len;
1143             ++write_buf->cur_seg;
1144         }
1145         /* partial write for this segment */
1146         write_buf->seg_offset = rv;
1147         break;
1148     }
1149 
1150     if (seg_base) {
1151         msg->segment[write_buf->cur_seg].iov_base = seg_base;
1152         msg->segment[write_buf->cur_seg].iov_len += write_buf->seg_offset;
1153     }
1154 
1155     RETURN(retval);
1156 }
1157 
1158 
1159 /*
1160  *    Use standard TCP functions to read (receive) a message.
1161  *
1162  *    Either this function or tls_recv() is used in the
1163  *    reader_thread() to read a message.  The 'transport' member of
1164  *    the sk_msg_conn_queue_t determines which function is used.
1165  *
1166  *    tcp_send() is the matching function for sending a message.
1167  */
1168 static int
tcp_recv(sk_msg_conn_queue_t * conn,sk_msg_t ** message)1169 tcp_recv(
1170     sk_msg_conn_queue_t    *conn,
1171     sk_msg_t              **message)
1172 {
1173     ssize_t             rv;
1174     sk_msg_read_buf_t  *buffer;
1175     int                 retval;
1176     int                 new_msg;
1177 
1178     DEBUG_ENTER_FUNC;
1179 
1180     assert(message);
1181     assert(conn);
1182 
1183     buffer = &conn->msg_read_buf;
1184     new_msg = (buffer->msg == NULL);
1185     if (new_msg) {
1186         /* Starting to read a new message. */
1187 
1188         sk_msg_hdr_t    *hdr;
1189         sk_msg_t        *msg;
1190         uint8_t         *hdr_buf;
1191         ssize_t          hdr_len;
1192 
1193         /* Create a message structure */
1194         buffer->msg = (sk_msg_t*)malloc(sizeof(sk_msg_t)
1195                                         + sizeof(struct iovec));
1196         msg = buffer->msg;
1197         MEM_ASSERT(msg != NULL);
1198         msg->segments = 1;
1199         msg->simple_free = NULL;
1200         msg->free_fn = msg_simple_free;
1201         hdr = &msg->hdr;
1202         msg->segment[0].iov_base = hdr;
1203         msg->segment[0].iov_len = sizeof(*hdr);
1204         memset(hdr, 0, sizeof(*hdr));
1205 
1206         /* Read a header */
1207         hdr_buf = (uint8_t*)hdr;
1208         hdr_len = (ssize_t)sizeof(*hdr);
1209         while ((rv = read(conn->rsocket, hdr_buf, hdr_len))
1210                != hdr_len)
1211         {
1212             /* Did not get all of the data we expected to get */
1213             if (rv > 0) {
1214                 /* Partial read, reduce number of expected bytes and
1215                  * try again. */
1216                 DEBUG_PRINT3("recv: Partial read of header; trying again"
1217                              " (%" SK_PRIdZ "/%" SK_PRIuZ ")",
1218                              rv, hdr_len);
1219                 hdr_buf += rv;
1220                 hdr_len -= rv;
1221                 continue;
1222             }
1223             if (rv == -1) {
1224                 /* Error.  Close connection unless interrupt. */
1225                 if (errno == EINTR) {
1226                     continue;
1227                 }
1228                 if (errno != EAGAIN) {
1229                     conn->last_errnum = errno;
1230                     DEBUG_PRINT3("recv: System error %d [%s]",
1231                                  errno, strerror(errno));
1232                     retval = SKMERR_ERRNO;
1233                 } else if (sizeof(*hdr) == (size_t)hdr_len) {
1234                     /* Handle EAGAIN on a completely unread header
1235                      * specially. This can happen if poll() says data
1236                      * is available when it actually is not, which can
1237                      * happen on Linux ("spurious readiness") */
1238                     DEBUG_PRINT1("recv: EAGAIN on unread header");
1239                     retval = SKMERR_EMPTY;
1240                 } else {
1241                     /* Handle EAGAIN after we have read part of the
1242                      * header as a short read error. */
1243                     DEBUG_PRINT3("recv: Short read (%" SK_PRIdZ
1244                                  "/%" SK_PRIuZ ") [EAGAIN]",
1245                                  (sizeof(*hdr) - hdr_len), sizeof(*hdr));
1246                     retval = SKMERR_SHORT;
1247                 }
1248             } else if (sizeof(*hdr) == (size_t)hdr_len) {
1249                 /* This read() returned 0, and we do not have any of
1250                  * the header; assume connection is closed. */
1251                 DEBUG_PRINT1("recv: Connection closed due to attempted"
1252                              " read of header returning 0");
1253                 retval = SKMERR_CLOSED;
1254             } else {
1255                 /* This read() returned 0, but we got part of the
1256                  * header on a previous read(); treat as error. */
1257                 DEBUG_PRINT3("recv: Short read (%" SK_PRIdZ "/%" SK_PRIuZ ")",
1258                              (sizeof(*hdr) - hdr_len), sizeof(*hdr));
1259                 retval = SKMERR_SHORT;
1260             }
1261             goto error;
1262         }
1263 
1264         /* Convert network byte order to host byte order */
1265         hdr->channel = ntohs(hdr->channel);
1266         hdr->type    = ntohs(hdr->type);
1267         hdr->size    = ntohs(hdr->size);
1268 
1269         DEBUG_PRINT4("Receiving chan=%#x type=%#x size=%d",
1270                      hdr->channel, hdr->type, hdr->size);
1271 
1272         if (hdr->size == 0) {
1273             /* If the size is zero, the message is complete */
1274             *message = msg;
1275             buffer->msg = NULL;
1276             RETURN(0);
1277         }
1278         /* Allocate space for the body of the message */
1279         msg->segment[1].iov_base = malloc(hdr->size);
1280         MEM_ASSERT(msg->segment[1].iov_base);
1281         msg->segment[1].iov_len = hdr->size;
1282         msg->segments++;
1283 
1284         /* Maintain state for re-entrant call */
1285         buffer->count = hdr->size;
1286         buffer->loc = (uint8_t*)msg->segment[1].iov_base;
1287 
1288         /* Fall through to read the rest of the message.  NOTE: Do not
1289          * fail if the next read() returns 0, since maybe the header
1290          * was the only thing available to read. */
1291     }
1292     /* else, the previous call to recv was a partial read of the
1293      * message. */
1294 
1295     assert(buffer->count);
1296     rv = read(conn->rsocket, buffer->loc, buffer->count);
1297     if (rv == -1) {
1298         if (errno == EINTR || errno == EAGAIN) {
1299             DEBUG_PRINT3("Failed to read %u bytes; return PARTIAL [%s]",
1300                          buffer->count, strerror(errno));
1301             RETURN(SKMERR_PARTIAL);
1302         }
1303         conn->last_errnum = errno;
1304         DEBUG_PRINT3("Failed to read %u bytes; return ERRNO [%s]",
1305                      buffer->count, strerror(errno));
1306         retval = SKMERR_ERRNO;
1307         goto error;
1308     } else if (rv == 0 && !new_msg) {
1309         /* Fail: poll() said data was available, and this first
1310          * attempt to read() data returned 0 bytes. */
1311         DEBUG_PRINT2("Failed to read %u bytes; return CLOSED [EOF]",
1312                      buffer->count);
1313         retval = SKMERR_CLOSED;
1314         goto error;
1315     }
1316 
1317     buffer->count -= rv;
1318     buffer->loc   += rv;
1319 
1320     if (buffer->count != 0) {
1321         /* We've read the available data but the message is not yet
1322          * complete; return to reader_thread() and call poll()
1323          * again */
1324         DEBUG_PRINT2("PARTIAL message, %u bytes remaining", buffer->count);
1325         RETURN(SKMERR_PARTIAL);
1326     }
1327     /* else, message is complete */
1328 
1329     *message = buffer->msg;
1330     buffer->msg = NULL;
1331     RETURN(0);
1332 
1333   error:
1334     if (buffer->msg) {
1335         skMsgDestroy(buffer->msg);
1336         buffer->msg = NULL;
1337     }
1338     RETURN(retval);
1339 }
1340 
1341 
1342 #if SK_ENABLE_GNUTLS
1343 
1344 /*** TLS functions ***/
1345 
1346 /*
1347  *    Use GnuTLS to write (send) a message.
1348  *
1349  *    Either this function or tcp_send() is used in the
1350  *    writer_thread() to write a message.  The 'transport' member of
1351  *    the sk_msg_conn_queue_t determines which function is used.
1352  *
1353  *    tls_recv() is the matching function for receiving a message.
1354  *
1355  *    The call to gnutls_record_send() in this function invokes the
1356  *    tls_push() callback defined below.
1357  */
1358 static int
tls_send(sk_msg_conn_queue_t * conn,sk_msg_write_buf_t * write_buf)1359 tls_send(
1360     sk_msg_conn_queue_t    *conn,
1361     sk_msg_write_buf_t     *write_buf)
1362 {
1363     sk_msg_t *msg;
1364     ssize_t rv;
1365     uint16_t i;
1366 
1367     DEBUG_ENTER_FUNC;
1368 
1369     assert(write_buf);
1370     msg = write_buf->msg;
1371 
1372     assert(msg);
1373     assert(conn);
1374     assert(conn->use_tls);
1375     assert(msg->segments);
1376     assert(msg->segment[0].iov_base == &msg->hdr);
1377     assert(msg->segment[0].iov_len  == sizeof(msg->hdr));
1378 
1379     DEBUG_PRINT3("Sending chan=%#x type=%#x",
1380                  ntohs(msg->hdr.channel), ntohs(msg->hdr.type));
1381 
1382     for (i = 0; i < msg->segments; i++) {
1383         size_t remaining = msg->segment[i].iov_len;
1384         uint8_t *loc = (uint8_t*)msg->segment[i].iov_base;
1385 
1386         /* Write the message */
1387         while (remaining) {
1388             DEBUG_PRINT2("calling gnutls_record_send (%d)",
1389                          (skm_len_t)msg->segment[i].iov_len);
1390             rv = gnutls_record_send(conn->session, loc, remaining);
1391             DEBUG_PRINT2("gnutls_record_send -> %" SK_PRIdZ, rv);
1392             if (rv < 0) {
1393                 if (rv == GNUTLS_E_INTERRUPTED || rv == GNUTLS_E_AGAIN) {
1394                     continue;
1395                 }
1396                 if (rv == GNUTLS_E_PUSH_ERROR) {
1397                     if (errno == EPIPE || errno == ECONNRESET) {
1398                         RETURN(SKMERR_CLOSED);
1399                     }
1400                     conn->last_errnum = errno;
1401                     RETURN(SKMERR_ERRNO);
1402                 }
1403                 conn->last_errnum = rv;
1404                 RETURN(SKMERR_GNUTLS);
1405             } else if (rv == 0) {
1406                 DEBUG_PRINT1("send: Connection closed"
1407                              " due to write returning 0");
1408                 RETURN(SKMERR_CLOSED);
1409             } else {
1410                 remaining -= rv;
1411                 loc += rv;
1412             }
1413         }
1414     }
1415 
1416     RETURN(0);
1417 }
1418 
1419 
1420 /*
1421  *    Use GnuTLS to read (receive) a message.
1422  *
1423  *    Either this function or tcp_recv() is used in the
1424  *    reader_thread() to read a message.  The 'transport' member of
1425  *    the sk_msg_conn_queue_t determines which function is used.
1426  *
1427  *    tls_send() is the matching function for sending a message.
1428  *
1429  *    The call to gnutls_record_recv() in this function invokes the
1430  *    tls_pull() callback defined below.
1431  */
1432 static int
tls_recv(sk_msg_conn_queue_t * conn,sk_msg_t ** message)1433 tls_recv(
1434     sk_msg_conn_queue_t    *conn,
1435     sk_msg_t              **message)
1436 {
1437     ssize_t             rv;
1438     sk_msg_read_buf_t  *buffer;
1439     int                 retval;
1440     int                 new_msg;
1441 
1442     DEBUG_ENTER_FUNC;
1443 
1444     assert(message);
1445     assert(conn);
1446 
1447     buffer = &conn->msg_read_buf;
1448     new_msg = (buffer->msg == NULL);
1449     if (new_msg) {
1450         /* Starting to read a new message */
1451 
1452         sk_msg_hdr_t    *hdr;
1453         sk_msg_t        *msg;
1454         uint8_t         *hdr_buf;
1455         ssize_t          hdr_len;
1456 
1457         /* Create a message structure */
1458         buffer->msg = (sk_msg_t*)malloc(sizeof(sk_msg_t)
1459                                         + sizeof(struct iovec));
1460         msg = buffer->msg;
1461         MEM_ASSERT(msg != NULL);
1462         msg->segments = 1;
1463         msg->simple_free = NULL;
1464         msg->free_fn = msg_simple_free;
1465         hdr = &msg->hdr;
1466         msg->segment[0].iov_base = hdr;
1467         msg->segment[0].iov_len = sizeof(*hdr);
1468         memset(hdr, 0, sizeof(*hdr));
1469 
1470         /* Read a header */
1471         hdr_buf = (uint8_t*)hdr;
1472         hdr_len = (ssize_t)sizeof(*hdr);
1473         DEBUG_PRINT2("calling gnutls_record_recv (%" SK_PRIdZ ")", hdr_len);
1474         while ((rv = gnutls_record_recv(conn->session, hdr_buf, hdr_len))
1475                != hdr_len)
1476         {
1477             DEBUG_PRINT2("gnutls_record_recv -> %" SK_PRIdZ, rv);
1478             if (rv > 0) {
1479                 /* Partial read, reduce number of expected bytes and
1480                  * try again. */
1481                 DEBUG_PRINT3("recv: Partial read of header; trying again"
1482                              " (%" SK_PRIdZ "/%" SK_PRIdZ ")",
1483                              rv, hdr_len);
1484                 hdr_buf += rv;
1485                 hdr_len -= rv;
1486                 DEBUG_PRINT2("calling gnutls_record_recv (%" SK_PRIdZ ")",
1487                              hdr_len);
1488                 continue;
1489             }
1490             if (rv < 0) {
1491                 if (rv == GNUTLS_E_INTERRUPTED || rv == GNUTLS_E_AGAIN) {
1492                     DEBUG_PRINT2("calling gnutls_record_recv (%" SK_PRIdZ ")",
1493                                  hdr_len);
1494                     continue;
1495                 }
1496                 if (rv == GNUTLS_E_PULL_ERROR) {
1497                     conn->last_errnum = errno;
1498                     retval = SKMERR_ERRNO;
1499                 } else {
1500                     /* it seems that GnuTLS 2.x returns
1501                      * GNUTLS_E_UNEXPECTED_PACKET_LENGTH when read()
1502                      * in the tls_pull() function returns 0.  Perhaps
1503                      * we ought to treat that as an ordinary close to
1504                      * avoid an odd error msg in the log file. */
1505                     conn->last_errnum = rv;
1506                     retval = SKMERR_GNUTLS;
1507                 }
1508             } else if (sizeof(*hdr) == (size_t)hdr_len) {
1509                 /* This read() returned 0, and we do not have any of
1510                  * the header; assume connection is closed. */
1511                 DEBUG_PRINT1("recv: Connection closed due to attempted"
1512                              " read of header returning 0");
1513                 retval = SKMERR_CLOSED;
1514             } else {
1515                 /* This read() returned 0, but we got part of the
1516                  * header on a previous read(); treat as error. */
1517                 DEBUG_PRINT3("recv: Short read (%" SK_PRIdZ "/%" SK_PRIuZ ")",
1518                              (sizeof(*hdr) - hdr_len), sizeof(*hdr));
1519                 retval = SKMERR_SHORT;
1520             }
1521             goto error;
1522         }
1523         DEBUG_PRINT2("gnutls_record_recv -> %" SK_PRIdZ, rv);
1524 
1525         /* Convert network byte order to host byte order */
1526         hdr->channel = ntohs(hdr->channel);
1527         hdr->type    = ntohs(hdr->type);
1528         hdr->size    = ntohs(hdr->size);
1529 
1530         DEBUG_PRINT4("Receiving chan=%#x type=%#x size=%d",
1531                      hdr->channel, hdr->type, hdr->size);
1532 
1533         if (hdr->size == 0) {
1534             /* If the size is zero, the message is complete */
1535             *message = msg;
1536             buffer->msg = NULL;
1537             RETURN(0);
1538         }
1539         /* Allocate space for the body of the message */
1540         msg->segment[1].iov_base = malloc(hdr->size);
1541         MEM_ASSERT(msg->segment[1].iov_base);
1542         msg->segment[1].iov_len = hdr->size;
1543         msg->segments++;
1544 
1545         /* Maintain state for re-entrant call */
1546         buffer->count = hdr->size;
1547         buffer->loc = (uint8_t*)msg->segment[1].iov_base;
1548 
1549         /* Fall through to read the rest of the message.  NOTE: Do not
1550          * fail if the next gnutls_record_recv() returns 0, since
1551          * maybe the header was the only thing available. */
1552     }
1553     /* else, the previous call to recv was a partial read of the
1554      * message. */
1555 
1556     assert(buffer->count);
1557     DEBUG_PRINT2("calling gnutls_record_recv (%d)", buffer->count);
1558     while ((rv = gnutls_record_recv(conn->session, buffer->loc, buffer->count))
1559            < 0)
1560     {
1561         DEBUG_PRINT2("gnutls_record_recv -> %" SK_PRIdZ, rv);
1562         if (rv == GNUTLS_E_INTERRUPTED || rv == GNUTLS_E_AGAIN) {
1563             DEBUG_PRINT2("calling gnutls_record_recv (%d)", buffer->count);
1564             continue;
1565         }
1566         if (rv == GNUTLS_E_PULL_ERROR) {
1567             conn->last_errnum = errno;
1568             retval = SKMERR_ERRNO;
1569             DEBUG_PRINT3("read failure: %d [%s]", errno, strerror(errno));
1570         } else {
1571             conn->last_errnum = rv;
1572             retval = SKMERR_GNUTLS;
1573             DEBUG_PRINT2("read failure: [%s]", gnutls_strerror(rv));
1574         }
1575         goto error;
1576     }
1577     DEBUG_PRINT2("gnutls_record_recv -> %" SK_PRIdZ, rv);
1578 
1579     if (rv == 0 && !new_msg) {
1580         /* Fail: poll() said data was available, but this first
1581          * attempt to use gnutls_record_recv() returned 0 bytes. */
1582         DEBUG_PRINT1("read ended: [EOF]");
1583         retval = SKMERR_CLOSED;
1584         goto error;
1585     }
1586 
1587     buffer->count -= rv;
1588     buffer->loc   += rv;
1589 
1590     if (buffer->count != 0) {
1591         /* we've read the available data but the message is not yet
1592          * complete; return to reader_thread() and call poll()
1593          * again */
1594         RETURN(SKMERR_PARTIAL);
1595     }
1596     /* else, message is complete */
1597 
1598     *message = buffer->msg;
1599     buffer->msg = NULL;
1600     RETURN(0);
1601 
1602   error:
1603     if (buffer->msg) {
1604         skMsgDestroy(buffer->msg);
1605         buffer->msg = NULL;
1606     }
1607     RETURN(retval);
1608 }
1609 
1610 #endif /* SK_ENABLE_GNUTLS */
1611 
1612 
1613 #if SK_ENABLE_GNUTLS
1614 /*
1615  *    Since we cannot be certain that GnuTLS was built with pthread
1616  *    support (hello redhat), define our own functions that are copies
1617  *    of the functions gnutls_system_mutex_init(), etc, with
1618  *    pthread-support found in gnutls-3.x/lib/system/threads.c.
1619  */
1620 
1621 /*    Initialize a mutex. */
1622 static int
skMsgGnuTLSMutexInit(void ** priv)1623 skMsgGnuTLSMutexInit(
1624     void             **priv)
1625 {
1626     pthread_mutex_t *lock = (pthread_mutex_t *)malloc(sizeof(pthread_mutex_t));
1627 
1628     *priv = lock;
1629     if (NULL == lock) {
1630         return GNUTLS_E_MEMORY_ERROR;
1631     }
1632     if (pthread_mutex_init(lock, NULL)) {
1633         free(lock);
1634         *priv = NULL;
1635         return GNUTLS_E_LOCKING_ERROR;
1636     }
1637     return 0;
1638 }
1639 
1640 /*    Destroy/Free a mutex. */
1641 static int
skMsgGnuTLSMutexDeinit(void ** priv)1642 skMsgGnuTLSMutexDeinit(
1643     void             **priv)
1644 {
1645     pthread_mutex_destroy((pthread_mutex_t *)*priv);
1646     free(*priv);
1647     *priv = NULL;
1648     return 0;
1649 }
1650 
1651 /*    Lock a mutex. */
1652 static int
skMsgGnuTLSMutexLock(void ** priv)1653 skMsgGnuTLSMutexLock(
1654     void             **priv)
1655 {
1656     if (pthread_mutex_lock((pthread_mutex_t *)*priv)) {
1657         return GNUTLS_E_LOCKING_ERROR;
1658     }
1659     return 0;
1660 }
1661 
1662 /*    Unlock a mutex. */
1663 static int
skMsgGnuTLSMutexUnlock(void ** priv)1664 skMsgGnuTLSMutexUnlock(
1665     void             **priv)
1666 {
1667     if (pthread_mutex_unlock((pthread_mutex_t *)*priv)) {
1668         return GNUTLS_E_LOCKING_ERROR;
1669     }
1670     return 0;
1671 }
1672 #endif  /* SK_ENABLE_GNUTLS */
1673 
1674 
1675 /***********************************************************************/
1676 
1677 
1678 #if SK_ENABLE_GNUTLS
1679 
1680 /* Callback function used for debugging messages */
1681 static void
skMsgGnuTLSDebugLog(int level,const char * msg)1682 skMsgGnuTLSDebugLog(
1683     int                 level,
1684     const char         *msg)
1685 {
1686     INFOMSG("GnuTLS[%d] %s", level, msg);
1687 }
1688 
1689 #if GNUTLS_VERSION_NUMBER >= 0x030000
1690 /* Callback function used for audit messages */
1691 static void
skMsgGnuTLSAuditLog(gnutls_session_t session,const char * msg)1692 skMsgGnuTLSAuditLog(
1693     gnutls_session_t    session,
1694     const char         *msg)
1695 {
1696     SK_UNUSED_PARAM(session);
1697     NOTICEMSG("GnuTLS audit: %s", msg);
1698 }
1699 #endif  /* GNUTLS_VERSION_NUMBER >= 0x030000 */
1700 
1701 #if GNUTLS_VERSION_NUMBER < 0x030406
1702 /* Callback function to verify the peer's certificate during the TLS
1703  * handshake. */
1704 static int
skMsgGnuTLSVerifyPeer(gnutls_session_t session)1705 skMsgGnuTLSVerifyPeer(
1706     gnutls_session_t    session)
1707 {
1708     int rv;
1709     unsigned int status;
1710     char reasonbuf[256];
1711     const char *reason;
1712 
1713     DEBUG_ENTER_FUNC;
1714 
1715     status = 0;
1716     rv = gnutls_certificate_verify_peers2(session, &status);
1717     if (0 == rv && 0 == status) {
1718         RETURN(0);
1719     }
1720     if (rv < 0) {
1721         NOTICEMSG("Failed to verify peer's certificate: %s",
1722                   gnutls_strerror(rv));
1723         RETURN(GNUTLS_E_CERTIFICATE_ERROR);
1724     }
1725 
1726     if (status & GNUTLS_CERT_REVOKED) {
1727         reason = "Certificate is revoked by its authority.";
1728     } else if (status & GNUTLS_CERT_SIGNER_NOT_FOUND) {
1729         reason = "Certificate's issuer is not known.";
1730     } else if (status & GNUTLS_CERT_SIGNER_NOT_CA) {
1731         reason = "Certificate's signer is not a CA.";
1732     } else if (status & GNUTLS_CERT_INSECURE_ALGORITHM) {
1733         reason = "Certificate is signed using an insecure algorithm";
1734     } else if (status & GNUTLS_CERT_NOT_ACTIVATED) {
1735         reason = "Certificate is not yet activated.";
1736     } else if (status & GNUTLS_CERT_EXPIRED) {
1737         reason = "Certificate has expired.";
1738     } else if (status & GNUTLS_CERT_INVALID) {
1739         reason = ("Certificate is not signed by a known authority"
1740                   " or the signature is invalid");
1741     } else {
1742         snprintf(reasonbuf, sizeof(reasonbuf), "Other reason [%#x]",
1743                  status);
1744         reason = reasonbuf;
1745     }
1746 
1747     NOTICEMSG("Certificate verification failed: %s", reason);
1748     RETURN(GNUTLS_E_CERTIFICATE_ERROR);
1749 }
1750 #endif  /* GNUTLS_VERSION_NUMBER < 0x030406 */
1751 
1752 /*
1753  *    If authentication and encryption files were provided, read and
1754  *    load them into the message queue's credentials.
1755  *
1756  *    Do global initialization of GnuTLS if it has not occurred yet.
1757  */
1758 static int
skMsgQueueInitializeGnuTLS(sk_msg_queue_t * queue)1759 skMsgQueueInitializeGnuTLS(
1760     sk_msg_queue_t     *queue)
1761 {
1762 #if GNUTLS_VERSION_NUMBER < 0x030506
1763     static gnutls_dh_params_t dh_params;
1764     unsigned int bits;
1765 #endif  /* GNUTLS_VERSION_NUMBER < 0x030506 */
1766     int rv = 0;
1767 
1768     DEBUG_ENTER_FUNC;
1769 
1770     assert(queue);
1771     assert(queue->root);
1772 
1773     pthread_mutex_lock(&sk_msg_gnutls_mutex);
1774     if (NULL == tls_ca_file) {
1775         assert(NULL == tls_cert_file);
1776         assert(NULL == tls_key_file);
1777         assert(NULL == tls_pkcs12_file);
1778         DEBUGMSG("Skipping GnuTLS initialization since not in use");
1779         goto END;
1780     }
1781     assert((tls_cert_file && tls_key_file) ^ (!!tls_pkcs12_file));
1782 
1783     if (NULL == gnutls_check_version(MIN_GNUTLS_VERISON)) {
1784         CRITMSG("GnuTLS version is less than required %s", MIN_GNUTLS_VERISON);
1785         goto END;
1786     }
1787 
1788     if (!sk_msg_gnutls_initialized) {
1789         gnutls_global_set_mutex(skMsgGnuTLSMutexInit, skMsgGnuTLSMutexDeinit,
1790                                 skMsgGnuTLSMutexLock, skMsgGnuTLSMutexUnlock);
1791         rv = gnutls_global_init();
1792         if (rv < 0) {
1793             WARNINGMSG("Unable to initialize gnutls: %s", gnutls_strerror(rv));
1794             goto END;
1795         }
1796         if (tls_debug_level) {
1797             gnutls_global_set_log_function(skMsgGnuTLSDebugLog);
1798             gnutls_global_set_log_level(tls_debug_level);
1799         }
1800         if (tls_priority) {
1801             const char *err_pos = NULL;
1802             rv = gnutls_priority_init(&tls_priority_cache, tls_priority,
1803                                       &err_pos);
1804             if (rv != GNUTLS_E_SUCCESS) {
1805                 if (rv != GNUTLS_E_INVALID_REQUEST || NULL == err_pos) {
1806                     CRITMSG("Unable to initialize gnutls priority: %s",
1807                             gnutls_strerror(rv));
1808                 } else {
1809                     CRITMSG("Invalid %s: Error at '%s'",
1810                             sk_tls_options[OPT_TLS_PRIORITY].name, err_pos);
1811                 }
1812                 goto END;
1813             }
1814         }
1815 #if GNUTLS_VERSION_NUMBER >= 0x030000
1816         gnutls_global_set_audit_log_function(skMsgGnuTLSAuditLog);
1817 #endif
1818 #if GNUTLS_VERSION_NUMBER < 0x030506
1819         /* Generate Diffie-Hellman parameters */
1820         bits = gnutls_sec_param_to_pk_bits(GNUTLS_PK_DH, tls_security_param);
1821         if (0 == bits) {
1822             CRITMSG(("Programmer error: Unable to determine number of bits"
1823                      " for algorithm %u at security level %u in"
1824                      " GnuTLS version %s"),
1825                     GNUTLS_PK_DH, tls_security_param,
1826                     gnutls_check_version(NULL));
1827             rv = -1;
1828             goto END;
1829         }
1830         rv = gnutls_dh_params_init(&dh_params);
1831         if (rv < 0) {
1832             CRITMSG("Unable to initialize Diffie-Hellman parameters: %s",
1833                     gnutls_strerror(rv));
1834             goto END;
1835         }
1836         INFOMSG("Generating Diffie-Hellman parameters...");
1837         INFOMSG("This could take some time...");
1838         rv = gnutls_dh_params_generate2(dh_params, bits);
1839         if (rv < 0) {
1840             CRITMSG("Unable to generate Diffie-Hellman parameters: %s",
1841                     gnutls_strerror(rv));
1842             goto END;
1843         }
1844         INFOMSG("Finished generating Diffie-Hellman parameters");
1845 #endif  /* GNUTLS_VERSION_NUMBER < 0x030506 */
1846         sk_msg_gnutls_initialized = 1;
1847     }
1848 
1849     /* allocate credentials and set Diffie-Hellman parameters */
1850     if (!queue->root->cred_set) {
1851         rv = gnutls_certificate_allocate_credentials(&queue->root->cred);
1852         if (rv < 0) {
1853             WARNINGMSG("Unable to allocate credentials: %s",
1854                        gnutls_strerror(rv));
1855             goto END;
1856         }
1857 #if GNUTLS_VERSION_NUMBER < 0x030506
1858         gnutls_certificate_set_dh_params(queue->root->cred, dh_params);
1859 #else
1860         rv = (gnutls_certificate_set_known_dh_params(
1861                   queue->root->cred, GNUTLS_SEC_PARAM_MEDIUM));
1862         if (rv < 0) {
1863             WARNINGMSG(
1864                 "Unable to set credentials' Diffie-Hellman parameters: %s",
1865                 gnutls_strerror(rv));
1866             goto END;
1867         }
1868 #endif  /* GNUTLS_VERSION_NUMBER < 0x030506 */
1869     }
1870 
1871     /* add the trusted CAs from 'tls_ca_file'; the file should
1872      * be in PEM format. */
1873     rv = (gnutls_certificate_set_x509_trust_file(
1874               queue->root->cred, tls_ca_file, GNUTLS_X509_FMT_PEM));
1875     if (rv < 0) {
1876         CRITMSG("Invalid Certificate Authority file '%s': %s",
1877                 tls_ca_file, gnutls_strerror(rv));
1878         goto END;
1879     }
1880 
1881     /* add the revocation file from 'tls_crl_file' if provided; the
1882      * file should be in PEM format. */
1883     if (tls_crl_file) {
1884         rv = (gnutls_certificate_set_x509_crl_file(
1885                   queue->root->cred, tls_crl_file, GNUTLS_X509_FMT_PEM));
1886         if (rv < 0) {
1887             CRITMSG("Invalid Certificate Revocation List file '%s': %s",
1888                     tls_crl_file, gnutls_strerror(rv));
1889             goto END;
1890         }
1891     }
1892 
1893 #if  GNUTLS_VERSION_NUMBER < 0x030406
1894     /* set the callback function used to verify the peers'
1895      * certificates */
1896     gnutls_certificate_set_verify_function(
1897         queue->root->cred, &skMsgGnuTLSVerifyPeer);
1898 #endif  /* GNUTLS_VERSION_NUMBER < 0x030406 */
1899 
1900     /* add either the PKCS12 file or the certificate and key files */
1901     if (tls_pkcs12_file) {
1902         const char *password = getenv(password_env_name);
1903         rv = (gnutls_certificate_set_x509_simple_pkcs12_file(
1904                   queue->root->cred, tls_pkcs12_file,
1905                   GNUTLS_X509_FMT_DER, password));
1906         if (rv < 0) {
1907             CRITMSG("Invalid encryption cert file '%s': %s",
1908                     tls_pkcs12_file, gnutls_strerror(rv));
1909             goto END;
1910         }
1911     } else {
1912         /* set a certificate/private-key pair in the credential from
1913          * separate certificate and key files, each in PEM format. */
1914         rv = (gnutls_certificate_set_x509_key_file(
1915                   queue->root->cred, tls_cert_file, tls_key_file,
1916                   GNUTLS_X509_FMT_PEM));
1917         if (rv < 0) {
1918             CRITMSG("Invalid encryption cert or key file '%s', '%s': %s",
1919                     tls_cert_file, tls_key_file, gnutls_strerror(rv));
1920             goto END;
1921         }
1922     }
1923 
1924     if (!queue->root->cred_set) {
1925         queue->root->cred_set = 1;
1926         ++sk_msg_gnutls_initialized;
1927     }
1928 
1929     rv = 0;
1930 
1931   END:
1932     if (0 != rv) {
1933         rv = -1;
1934         if (queue->root->cred_set) {
1935             gnutls_certificate_free_credentials(queue->root->cred);
1936             queue->root->cred_set = 0;
1937         }
1938     }
1939     pthread_mutex_unlock(&sk_msg_gnutls_mutex);
1940     RETURN(rv);
1941 }
1942 
1943 
1944 void
skMsgGnuTLSTeardown(void)1945 skMsgGnuTLSTeardown(
1946     void)
1947 {
1948     DEBUG_ENTER_FUNC;
1949 
1950     pthread_mutex_lock(&sk_msg_gnutls_mutex);
1951 
1952     if (tls_priority_cache) {
1953         gnutls_priority_deinit(tls_priority_cache);
1954         tls_priority_cache = NULL;
1955     }
1956     if (sk_msg_gnutls_initialized) {
1957         gnutls_global_deinit();
1958         sk_msg_gnutls_initialized = 0;
1959     }
1960 
1961     pthread_mutex_unlock(&sk_msg_gnutls_mutex);
1962     RETURN_VOID;
1963 }
1964 
1965 #endif /* SK_ENABLE_GNUTLS */
1966 
1967 
1968 /* Create a channel within a message queue. */
1969 static sk_msg_channel_queue_t *
create_channel(sk_msg_queue_t * q)1970 create_channel(
1971     sk_msg_queue_t     *q)
1972 {
1973     sk_msg_channel_queue_t *chan;
1974     int rv;
1975 
1976     DEBUG_ENTER_FUNC;
1977 
1978     assert(q);
1979     ASSERT_QUEUE_LOCK(q);
1980 
1981     /* Allocate space for a new channel */
1982     chan = (sk_msg_channel_queue_t *)calloc(1, sizeof(*chan));
1983     MEM_ASSERT(chan != NULL);
1984 
1985     chan->queue = mqCreateQueue(q->group);
1986     MEM_ASSERT(chan->queue != NULL);
1987 
1988     /* Assign a local channel number and add the channel to the
1989        message queue */
1990     do {
1991         chan->channel = q->root->next_channel++;
1992         rv = int_dict_set(q->root->channel, chan->channel, &chan);
1993     } while (rv == 1);
1994     MEM_ASSERT(rv == 0);
1995 
1996     /* Channel is created, rchannel is unset */
1997     chan->state = SKM_CREATED;
1998     chan->rchannel = SKMSG_CHANNEL_CONTROL;
1999     rv = pthread_cond_init(&chan->pending, NULL);
2000     XASSERT(rv == 0);
2001     chan->is_pending = 0;
2002 
2003     rv = int_dict_set(q->root->groups, chan->channel, &q);
2004     MEM_ASSERT(rv == 0);
2005     rv = int_dict_set(q->channel, chan->channel, &chan);
2006     MEM_ASSERT(rv == 0);
2007     chan->group = q;
2008 
2009     DEBUG_PRINT2("create_channel() = %#x", chan->channel);
2010 
2011     RETURN(chan);
2012 }
2013 
2014 
2015 /* Attach a channel to a connection object. */
2016 static int
set_channel_connecting(sk_msg_queue_t * q,sk_msg_channel_queue_t * chan,sk_msg_conn_queue_t * conn)2017 set_channel_connecting(
2018     sk_msg_queue_t         *q,
2019     sk_msg_channel_queue_t *chan,
2020     sk_msg_conn_queue_t    *conn)
2021 {
2022     int rv;
2023 
2024     DEBUG_ENTER_FUNC;
2025     SK_UNUSED_PARAM(q);
2026 
2027     assert(q);
2028     assert(chan);
2029     assert(conn);
2030     ASSERT_QUEUE_LOCK(q);
2031     assert(chan->state == SKM_CREATED);
2032     assert(conn->state != SKM_CLOSED);
2033 
2034     DEBUG_PRINT2("set_channel_connecting(%#x)", chan->channel);
2035 
2036     /* Set the channel's communication stream, set it to
2037        half-connected, and up the refcount on the connection
2038        object. */
2039     chan->conn = conn;
2040     chan->state = SKM_CONNECTING;
2041 
2042     /* Add an entry in the connections's channel map for the
2043        channel. */
2044     rv = int_dict_set(conn->channelmap, chan->channel, &chan);
2045     MEM_ASSERT(rv != -1);
2046     assert(rv == 0);
2047 
2048     conn->state = SKM_CONNECTED;
2049     conn->refcount++;
2050 
2051     RETURN(0);
2052 }
2053 
2054 
2055 /* Return 1 if the connection was destroyed, zero otherwise */
2056 static int
set_channel_closed(sk_msg_queue_t * q,sk_msg_channel_queue_t * chan,int no_destroy)2057 set_channel_closed(
2058     sk_msg_queue_t         *q,
2059     sk_msg_channel_queue_t *chan,
2060     int                     no_destroy)
2061 {
2062     sk_msg_conn_queue_t *conn;
2063 
2064     DEBUG_ENTER_FUNC;
2065 
2066     assert(q);
2067     assert(chan);
2068     ASSERT_QUEUE_LOCK(q);
2069     assert(chan->conn);
2070     if (chan->state == SKM_CLOSED) {
2071         RETURN(0);
2072     }
2073     assert(chan->conn->refcount > 0);
2074     assert(chan->state == SKM_CONNECTING ||
2075            chan->state == SKM_CONNECTED);
2076 
2077     DEBUG_PRINT2("set_channel_closed(%#x)", chan->channel);
2078 
2079     conn = chan->conn;
2080 
2081     if (chan->state == SKM_CONNECTED &&
2082         chan->channel != SKMSG_CHANNEL_CONTROL)
2083     {
2084         skm_channel_t lchannel = htons(chan->channel);
2085         DEBUG_PRINT1("Sending SKMSG_CTL_CHANNEL_DIED (Internal)");
2086         send_message(q, SKMSG_CHANNEL_CONTROL,
2087                      SKMSG_CTL_CHANNEL_DIED,
2088                      &lchannel, sizeof(lchannel), SKM_SEND_INTERNAL,
2089                      0, NULL);
2090     }
2091 
2092     ASSERT_RESULT(int_dict_del(conn->channelmap, chan->channel),
2093                   int, 0);
2094 
2095     chan->state = SKM_CLOSED;
2096     conn->refcount--;
2097 
2098     /* Notify people waiting on this channel to complete connecting
2099        that it is dead. */
2100     MUTEX_BROADCAST(&chan->pending);
2101 
2102     if (conn->refcount == 0 && !no_destroy) {
2103         RETURN(destroy_connection(q, conn));
2104     }
2105 
2106     RETURN(0);
2107 }
2108 
2109 
2110 static int
set_channel_connected(sk_msg_queue_t * q,sk_msg_channel_queue_t * chan,skm_channel_t rchannel)2111 set_channel_connected(
2112     sk_msg_queue_t         *q,
2113     sk_msg_channel_queue_t *chan,
2114     skm_channel_t           rchannel)
2115 {
2116     DEBUG_ENTER_FUNC;
2117     SK_UNUSED_PARAM(q);
2118 
2119     assert(q);
2120     assert(chan);
2121     ASSERT_QUEUE_LOCK(q);
2122     assert(chan->state == SKM_CONNECTING);
2123 
2124     DEBUG_PRINT2("set_channel_connected(%#x)", chan->channel);
2125 
2126     chan->rchannel = rchannel;
2127 
2128     chan->state = SKM_CONNECTED;
2129     RETURN(0);
2130 }
2131 
2132 
2133 static void
destroy_channel(sk_msg_queue_t * q,sk_msg_channel_queue_t * chan)2134 destroy_channel(
2135     sk_msg_queue_t         *q,
2136     sk_msg_channel_queue_t *chan)
2137 {
2138     DEBUG_ENTER_FUNC;
2139 
2140     assert(q);
2141     assert(chan);
2142     ASSERT_QUEUE_LOCK(q);
2143 
2144     DEBUG_PRINT2("destroy_channel(%#x)", chan->channel);
2145 
2146     if (chan->state == SKM_CONNECTED &&
2147         chan->channel != SKMSG_CHANNEL_CONTROL)
2148     {
2149         skm_channel_t rchannel = htons(chan->rchannel);
2150         DEBUG_PRINT1("Sending SKMSG_CTL_CHANNEL_KILL (Ext-control)");
2151         send_message(q, chan->channel, SKMSG_CTL_CHANNEL_KILL,
2152                      &rchannel, sizeof(rchannel), SKM_SEND_CONTROL,
2153                      0, NULL);
2154     }
2155     if (chan->state == SKM_CONNECTED || chan->state == SKM_CONNECTING) {
2156         set_channel_closed(q, chan, 0);
2157     }
2158 
2159     assert(chan->state == SKM_CLOSED);
2160 
2161     ASSERT_RESULT(int_dict_del(q->root->channel, chan->channel), int, 0);
2162 
2163     ASSERT_RESULT(int_dict_del(q->root->groups, chan->channel), int, 0);
2164     ASSERT_RESULT(int_dict_del(chan->group->channel, chan->channel), int, 0);
2165 
2166     ASSERT_RESULT(pthread_cond_destroy(&chan->pending), int, 0);
2167 
2168     /* Disable adding to the queue (it will be destroyed when the
2169        group is destroyed) */
2170     mqQueueDisable(chan->queue, MQ_ADD);
2171 
2172     free(chan);
2173 
2174     RETURN_VOID;
2175 }
2176 
2177 
2178 #if SK_ENABLE_GNUTLS
2179 #if SK_TLS_USE_CUSTOM_PULL_PUSH
2180 /*
2181  *    A callback function used by GnuTLS for receiving (reading) data.
2182  *
2183  *    This callback is invoked by gnutls_record_recv(), which is
2184  *    called by the tls_recv() function, which is called by the
2185  *    reader_thread().
2186  */
2187 static ssize_t
tls_pull(gnutls_transport_ptr_t fd,void * buf,size_t len)2188 tls_pull(
2189     gnutls_transport_ptr_t  fd,
2190     void                   *buf,
2191     size_t                  len)
2192 {
2193     struct pollfd pfd;
2194     ssize_t       rv;
2195 
2196     pfd.fd = (intptr_t)fd;
2197     pfd.events = POLLIN;
2198 
2199     rv = poll(&pfd, 1, TLS_POLL_TIMEOUT);
2200     if (0 == rv) {
2201         /* poll() timed out.  According to GnuTLS docs, this function
2202          * should act like recv(2) in that case and set errno to
2203          * EAGAIN and return -1 */
2204         errno = EAGAIN;
2205         rv = -1;
2206     } else if (1 == rv) {
2207         rv = read(pfd.fd, buf, len);
2208 
2209 #if (SENDRCV_DEBUG) & DEBUG_RWTRANSFER_PROTOCOL
2210         if (pfd.revents & (POLLERR | POLLNVAL | POLLHUP)) {
2211             /* Log but ignore poll error events for now (allow the
2212              * read to fail).  */
2213             DEBUG_PRINT2("Poll returned %s", SK_POLL_EVENT_STR(pfd.revents));
2214         }
2215         if (-1 == rv) {
2216             DEBUG_PRINT2("Returning -1 from tls_pull (read()) [errno = %d]",
2217                          errno);
2218         } else if (0 == rv) {
2219             DEBUG_PRINT1("Returning 0 from tls_pull (read())");
2220         }
2221     } else {
2222         if (-1 == rv) {
2223             DEBUG_PRINT2("Returning -1 from tls_pull (poll()) [errno = %d]",
2224                          errno);
2225         } else {
2226             DEBUG_PRINT2("Returning %" SK_PRIdZ " from tls_pull (poll())", rv);
2227         }
2228 #endif  /*  (SENDRCV_DEBUG) & DEBUG_RWTRANSFER_PROTOCOL */
2229     }
2230     return rv;
2231 }
2232 
2233 /*
2234  *    A callback function used by GnuTLS for sending (writing) data.
2235  *
2236  *    This callback is invoked by gnutls_record_send(), which is
2237  *    called by our tls_send() function, which is called by the
2238  *    writer_thread().
2239  */
2240 static ssize_t
tls_push(gnutls_transport_ptr_t fd,const void * buf,size_t len)2241 tls_push(
2242     gnutls_transport_ptr_t  fd,
2243     const void             *buf,
2244     size_t                  len)
2245 {
2246     struct pollfd pfd;
2247     ssize_t       rv;
2248 
2249     pfd.fd = (intptr_t)fd;
2250     pfd.events = POLLOUT;
2251 
2252     rv = poll(&pfd, 1, TLS_POLL_TIMEOUT);
2253     if (1 == rv) {
2254         rv = write(pfd.fd, buf, len);
2255 
2256 #if (SENDRCV_DEBUG) & DEBUG_RWTRANSFER_PROTOCOL
2257         if (pfd.revents & (POLLERR | POLLNVAL | POLLHUP)) {
2258             /* Log but ignore poll error events for now (allow the
2259              * write to fail).  */
2260             DEBUG_PRINT2("Poll returned %s", SK_POLL_EVENT_STR(pfd.revents));
2261         }
2262         if (-1 == rv) {
2263             DEBUG_PRINT2("Returning -1 from tls_push (write()) [errno = %d]",
2264                          errno);
2265         } else if (0 == rv) {
2266             DEBUG_PRINT1("Returning 0 from tls_push (write())");
2267         }
2268     } else {
2269         if (-1 == rv) {
2270             DEBUG_PRINT2("Returning -1 from tls_push (poll()) [errno = %d]",
2271                          errno);
2272         } else {
2273             DEBUG_PRINT2("Returning %" SK_PRIdZ " from tls_push (poll())", rv);
2274         }
2275 #endif  /* (SENDRCV_DEBUG) & DEBUG_RWTRANSFER_PROTOCOL */
2276     }
2277     /* What should happen if result of poll() is 0 (timed out)?
2278      * Currently we just return 0, as if there was a zero write.
2279      * Possibly we should return -1 with an errno of EAGAIN. */
2280     return rv;
2281 }
2282 #endif  /* SK_TLS_USE_CUSTOM_PULL_PUSH */
2283 
2284 static int
setup_tls(sk_msg_queue_t * q,sk_msg_conn_queue_t * conn,int rsocket,int wsocket,skm_tls_type_t tls)2285 setup_tls(
2286     sk_msg_queue_t         *q,
2287     sk_msg_conn_queue_t    *conn,
2288     int                     rsocket,
2289     int                     wsocket,
2290     skm_tls_type_t          tls)
2291 {
2292     int rv;
2293 
2294     DEBUG_ENTER_FUNC;
2295 
2296     assert(q);
2297     assert(conn);
2298     assert(q->root->cred_set);
2299 
2300     /* initialize the TLS session object depending on whether it is
2301      * the client or the server */
2302     switch (tls) {
2303       case SKM_TLS_CLIENT:
2304 #if GNUTLS_VERSION_NUMBER < 0x030000
2305         rv = gnutls_init(&conn->session, GNUTLS_CLIENT);
2306 #else
2307         rv = gnutls_init(&conn->session, GNUTLS_CLIENT | GNUTLS_NONBLOCK);
2308 #endif  /* #else of GNUTLS_VERSION_NUMBER < 0x030000 */
2309         break;
2310       case SKM_TLS_SERVER:
2311 #if GNUTLS_VERSION_NUMBER < 0x030000
2312         rv = gnutls_init(&conn->session, GNUTLS_SERVER);
2313 #else
2314         rv = gnutls_init(&conn->session, GNUTLS_SERVER | GNUTLS_NONBLOCK);
2315 #endif  /* #else of GNUTLS_VERSION_NUMBER < 0x030000 */
2316         break;
2317       default:
2318         skAbortBadCase(tls);
2319     }
2320     if (rv < 0) {
2321         ERRMSG("Unable to initialize TLS in the session: %s",
2322                gnutls_strerror(rv));
2323         RETURN(-1);
2324     }
2325 
2326     /* tell the session to use the public/private keys loaded earlier */
2327     rv = gnutls_credentials_set(conn->session, GNUTLS_CRD_CERTIFICATE,
2328                                 q->root->cred);
2329     if (rv < 0) {
2330         ERRMSG("Unable to set TLS credentials in the session: %s",
2331                gnutls_strerror(rv));
2332         RETURN(-1);
2333     }
2334 
2335     /* set the priority */
2336     if (tls_priority_cache) {
2337         rv = gnutls_priority_set(conn->session, tls_priority_cache);
2338     } else {
2339         /* use "NORMAL" priority */
2340         rv = gnutls_set_default_priority(conn->session);
2341     }
2342     if (rv < 0) {
2343         ERRMSG("Unable to initialize TLS priority in the session: %s",
2344                gnutls_strerror(rv));
2345         RETURN(-1);
2346     }
2347 
2348     if (rsocket != wsocket) {
2349         WARNINGMSG("Unexpected found read socket %d != write socket %d",
2350                    rsocket, wsocket);
2351         gnutls_transport_set_ptr2(conn->session,
2352                                   (gnutls_transport_ptr_t)(intptr_t)rsocket,
2353                                   (gnutls_transport_ptr_t)(intptr_t)wsocket);
2354     } else {
2355         gnutls_transport_set_ptr(conn->session,
2356                                  (gnutls_transport_ptr_t)(intptr_t)rsocket);
2357     }
2358 
2359 #if SK_TLS_USE_CUSTOM_PULL_PUSH
2360     /* tell TLS to use our read and write functions (tls_pull,
2361      * tls_push) instead of the defaults.  The call to set_ptr2()
2362      * controls what will be passed to tls_pull() and tls_push(). */
2363     gnutls_transport_set_pull_function(conn->session, tls_pull);
2364     gnutls_transport_set_push_function(conn->session, tls_push);
2365 #endif  /* SK_TLS_USE_CUSTOM_PULL_PUSH */
2366 #if GNUTLS_VERSION_NUMBER < 0x030000
2367     /* Used to set the low water value in order for select() to check
2368      * if there are data pending to be read from the socket buffer. */
2369     gnutls_transport_set_lowat(conn->session, 0);
2370 #endif  /* GNUTLS_VERSION_NUMBER < 0x030000 */
2371     set_nonblock(rsocket);
2372 
2373     /* force the client to send its certificate to the server */
2374     if (tls == SKM_TLS_SERVER) {
2375         gnutls_certificate_server_set_request(conn->session,
2376                                               GNUTLS_CERT_REQUIRE);
2377     }
2378 
2379 #if GNUTLS_VERSION_NUMBER >= 0x030406
2380     gnutls_session_set_verify_cert(conn->session, NULL, 0);
2381 #endif
2382 
2383     DEBUG_PRINT1("Attempting TLS handshake");
2384     while ((rv = gnutls_handshake(conn->session)) < 0
2385            && gnutls_error_is_fatal(rv) == 0)
2386     {
2387         DEBUG_PRINT2("Re-attempting TLS handshake on non-fatal error %s",
2388                      gnutls_strerror(rv));
2389     }
2390     if (rv < 0) {
2391         if (rv == GNUTLS_E_PUSH_ERROR) {
2392             NOTICEMSG("Remote side disconnected during TLS handshake."
2393                       " Certificate may have been rejected.");
2394         } else if (rv == GNUTLS_E_CERTIFICATE_ERROR) {
2395             NOTICEMSG("TLS handshake failed while verifying the certificate");
2396         }
2397 #if GNUTLS_VERSION_NUMBER >= 0x030406
2398         else if (rv == GNUTLS_E_CERTIFICATE_VERIFICATION_ERROR) {
2399             int type;
2400             unsigned int status;
2401             gnutls_datum_t out;
2402 
2403             /* check certificate verification status */
2404             type = gnutls_certificate_type_get(conn->session);
2405             status = gnutls_session_get_verify_cert_status(conn->session);
2406             if (gnutls_certificate_verification_status_print(
2407                     status, (gnutls_certificate_type_t)type, &out, 0) == 0)
2408             {
2409                 NOTICEMSG("Failed to verify peer's certificate: %s", out.data);
2410                 gnutls_free(out.data);
2411             } else {
2412                 NOTICEMSG("Failed to verify peer's certificate"
2413                           " during TLS handshake");
2414             }
2415         }
2416 #endif  /* GNUTLS_VERSION_NUMBER >= 0x030406 */
2417         else {
2418             NOTICEMSG(("TLS handshake failed"
2419                        " (peer may have rejected the certificate): %s"),
2420                       gnutls_strerror(rv));
2421         }
2422 
2423         gnutls_deinit(conn->session);
2424         RETURN(-1);
2425     }
2426     DEBUG_PRINT1("TLS handshake succeeded");
2427 
2428     conn->use_tls = 1;
2429 
2430     RETURN(0);
2431 }
2432 #endif /* SK_ENABLE_GNUTLS */
2433 
2434 
2435 static int
create_connection(sk_msg_queue_t * q,int rsocket,int wsocket,struct sockaddr * addr,socklen_t addrlen,sk_msg_conn_queue_t ** rconn,skm_tls_type_t tls)2436 create_connection(
2437     sk_msg_queue_t         *q,
2438     int                     rsocket,
2439     int                     wsocket,
2440     struct sockaddr        *addr,
2441     socklen_t               addrlen,
2442     sk_msg_conn_queue_t   **rconn,
2443     skm_tls_type_t          tls)
2444 {
2445     sk_msg_conn_queue_t *conn;
2446     sk_queue_and_conn_t *qac;
2447     int rv;
2448 
2449     DEBUG_ENTER_FUNC;
2450 
2451     assert(q);
2452     assert(rconn);
2453     ASSERT_QUEUE_LOCK(q);
2454 
2455     DEBUG_PRINT3("create_connection() = %d, %d", rsocket, wsocket);
2456 
2457     /* Allocate space for the connection */
2458     conn = (sk_msg_conn_queue_t *)calloc(1, sizeof(*conn));
2459     MEM_ASSERT(conn != NULL);
2460 
2461 #if SK_ENABLE_GNUTLS
2462     if (tls != SKM_TLS_NONE) {
2463         rv = setup_tls(q, conn, rsocket, wsocket, tls);
2464         if (rv != 0) {
2465             free(conn);
2466             RETURN(-1);
2467         }
2468     }
2469 #endif /* SK_ENABLE_GNUTLS */
2470 
2471     /* Set the read and write sockets */
2472     conn->rsocket = rsocket;
2473     conn->wsocket = wsocket;
2474 
2475     /* And the address */
2476     conn->addr = addr;
2477     conn->addrlen = addrlen;
2478 
2479     conn->transport = (tls == SKM_TLS_NONE) ? CONN_TCP : CONN_TLS;
2480 
2481     /* Set up the channel queue and refcount */
2482     conn->channelmap = int_dict_create(sizeof(sk_msg_channel_queue_t *));
2483     MEM_ASSERT(conn->channelmap != NULL);
2484     conn->refcount = 0;
2485 
2486     /* Set the connections initial state */
2487     conn->state = SKM_CREATED;
2488 
2489     /* Set up the write queue */
2490     conn->queue = skDequeCreate();
2491     XASSERT(conn->queue != NULL);
2492 
2493     /* Initialize the writer thread start state */
2494     pthread_cond_init(&conn->writer_cond, NULL);
2495     conn->writer_state = SKM_THREAD_BEFORE;
2496 
2497     /* Initialize the reader thread start state */
2498     pthread_cond_init(&conn->reader_cond, NULL);
2499     conn->reader_state = SKM_THREAD_BEFORE;
2500 
2501     /* Set up and start the writer thread */
2502     qac = (sk_queue_and_conn_t *)malloc(sizeof(*qac));
2503     MEM_ASSERT(qac != NULL);
2504     qac->q = q;
2505     qac->conn = conn;
2506     THREAD_START("skmsg_writer", rv, q, &conn->writer, writer_thread, qac);
2507     XASSERT(rv == 0);
2508 
2509     /* Start the reader thread, */
2510     qac = (sk_queue_and_conn_t *)malloc(sizeof(*qac));
2511     MEM_ASSERT(qac != NULL);
2512     qac->q = q;
2513     qac->conn = conn;
2514     THREAD_START("skmsg_reader", rv, q, &conn->reader, reader_thread, qac);
2515     XASSERT(rv == 0);
2516 
2517     *rconn = conn;
2518     RETURN(0);
2519 }
2520 
2521 static void
start_connection(sk_msg_queue_t * q,sk_msg_conn_queue_t * conn)2522 start_connection(
2523     sk_msg_queue_t         *q,
2524     sk_msg_conn_queue_t    *conn)
2525 {
2526     DEBUG_ENTER_FUNC;
2527     SK_UNUSED_PARAM(q);
2528 
2529     assert(q);
2530     assert(conn);
2531     ASSERT_QUEUE_LOCK(q);
2532 
2533     assert(conn->reader_state == SKM_THREAD_BEFORE);
2534     assert(conn->writer_state == SKM_THREAD_BEFORE);
2535     conn->reader_state = SKM_THREAD_RUNNING;
2536     conn->writer_state = SKM_THREAD_RUNNING;
2537     MUTEX_BROADCAST(&conn->reader_cond);
2538     MUTEX_BROADCAST(&conn->writer_cond);
2539 
2540     RETURN_VOID;
2541 }
2542 
2543 static void
unblock_connection(sk_msg_queue_t * q,sk_msg_conn_queue_t * conn)2544 unblock_connection(
2545     sk_msg_queue_t         *q,
2546     sk_msg_conn_queue_t    *conn)
2547 {
2548     static sk_msg_t unblocker = {{SKMSG_CHANNEL_CONTROL,
2549                                   SKMSG_WRITER_UNBLOCKER, 0},
2550                                  NULL, NULL, 1, {{NULL, 0}}};
2551     skDQErr_t err;
2552 
2553     DEBUG_ENTER_FUNC;
2554     SK_UNUSED_PARAM(q);
2555 
2556     assert(q);
2557     assert(conn);
2558     ASSERT_QUEUE_LOCK(q);
2559 
2560     /* Add a special messaage to the writers queue to guarantee it
2561        will unblock */
2562     DEBUG_PRINT1("Sending SKMSG_WRITER_UNBLOCKER message");
2563     err = skDequePushBack(conn->queue, &unblocker);
2564     XASSERT(err == SKDQ_SUCCESS);
2565 
2566     RETURN_VOID;
2567 }
2568 
2569 
2570 /* Stops and destroys a connection.  A return value of 0 means the
2571  * connection object still exists; another thread is destroying that
2572  * connection right now.  A return value of 1 means the connection has
2573  * been destroyed.  */
2574 static int
destroy_connection(sk_msg_queue_t * q,sk_msg_conn_queue_t * conn)2575 destroy_connection(
2576     sk_msg_queue_t         *q,
2577     sk_msg_conn_queue_t    *conn)
2578 {
2579     sk_msg_channel_queue_t *chan;
2580     void *cont;
2581     sk_msg_t *msg;
2582     skDQErr_t err;
2583     pthread_t self;
2584 
2585     DEBUG_ENTER_FUNC;
2586 
2587     assert(q);
2588     assert(conn);
2589     ASSERT_QUEUE_LOCK(q);
2590 
2591     DEBUG_PRINT3("destroy_connection() = %d, %d",
2592                  conn->rsocket, conn->wsocket);
2593 
2594     /* Check to see if this connection is already being shut down */
2595     if (conn->state == SKM_CLOSED) {
2596         RETURN(0);
2597     }
2598 
2599     /* Okay, start closing */
2600     conn->state = SKM_CLOSED;
2601     conn->writer_state = SKM_THREAD_SHUTTING_DOWN;
2602     conn->reader_state = SKM_THREAD_SHUTTING_DOWN;
2603     unblock_connection(q, conn);
2604 
2605     /* Empty the queue */
2606     while ((err = skDequePopBackNB(conn->queue, (void**)&msg))
2607            == SKDQ_SUCCESS)
2608     {
2609         skMsgDestroy(msg);
2610     }
2611     assert(err == SKDQ_EMPTY);
2612 
2613     /* Shut down the queue */
2614     ASSERT_RESULT(skDequeUnblock(conn->queue), skDQErr_t, SKDQ_SUCCESS);
2615 
2616     /* Mark all channels using this connection as closed. */
2617     if (conn->first_channel) {
2618         assert(conn->first_channel->state == SKM_CREATED);
2619         conn->first_channel->state = SKM_CLOSED;
2620         destroy_channel(q, conn->first_channel);
2621         conn->first_channel = NULL;
2622     }
2623     cont = int_dict_get_first(conn->channelmap, NULL, &chan);
2624     while (cont != NULL) {
2625         intkey_t channel = chan->channel;
2626         if ((chan->state == SKM_CONNECTING ||
2627              chan->state == SKM_CONNECTED))
2628         {
2629             set_channel_closed(q, chan, 1);
2630         }
2631 
2632         cont = int_dict_get_next(conn->channelmap, &channel, &chan);
2633     }
2634     assert(conn->refcount == 0);
2635 
2636     /* End the threads */
2637     self = pthread_self();
2638     if (!pthread_equal(self, conn->writer)) {
2639         /* Wait for the writer thread to end */
2640         THREAD_WAIT_END(q, conn->writer_state);
2641         pthread_join(conn->writer, NULL);
2642     }
2643     if (!pthread_equal(self, conn->reader)) {
2644         /* Wait for the reader  */
2645         THREAD_WAIT_END(q, conn->reader_state);
2646         pthread_join(conn->reader, NULL);
2647     }
2648     if (pthread_equal(self, conn->reader)
2649         || pthread_equal(self, conn->writer))
2650     {
2651         DEBUG_PRINT1("Detaching self");
2652         pthread_detach(self);
2653     }
2654 
2655     /* Destroy the channelmap */
2656     int_dict_destroy(conn->channelmap);
2657 
2658 #if SK_ENABLE_GNUTLS
2659     /* End the connection */
2660     if (conn->use_tls) {
2661         int rv;
2662         do {
2663             rv = gnutls_bye(conn->session, GNUTLS_SHUT_RDWR);
2664             DEBUG_PRINT2("gnutls_bye() -> %d", rv);
2665         } while (rv == GNUTLS_E_AGAIN || rv == GNUTLS_E_INTERRUPTED);
2666     }
2667 #endif /* SK_ENABLE_GNUTLS */
2668 
2669     /* Close the socket(s) */
2670     close(conn->rsocket);
2671     if (conn->rsocket != conn->wsocket) {
2672         close(conn->wsocket);
2673     }
2674 
2675     /* Destroy the queue */
2676     ASSERT_RESULT(skDequeDestroy(conn->queue), skDQErr_t, SKDQ_SUCCESS);
2677 
2678 #if SK_ENABLE_GNUTLS
2679     /* Destroy the session */
2680     if (conn->use_tls) {
2681         gnutls_deinit(conn->session);
2682     }
2683 #endif /* SK_ENABLE_GNUTLS */
2684 
2685     /* Destroy the condition variables */
2686     ASSERT_RESULT(pthread_cond_destroy(&conn->writer_cond), int, 0);
2687     ASSERT_RESULT(pthread_cond_destroy(&conn->reader_cond), int, 0);
2688 
2689     /* Destroy the address */
2690     if (conn->addr != NULL) {
2691         free(conn->addr);
2692     }
2693 
2694     /* Remove any incomplete buffers */
2695     if (conn->msg_read_buf.msg) {
2696         skMsgDestroy(conn->msg_read_buf.msg);
2697     }
2698 
2699     /* Finally, free the connection object */
2700     free(conn);
2701 
2702     RETURN(1);
2703 }
2704 
2705 
2706 static int
accept_connection(sk_msg_queue_t * q,int listen_sock)2707 accept_connection(
2708     sk_msg_queue_t     *q,
2709     int                 listen_sock)
2710 {
2711     int fd;
2712     sk_msg_conn_queue_t *conn;
2713     int rv;
2714     sk_sockaddr_t addr;
2715     struct sockaddr *addr_copy;
2716     socklen_t addrlen = sizeof(addr);
2717     char addr_buf[128];
2718 
2719     DEBUG_ENTER_FUNC;
2720 
2721     assert(q);
2722     ASSERT_QUEUE_LOCK(q);
2723     assert(q->root->listener_state == SKM_THREAD_RUNNING);
2724 
2725     while ((fd = accept(listen_sock, &addr.sa, &addrlen)) == -1) {
2726         if (errno == EAGAIN || errno == EWOULDBLOCK) {
2727             DEBUG_PRINT1("Properly handling EAGAIN/EWOULDBLOCK");
2728             RETURN(1);
2729         }
2730         if (errno == EINTR) {
2731             DEBUGMSG("accept() [%s]", strerror(errno));
2732             continue;
2733         }
2734         if (errno == EBADF) {
2735             DEBUGMSG("accept() [%s]", strerror(errno));
2736             RETURN(-1);
2737         }
2738         CRITMSG("Unexpected accept() error: %s", strerror(errno));
2739         XASSERT(0);
2740         skAbort();
2741     }
2742     skSockaddrString(addr_buf, sizeof(addr_buf), &addr);
2743     DEBUGMSG("Accepted connection from %s", addr_buf);
2744 
2745     /* Create the queue and both references */
2746     addr_copy = (struct sockaddr *)malloc(addrlen);
2747     if (addr_copy != NULL) {
2748         memcpy(addr_copy, &addr, addrlen);
2749     }
2750     rv = create_connection(q, fd, fd, addr_copy, addrlen, &conn,
2751                            q->root->bind_tls ? SKM_TLS_SERVER : SKM_TLS_NONE);
2752     if (rv != 0) {
2753         NOTICEMSG("Unable to initialize connection with %s", addr_buf);
2754         close(fd);
2755         free(addr_copy);
2756         RETURN(-1);
2757     }
2758 
2759     conn->first_channel = create_channel(q);
2760 
2761     start_connection(q, conn);
2762 
2763     RETURN(0);
2764 }
2765 
2766 static int
handle_system_control_message(sk_msg_queue_t * q,sk_msg_conn_queue_t * conn,sk_msg_t * msg)2767 handle_system_control_message(
2768     sk_msg_queue_t         *q,
2769     sk_msg_conn_queue_t    *conn,
2770     sk_msg_t               *msg)
2771 {
2772     int rv;
2773     int retval = 0;
2774 
2775     DEBUG_ENTER_FUNC;
2776 
2777     assert(q);
2778     assert(msg);
2779     ASSERT_QUEUE_LOCK(q);
2780 
2781     switch (msg->hdr.type) {
2782       case SKMSG_CTL_CHANNEL_ANNOUNCE:
2783         /* Handle the announcement of a connection */
2784         {
2785             skm_channel_t rchannel;
2786             skm_channel_t lchannel;
2787             sk_msg_channel_queue_t *chan;
2788             sk_channel_pair_t pair;
2789             sk_new_channel_info_t info;
2790 
2791             DEBUG_PRINT1("Handling SKMSG_CTL_CHANNEL_ANNOUNCE");
2792             assert(msg->hdr.size == sizeof(rchannel));
2793             assert(msg->segments == 2);
2794 
2795             /* Decode the remote channel */
2796             rchannel = SKMSG_CTL_MSG_GET_CHANNEL(msg);
2797 
2798             /* Create a local channel */
2799             if (conn->first_channel) {
2800                 chan = conn->first_channel;
2801                 conn->first_channel = NULL;
2802             } else {
2803                 chan = create_channel(q);
2804             }
2805             lchannel = chan->channel;
2806 
2807             /* Attach the channel to the connection */
2808             ASSERT_RESULT(set_channel_connecting(q, chan, conn), int, 0);
2809 
2810             /* Set the remote channel */
2811             ASSERT_RESULT(set_channel_connected(q, chan, rchannel), int, 0);
2812 
2813             /* Respond to the announcement with the channel pair */
2814             pair.rchannel = htons(rchannel);
2815             pair.lchannel = htons(lchannel);
2816             DEBUG_PRINT1("Sending SKMSG_CTL_CHANNEL_REPLY (Ext-control)");
2817             rv = send_message(q, lchannel,
2818                               SKMSG_CTL_CHANNEL_REPLY,
2819                               &pair, sizeof(pair), SKM_SEND_CONTROL,
2820                               0, NULL);
2821             if (rv != 0) {
2822                 DEBUG_PRINT1("Sending SKMSG_CTL_CHANNEL_REPLY failed");
2823                 retval = -11;
2824                 break;
2825             }
2826 
2827             /* Announce the new channel internally */
2828             info.channel = pair.lchannel;
2829             if (conn->addr != NULL) {
2830                 memcpy(&info.addr, conn->addr, conn->addrlen);
2831                 info.known = 1;
2832             } else {
2833                 info.known = 0;
2834             }
2835             DEBUG_PRINT1("Sending SKMSG_CTL_NEW_CONNECTION (Internal)");
2836             rv = send_message(q, SKMSG_CHANNEL_CONTROL,
2837                               SKMSG_CTL_NEW_CONNECTION,
2838                               &info, sizeof(info),
2839                               SKM_SEND_INTERNAL, 0, NULL);
2840             XASSERT(rv == 0);
2841         }
2842         break;
2843 
2844       case SKMSG_CTL_CHANNEL_REPLY:
2845         /* Handle the reply to a channel announcement */
2846         {
2847             sk_channel_pair_t *pair;
2848             sk_msg_channel_queue_t *chan;
2849             skm_channel_t rchannel;
2850             skm_channel_t lchannel;
2851 
2852             DEBUG_PRINT1("Handling SKMSG_CTL_CHANNEL_REPLY");
2853             assert(sizeof(*pair) == msg->hdr.size);
2854             assert(2 == msg->segments);
2855 
2856             /* Decode the channels: Reversed directionality is on
2857                purpose. */
2858             pair = (sk_channel_pair_t *)msg->segment[1].iov_base;
2859             rchannel = ntohs(pair->lchannel);
2860             lchannel = ntohs(pair->rchannel);
2861 
2862             /* Get the channel object */
2863             chan = find_channel(q, lchannel);
2864             XASSERT(chan != NULL);
2865 
2866             /* Set the remote channel */
2867             ASSERT_RESULT(set_channel_connected(q, chan, rchannel), int, 0);
2868 
2869             chan->conn->state = SKM_CONNECTED;
2870 
2871             /* Complete the connection */
2872             assert(chan->state != SKM_CONNECTING);
2873             assert(chan->is_pending);
2874             MUTEX_BROADCAST(&chan->pending);
2875         }
2876         break;
2877 
2878       case SKMSG_CTL_CHANNEL_KILL:
2879         /* Handle the death of a remote channel */
2880         {
2881             skm_channel_t channel;
2882             sk_msg_channel_queue_t *chan;
2883 
2884             DEBUG_PRINT1("Handling SKMSG_CTL_CHANNEL_KILL");
2885 
2886             assert(msg->hdr.size == sizeof(channel));
2887             assert(msg->segments == 2);
2888 
2889             /* Decode the channel. */
2890             channel = SKMSG_CTL_MSG_GET_CHANNEL(msg);
2891 
2892             /* Get the channel object. */
2893             chan = find_channel(q, channel);
2894             XASSERT(chan != NULL);
2895 
2896             /* Close the channel. */
2897             retval = set_channel_closed(q, chan, 0);
2898         }
2899         break;
2900 
2901       case SKMSG_CTL_CHANNEL_KEEPALIVE:
2902         DEBUG_PRINT1("Handling SKMSG_CTL_CHANNEL_KEEPALIVE");
2903         assert(msg->hdr.size == 0);
2904         /* Do nothing on KEEPALIVE*/
2905         break;
2906 
2907       default:
2908         skAbortBadCase(msg->hdr.type);
2909     }
2910 
2911     skMsgDestroy(msg);
2912 
2913     RETURN(retval);
2914 }
2915 
2916 
2917 /*
2918  *    THREAD ENTRY POINT
2919  *
2920  *    Entry point for the "skmsg_listener" thread, started from
2921  *    skMsgQueueBind()
2922  */
2923 static void *
listener_thread(void * vq)2924 listener_thread(
2925     void               *vq)
2926 {
2927     sk_msg_queue_t *q = (sk_msg_queue_t *)vq;
2928     int count;
2929     nfds_t len, valid;
2930     struct pollfd *pfd;
2931 
2932     DEBUG_ENTER_FUNC;
2933 
2934     DEBUG_PRINT1("Started listener_thread");
2935 
2936     assert(q);
2937 
2938     QUEUE_LOCK(q);
2939 
2940     pfd = q->root->pfd;
2941     len = q->root->pfd_len;
2942     valid = len;
2943     q->root->listener_state = SKM_THREAD_RUNNING;
2944     MUTEX_BROADCAST(&q->root->listener_cond);
2945     QUEUE_UNLOCK(q);
2946 
2947     while (valid && q->root->listener_state == SKM_THREAD_RUNNING) {
2948         unsigned i;
2949         int rv;
2950 
2951         /* Call poll */
2952         count = poll(pfd, len, SKMSG_IO_POLL_TIMEOUT);
2953         if (count == -1) {
2954             if (errno == EINTR || errno == EBADF) {
2955                 DEBUG_PRINT2("Ignoring expected poll() error: %s",
2956                              strerror(errno));
2957                 continue;
2958             }
2959             CRITMSG("Unexpected poll() error: %s", strerror(errno));
2960             skAbort();
2961         }
2962 
2963         /* Check the file descriptor results */
2964         for (i = 0; i < len; i++) {
2965             if (pfd[i].fd < 0) {
2966                 continue;
2967             }
2968             if (pfd[i].revents & (POLLERR | POLLHUP | POLLNVAL)) {
2969                 DEBUG_PRINT3("Poll returned %d, but revents was %d",
2970                              count, pfd[i].revents);
2971                 pfd[i].fd = -1;
2972                 valid--;
2973             } else if (pfd[i].revents & POLLIN) {
2974                 DEBUG_PRINT1("Accepting connection: trying");
2975                 QUEUE_LOCK(q);
2976                 if (q->root->listener_state != SKM_THREAD_RUNNING) {
2977                     QUEUE_UNLOCK(q);
2978                     DEBUG_PRINT1("Accepting connection: thread is ending");
2979                     break;
2980                 }
2981                 rv = accept_connection(q, pfd[i].fd);
2982                 QUEUE_UNLOCK(q);
2983                 if (rv == 0) {
2984                     DEBUG_PRINT1("Accepting connection: succeeded");
2985                 } else {
2986                     DEBUG_PRINT1("Accepting connection: failed");
2987                 }
2988             }
2989         }
2990     }
2991 
2992     QUEUE_LOCK(q);
2993     q->root->listener_state = SKM_THREAD_ENDED;
2994     THREAD_END(q);
2995     QUEUE_UNLOCK(q);
2996 
2997     DEBUG_PRINT1("STOPPED listener_thread");
2998 
2999     RETURN(NULL);
3000 }
3001 
3002 
3003 /*
3004  *    THREAD ENTRY POINT
3005  *
3006  *    Entry point for the "skmsg_reader" thread, started from
3007  *    create_connection().  Argument is a temporary structure that
3008  *    contains pointers to the queue and the connection.
3009  */
3010 static void *
reader_thread(void * vconn)3011 reader_thread(
3012     void               *vconn)
3013 {
3014     sk_queue_and_conn_t *both = (sk_queue_and_conn_t *)vconn;
3015     sk_msg_conn_queue_t *conn = both->conn;
3016     sk_msg_queue_t *q         = both->q;
3017     sk_msg_channel_queue_t *chan;
3018     int destroyed = 0;
3019     struct pollfd pfd;
3020     sk_sockaddr_t addr;
3021     char addr_buf[128] = "<unknown>";
3022 
3023     DEBUG_ENTER_FUNC;
3024 
3025     DEBUG_PRINT1("STARTED reader_thread");
3026 
3027     assert(conn);
3028     assert(q);
3029 
3030     free(both);
3031     both = NULL;
3032 
3033     /* Wait for a signal to start */
3034     QUEUE_LOCK(q);
3035     while (conn->reader_state == SKM_THREAD_BEFORE) {
3036         QUEUE_WAIT(&conn->reader_cond, q);
3037     }
3038     QUEUE_UNLOCK(q);
3039 
3040     /* Get the peer's address for error reporting */
3041     if (conn->addr != NULL) {
3042         memcpy(&addr.sa, conn->addr, conn->addrlen);
3043         skSockaddrString(addr_buf, sizeof(addr_buf), &addr);
3044     }
3045 
3046     /* Add the current time to the connection */
3047     conn->last_recv = time(NULL);
3048 
3049     /* Set up the poll descriptor */
3050     pfd.fd = conn->rsocket;
3051     pfd.events = POLLIN;
3052 
3053     while (!destroyed && conn->state != SKM_CLOSED
3054            && conn->reader_state == SKM_THREAD_RUNNING)
3055     {
3056         int           rv;
3057         sk_msg_t     *message = NULL;
3058 
3059 #if SK_ENABLE_GNUTLS
3060         /* If the transport is TLS and there are still pending bytes
3061          * in the session, don't bother polling; we have data. */
3062         if (conn->transport != CONN_TLS
3063             || !gnutls_record_check_pending(conn->session))
3064 #endif  /* SK_ENABLE_GNUTLS */
3065         {
3066             /* Poll for new data on the socket */
3067             rv = poll(&pfd, 1, SKMSG_IO_POLL_TIMEOUT);
3068             if (rv == -1) {
3069                 if (errno == EINTR || errno == EBADF) {
3070                     DEBUG_PRINT2("Ignoring expected poll(POLLIN) error: %s",
3071                                  strerror(errno));
3072                     continue;
3073                 }
3074                 CRITMSG("Unexpected poll(POLLIN) error for %s: %s",
3075                         addr_buf, strerror(rv));
3076                 skAbort();
3077             }
3078             if (rv == 0) {
3079                 if (CONNECTION_STAGNANT(conn, time(NULL))) {
3080                     /* It's been too long since we have heard something.
3081                      * Assume the connection died. */
3082                     INFOMSG(("Destroying connection to %s due to"
3083                              " %.0f seconds of inactivity"),
3084                             addr_buf, difftime(time(NULL), conn->last_recv));
3085                     QUEUE_LOCK(q);
3086                     destroyed = destroy_connection(q, conn);
3087                     QUEUE_UNLOCK(q);
3088                     break;
3089                 }
3090 #if SENDRCV_DEBUG & DEBUG_SKMSG_POLL_TIMEOUT
3091                 WRAP_ERRNO(SKTHREAD_DEBUG_PRINT3(
3092                                "Timeout on poll(%d, POLLIN) for %s",
3093                                pfd.fd, addr_buf));
3094 #endif  /* DEBUG_SKMSG_POLL_TIMEOUT */
3095                 continue;
3096             }
3097             if (pfd.revents & POLLNVAL) {
3098                 /* This is handled the same as EBADF from poll().  The
3099                  * standard states that we should get POLLNVAL for a
3100                  * bad file descriptor.  Under Linux, we might get a
3101                  * EBADF errno from the poll instead. */
3102                 DEBUG_PRINT1("poll(POLLIN) returned POLLNVAL");
3103                 continue;
3104             }
3105             if (pfd.revents & (POLLHUP | POLLERR)) {
3106                 /* Handle a disconnect or device error  */
3107                 INFOMSG("Closing connection to %s due to a disconnect (%s)",
3108                         addr_buf, SK_POLL_EVENT_STR(pfd.revents));
3109                 QUEUE_LOCK(q);
3110                 destroyed = destroy_connection(q, conn);
3111                 QUEUE_UNLOCK(q);
3112                 break;
3113             }
3114         }
3115 #if SK_ENABLE_GNUTLS && ((SENDRCV_DEBUG) & DEBUG_SKMSG_OTHER)
3116         else {
3117             DEBUG_PRINT2("Skipping poll(); %" SK_PRIuZ " bytes are pending",
3118                          gnutls_record_check_pending(conn->session));
3119         }
3120 #endif  /* DEBUG_SKMSG_OTHER */
3121 
3122         /* Update time for last received data; used by CONNECTION_STAGNANT */
3123         conn->last_recv = time(NULL);
3124 
3125         /* Read a message */
3126         DEBUG_PRINT1("Calling recv");
3127 #if SK_ENABLE_GNUTLS
3128         if (CONN_TLS == conn->transport) {
3129             rv = tls_recv(conn, &message);
3130         } else
3131 #endif  /* SK_ENABLE_GNUTLS */
3132         {
3133             rv = tcp_recv(conn, &message);
3134         }
3135         if (rv == SKMERR_PARTIAL || rv == SKMERR_EMPTY) {
3136             assert(NULL == message);
3137             /* the recv() was successful, but only part of the message
3138              * was available to read, or nothing was read at all */
3139             continue;
3140         }
3141         if (rv != 0) {
3142             /* Treat the connection as closed */
3143             INFOMSG("Closing connection to %s due to failed read: %s",
3144                     addr_buf, skmerr_strerror(conn, rv));
3145             QUEUE_LOCK(q);
3146             destroyed = destroy_connection(q, conn);
3147             QUEUE_UNLOCK(q);
3148             break;
3149         }
3150 
3151         assert(message);
3152 
3153         /* Handle control messages */
3154         if (message->hdr.channel == SKMSG_CHANNEL_CONTROL &&
3155             message->hdr.type >= SKMSG_MINIMUM_SYSTEM_CTL_CHANNEL)
3156         {
3157             QUEUE_LOCK(q);
3158             rv = handle_system_control_message(q, conn, message);
3159             QUEUE_UNLOCK(q);
3160             if (rv == 1) {
3161                 destroyed = 1;
3162             }
3163             continue;
3164         }
3165 
3166         /* Handle ordinary messages */
3167         chan = find_channel(q, message->hdr.channel);
3168         if (chan == NULL) {
3169             skMsgDestroy(message);
3170         } else {
3171             /* Put the message on the queue */
3172             DEBUG_PRINT3("Enqueue: chan=%#x type=%#x",
3173                          message->hdr.channel, message->hdr.type);
3174             DEBUG_PRINT2("From reader: %p", (void *)message);
3175             rv = mqQueueAdd(chan->queue, message);
3176             if (rv != 0) {
3177                 XASSERT(conn->state == SKM_CLOSED ||
3178                         conn->reader_state != SKM_THREAD_RUNNING);
3179                 skMsgDestroy(message);
3180             }
3181         }
3182     }
3183 
3184     QUEUE_LOCK(q);
3185     if (!destroyed) {
3186         conn->reader_state = SKM_THREAD_ENDED;
3187     }
3188     THREAD_END(q);
3189     QUEUE_UNLOCK(q);
3190 
3191     DEBUG_PRINT1("STOPPED reader_thread");
3192 
3193     RETURN(NULL);
3194 }
3195 
3196 
3197 /*
3198  *    THREAD ENTRY POINT
3199  *
3200  *    Entry point for the "skmsg_writer" thread, started from
3201  *    create_connection().  Argument is a temporary structure that
3202  *    contains pointers to the queue and the connection.
3203  */
3204 static void *
writer_thread(void * vconn)3205 writer_thread(
3206     void               *vconn)
3207 {
3208     sk_queue_and_conn_t *both      = (sk_queue_and_conn_t *)vconn;
3209     sk_msg_conn_queue_t *conn      = both->conn;
3210     sk_msg_queue_t      *q         = both->q;
3211     int                  have_msg  = 0;
3212     int                  destroyed = 0;
3213     sk_msg_write_buf_t   write_buf;
3214     struct pollfd pfd;
3215     sk_sockaddr_t addr;
3216     char addr_buf[128] = "<unknown>";
3217 
3218     DEBUG_ENTER_FUNC;
3219 
3220     DEBUG_PRINT1("STARTED writer_thread");
3221 
3222     assert(conn);
3223     assert(q);
3224 
3225     free(both);
3226 
3227     memset(&write_buf, 0, sizeof(write_buf));
3228 
3229     /* Wait for a signal to start */
3230     QUEUE_LOCK(q);
3231     while (conn->writer_state == SKM_THREAD_BEFORE) {
3232         QUEUE_WAIT(&conn->writer_cond, q);
3233     }
3234     QUEUE_UNLOCK(q);
3235 
3236     /* Get the peer's address for error reporting */
3237     if (conn->addr != NULL) {
3238         memcpy(&addr.sa, conn->addr, conn->addrlen);
3239         skSockaddrString(addr_buf, sizeof(addr_buf), &addr);
3240     }
3241 
3242     /* Set up the poll descriptor */
3243     pfd.fd = conn->wsocket;
3244     pfd.events = POLLOUT;
3245 
3246     while (conn->writer_state == SKM_THREAD_RUNNING) {
3247         int           rv;
3248         skDQErr_t     err;
3249         int           block = (conn->state != SKM_CLOSED);
3250 
3251         /* If not currently sending a message, get a new message to
3252          * send from the queue  */
3253         if (!have_msg) {
3254             if (!block) {
3255                 /* If the connection is closed, don't block in the deque */
3256                 err = skDequePopBackNB(conn->queue, (void**)&write_buf.msg);
3257             } else if (conn->keepalive == 0) {
3258                 err = skDequePopBack(conn->queue, (void**)&write_buf.msg);
3259             } else {
3260                 err = skDequePopBackTimed(conn->queue, (void**)&write_buf.msg,
3261                                           conn->keepalive);
3262                 if (err == SKDQ_TIMEDOUT) {
3263                     /* Create a keepalive message */
3264                     write_buf.msg = (sk_msg_t*)calloc(1, sizeof(sk_msg_t));
3265                     MEM_ASSERT(write_buf.msg);
3266                     write_buf.msg->segments = 1;
3267                     write_buf.msg->segment[0].iov_base = &write_buf.msg->hdr;
3268                     write_buf.msg->segment[0].iov_len
3269                         = sizeof(write_buf.msg->hdr);
3270                     write_buf.msg->hdr.channel = SKMSG_CHANNEL_CONTROL;
3271                     write_buf.msg->hdr.type = SKMSG_CTL_CHANNEL_KEEPALIVE;
3272                     /* Pretend it came from the queue */
3273                     err = SKDQ_SUCCESS;
3274                     DEBUG_PRINT1("Sending SKMSG_CTL_CHANNEL_KEEPALIVE");
3275                 }
3276             }
3277             if (err != SKDQ_SUCCESS) {
3278                 assert(err == SKDQ_UNBLOCKED || err == SKDQ_DESTROYED
3279                        || err == SKDQ_EMPTY);
3280                 break;
3281             }
3282             if (write_buf.msg->hdr.channel == SKMSG_CHANNEL_CONTROL
3283                 && write_buf.msg->hdr.type == SKMSG_WRITER_UNBLOCKER)
3284             {
3285                 /* Do not destroy message, as this is a special static
3286                  * message. */
3287                 write_buf.msg = NULL;
3288                 DEBUG_PRINT1("Handling SKMSG_WRITER_UNBLOCKER message");
3289                 continue;
3290             }
3291             have_msg = 1;
3292             write_buf.msg_size
3293                 = sizeof(write_buf.msg->hdr) + write_buf.msg->hdr.size;
3294             write_buf.cur_seg = 0;
3295             write_buf.seg_offset = 0;
3296 
3297             /* Convert data to network byte order */
3298             write_buf.msg->hdr.channel = htons(write_buf.msg->hdr.channel);
3299             write_buf.msg->hdr.type    = htons(write_buf.msg->hdr.type);
3300             write_buf.msg->hdr.size    = htons(write_buf.msg->hdr.size);
3301         }
3302 
3303         /* Wait for socket to be available for writing */
3304         rv = poll(&pfd, 1, SKMSG_IO_POLL_TIMEOUT);
3305         if (rv == -1) {
3306             if (errno == EINTR || errno == EBADF) {
3307                 DEBUG_PRINT2("Ignoring expected poll(POLLOUT) error: %s",
3308                              strerror(errno));
3309                 continue;
3310             }
3311             CRITMSG("Unexpected poll(POLLOUT) error for %s: %s",
3312                     addr_buf, strerror(errno));
3313             skAbort();
3314         }
3315         if (rv == 0) {
3316 #if SENDRCV_DEBUG & DEBUG_SKMSG_POLL_TIMEOUT
3317             WRAP_ERRNO(SKTHREAD_DEBUG_PRINT3(
3318                            "Timeout on poll(%d, POLLOUT) for %s",
3319                            pfd.fd, addr_buf));
3320 #endif  /* DEBUG_SKMSG_POLL_TIMEOUT */
3321             continue;
3322         }
3323         if (pfd.revents & (POLLHUP | POLLERR | POLLNVAL)) {
3324             /* Handle a disconnect or device error  */
3325             INFOMSG("Closing connection to %s due to a disconnect (%s)",
3326                     addr_buf, SK_POLL_EVENT_STR(pfd.revents));
3327             QUEUE_LOCK(q);
3328             destroyed = destroy_connection(q, conn);
3329             QUEUE_UNLOCK(q);
3330             break;
3331         }
3332 
3333 #if SK_ENABLE_GNUTLS
3334         if (CONN_TLS == conn->transport) {
3335             rv = tls_send(conn, &write_buf);
3336         } else
3337 #endif  /* SK_ENABLE_GNUTLS */
3338         {
3339             rv = tcp_send(conn, &write_buf);
3340         }
3341         if (rv == SKMERR_PARTIAL) {
3342             /* partial write */
3343             continue;
3344         }
3345         have_msg = 0;
3346         skMsgDestroy(write_buf.msg);
3347         write_buf.msg = NULL;
3348         if (rv != 0) {
3349             /* Treat the connection as closed */
3350             INFOMSG("Closing connection to %s due to failed write: %s",
3351                     addr_buf, skmerr_strerror(conn, rv));
3352             QUEUE_LOCK(q);
3353             destroyed = destroy_connection(q, conn);
3354             QUEUE_UNLOCK(q);
3355             break;
3356         }
3357     }
3358 
3359     if (write_buf.msg) {
3360         skMsgDestroy(write_buf.msg);
3361     }
3362 
3363     QUEUE_LOCK(q);
3364     if (!destroyed) {
3365         conn->writer_state = SKM_THREAD_ENDED;
3366     }
3367     THREAD_END(q);
3368     QUEUE_UNLOCK(q);
3369 
3370     DEBUG_PRINT1("STOPPED writer_thread");
3371 
3372     RETURN(NULL);
3373 }
3374 
3375 
3376 int
skMsgQueueCreate(sk_msg_queue_t ** queue)3377 skMsgQueueCreate(
3378     sk_msg_queue_t    **queue)
3379 {
3380     sk_msg_queue_t *q;
3381     int retval = 0;
3382     int fd[2];
3383     int rv;
3384     sk_msg_conn_queue_t *conn;
3385     sk_msg_channel_queue_t *chan;
3386 
3387     DEBUG_ENTER_FUNC;
3388 
3389     q = (sk_msg_queue_t*)calloc(1, sizeof(sk_msg_queue_t));
3390     if (q == NULL) {
3391         RETURN(SKMERR_MEMORY);
3392     }
3393 
3394     q->root = (sk_msg_root_t*)calloc(1, sizeof(sk_msg_root_t));
3395     if (q->root == NULL) {
3396         free(q);
3397         RETURN(SKMERR_MEMORY);
3398     }
3399 
3400     THREAD_INFO_INIT(q);
3401 
3402     q->root->channel = int_dict_create(sizeof(sk_msg_channel_queue_t *));
3403     if (q->root->channel == NULL) {
3404         retval = SKMERR_MEMORY;
3405         goto error;
3406     }
3407     q->root->groups = int_dict_create(sizeof(sk_msg_queue_t *));
3408     if (q->root->groups == NULL) {
3409         retval = SKMERR_MEMORY;
3410         goto error;
3411     }
3412     q->channel = int_dict_create(sizeof(sk_msg_channel_queue_t *));
3413     if (q->channel == NULL) {
3414         retval = SKMERR_MEMORY;
3415         goto error;
3416     }
3417 
3418     rv = pthread_mutex_init(QUEUE_MUTEX(q), NULL);
3419     if (rv != 0) {
3420         retval = SKMERR_MUTEX;
3421         goto error;
3422     }
3423 
3424     rv = pthread_cond_init(&q->shutdowncond, NULL);
3425     if (rv != 0) {
3426         retval = SKMERR_MUTEX;
3427         goto error;
3428     }
3429 
3430     q->group = mqCreateFair(sk_destroy_report_message);
3431     if (q->group == NULL) {
3432         goto error;
3433     }
3434 
3435     rv = pipe(fd);
3436     if (rv == -1) {
3437         retval = SKMERR_PIPE;
3438         goto error;
3439     }
3440 
3441 #if SK_ENABLE_GNUTLS
3442     rv = skMsgQueueInitializeGnuTLS(q);
3443     if (rv != 0) {
3444         retval = SKMERR_GNUTLS;
3445         goto error;
3446     }
3447 #endif  /* SK_ENABLE_GNUTLS */
3448 
3449     /* Initialize the listener thread start state */
3450     pthread_cond_init(&q->root->listener_cond, NULL);
3451     q->root->listener_state = SKM_THREAD_BEFORE;
3452 
3453     /* Lock the mutex to satisfy preconditions for mucking with
3454        connections and channels. */
3455     QUEUE_LOCK(q);
3456 
3457     /* Create an internal connection for the control channel. */
3458     rv = create_connection(q, fd[READ], fd[WRITE], NULL, 0, &conn,
3459                            SKM_TLS_NONE);
3460     conn->keepalive = SKMSG_CONTROL_KEEPALIVE_TIMEOUT;
3461     unblock_connection(q, conn);
3462     XASSERT(rv == 0);
3463 
3464     /* Create a channel for the control channel. */
3465     q->root->next_channel = SKMSG_CHANNEL_CONTROL;
3466     chan = create_channel(q);
3467 
3468     /* Start the connection */
3469     start_connection(q, conn);
3470 
3471     /* Attach the internal connection to the control channel. */
3472     ASSERT_RESULT(set_channel_connecting(q, chan, conn), int, 0);
3473 
3474     /* And let's completely connect it. */
3475     ASSERT_RESULT(set_channel_connected(q, chan, SKMSG_CHANNEL_CONTROL),
3476                   int, 0);
3477     conn->state = SKM_CONNECTED;
3478 
3479     /* Unlock the mutex */
3480     QUEUE_UNLOCK(q);
3481 
3482     *queue = q;
3483     RETURN(0);
3484 
3485   error:
3486     skMsgQueueDestroy(q);
3487     RETURN(retval);
3488 }
3489 
3490 static void
sk_msg_queue_shutdown(sk_msg_queue_t * q)3491 sk_msg_queue_shutdown(
3492     sk_msg_queue_t     *q)
3493 {
3494     sk_msg_channel_queue_t *chan;
3495     void *cont;
3496 
3497     DEBUG_ENTER_FUNC;
3498 
3499     assert(q);
3500     ASSERT_QUEUE_LOCK(q);
3501 
3502     if (q->shuttingdown) {
3503         RETURN_VOID;
3504     }
3505 
3506     q->shuttingdown = 1;
3507 
3508     /* Shut down a queue by shutting down all its channels. */
3509     cont = int_dict_get_first(q->channel, NULL, &chan);
3510     while (cont != NULL) {
3511         intkey_t key = chan->channel;
3512         if (chan->state == SKM_CONNECTED || chan->state == SKM_CONNECTING) {
3513             set_channel_closed(q, chan, 0);
3514         }
3515         cont = int_dict_get_next(q->channel, &key, &chan);
3516     }
3517 
3518     /* And then shutting down the multiqueue */
3519     mqShutdown(q->group);
3520 
3521     q->shuttingdown = 0;
3522 
3523     MUTEX_BROADCAST(&q->shutdowncond);
3524 
3525     RETURN_VOID;
3526 }
3527 
3528 
3529 void
skMsgQueueShutdown(sk_msg_queue_t * q)3530 skMsgQueueShutdown(
3531     sk_msg_queue_t     *q)
3532 {
3533     DEBUG_ENTER_FUNC;
3534 
3535     assert(q);
3536 
3537     QUEUE_LOCK(q);
3538 
3539     sk_msg_queue_shutdown(q);
3540 
3541     QUEUE_UNLOCK(q);
3542     RETURN_VOID;
3543 }
3544 
3545 
3546 void
skMsgQueueShutdownAll(sk_msg_queue_t * q)3547 skMsgQueueShutdownAll(
3548     sk_msg_queue_t     *q)
3549 {
3550     sk_msg_channel_queue_t *chan;
3551     void *cont;
3552 
3553     DEBUG_ENTER_FUNC;
3554 
3555     if (!q) {
3556         RETURN_VOID;
3557     }
3558     QUEUE_LOCK(q);
3559 
3560     if (q->root->shuttingdown) {
3561         QUEUE_UNLOCK(q);
3562         RETURN_VOID;
3563     }
3564 
3565     q->root->shuttingdown = 1;
3566     q->root->shutdownqueue = q;
3567 
3568     q->root->listener_state = SKM_THREAD_SHUTTING_DOWN;
3569 
3570     /* Shut down all channels */
3571     cont = int_dict_get_first(q->root->channel, NULL, &chan);
3572     while (cont != NULL) {
3573         intkey_t key = chan->channel;
3574         sk_msg_queue_shutdown(chan->group);
3575         cont = int_dict_get_next(q->root->channel, &key, &chan);
3576     }
3577 
3578     if (q->root->pfd) {
3579         nfds_t i;
3580 
3581         for (i = 0; i < q->root->pfd_len; i++) {
3582             if (q->root->pfd[i].fd >= 0) {
3583                 close(q->root->pfd[i].fd);
3584                 q->root->pfd[i].fd = -1;
3585             }
3586         }
3587     }
3588 
3589     THREAD_WAIT_ALL_END(q);
3590 
3591     if (q->root->pfd) {
3592         pthread_join(q->root->listener, NULL);
3593         free(q->root->pfd);
3594         q->root->pfd = NULL;
3595     }
3596 
3597     q->root->shuttingdown = 0;
3598 
3599     MUTEX_BROADCAST(&q->shutdowncond);
3600 
3601     QUEUE_UNLOCK(q);
3602 
3603     RETURN_VOID;
3604 }
3605 
3606 
3607 static void
skMsgQueueDestroyAll(sk_msg_queue_t * q)3608 skMsgQueueDestroyAll(
3609     sk_msg_queue_t     *q)
3610 {
3611     DEBUG_ENTER_FUNC;
3612 
3613     assert(q);
3614     ASSERT_QUEUE_LOCK(q);
3615 
3616     /* Verify that all channels have been destroyed */
3617     assert(int_dict_get_first(q->root->channel, NULL, NULL) == NULL);
3618 
3619     int_dict_destroy(q->root->channel);
3620     int_dict_destroy(q->root->groups);
3621 
3622     QUEUE_UNLOCK(q);
3623 
3624     THREAD_INFO_DESTROY(q);
3625 
3626 #if SK_ENABLE_GNUTLS
3627     if (q->root->cred_set) {
3628         gnutls_certificate_free_credentials(q->root->cred);
3629     }
3630 #endif  /* SK_ENABLE_GNUTLS */
3631 
3632     ASSERT_RESULT(pthread_cond_destroy(&q->root->listener_cond), int, 0);
3633     ASSERT_RESULT(pthread_mutex_destroy(QUEUE_MUTEX(q)), int, 0);
3634 
3635     free(q->root);
3636     free(q);
3637 
3638     RETURN_VOID;
3639 }
3640 
3641 
3642 void
skMsgQueueDestroy(sk_msg_queue_t * q)3643 skMsgQueueDestroy(
3644     sk_msg_queue_t     *q)
3645 {
3646     sk_msg_root_t *root;
3647     sk_msg_channel_queue_t *chan;
3648     void *cont;
3649 
3650     DEBUG_ENTER_FUNC;
3651 
3652     assert(q);
3653 
3654     QUEUE_LOCK(q);
3655 
3656     root = q->root;
3657 
3658     while (q->shuttingdown
3659            || (root->shuttingdown && root->shutdownqueue == q))
3660     {
3661         QUEUE_WAIT(&q->shutdowncond, q);
3662     }
3663 
3664     /* Destroy the channels */
3665     cont = int_dict_get_first(q->channel, NULL, &chan);
3666     while (cont != NULL) {
3667         intkey_t channel = chan->channel;
3668         destroy_channel(q, chan);
3669         cont = int_dict_get_next(q->channel, &channel, &chan);
3670     }
3671 
3672     mqShutdown(q->group);
3673     mqDestroy(q->group);
3674 
3675     int_dict_destroy(q->channel);
3676 
3677     if (int_dict_get_first(q->root->groups, NULL, NULL) == NULL) {
3678         skMsgQueueDestroyAll(q);
3679         RETURN_VOID;
3680     }
3681 
3682     free(q);
3683 
3684     MUTEX_UNLOCK(&root->mutex);
3685 
3686     RETURN_VOID;
3687 }
3688 
3689 int
skMsgQueueBind(sk_msg_queue_t * q,const sk_sockaddr_array_t * listen_addrs)3690 skMsgQueueBind(
3691     sk_msg_queue_t             *q,
3692     const sk_sockaddr_array_t  *listen_addrs)
3693 {
3694     static int on = 1;
3695     uint32_t i, n;
3696     struct pollfd *pfd;
3697     int rv;
3698 
3699     DEBUG_ENTER_FUNC;
3700 
3701     assert(q);
3702     assert(listen_addrs);
3703     assert(skSockaddrArrayGetSize(listen_addrs) > 0);
3704 
3705     pfd = (struct pollfd*)calloc(skSockaddrArrayGetSize(listen_addrs),
3706                                  sizeof(struct pollfd));
3707     MEM_ASSERT(pfd);
3708 
3709     n = 0;
3710     DEBUGMSG(("Attempting to bind %" PRIu32 " addresses for %s"),
3711              skSockaddrArrayGetSize(listen_addrs),
3712              skSockaddrArrayGetHostPortPair(listen_addrs));
3713     for (i = 0; i < skSockaddrArrayGetSize(listen_addrs); i++) {
3714         char addr_string[PATH_MAX];
3715         int sock;
3716         sk_sockaddr_t *addr = skSockaddrArrayGet(listen_addrs, i);
3717 
3718         skSockaddrString(addr_string, sizeof(addr_string), addr);
3719 
3720         /* Bind a socket to the address*/
3721         sock = socket(addr->sa.sa_family, SOCK_STREAM, 0);
3722         if (sock == -1) {
3723             DEBUGMSG("Skipping %s: Unable to create stream socket: %s",
3724                      addr_string, strerror(errno));
3725             pfd[i].fd = -1;
3726             continue;
3727         }
3728         rv = setsockopt(sock, SOL_SOCKET, SO_REUSEADDR,
3729                         &on, sizeof(on));
3730         XASSERT(rv != -1);
3731         rv = bind(sock, &addr->sa, skSockaddrGetLen(addr));
3732         if (rv == 0) {
3733             DEBUGMSG("Succeeded binding to %s", addr_string);
3734             rv = listen(sock, LISTENQ);
3735             XASSERT(rv != -1);
3736             set_nonblock(sock);
3737             pfd[i].events = POLLIN;
3738             n++;
3739         } else {
3740             DEBUGMSG("Skipping %s: Unable to bind: %s",
3741                      addr_string, strerror(errno));
3742             close(sock);
3743             sock = -1;
3744         }
3745         pfd[i].fd = sock;
3746     }
3747     if (n == 0) {
3748         ERRMSG("Failed to bind any addresses for %s",
3749                skSockaddrArrayGetHostPortPair(listen_addrs));
3750         free(pfd);
3751         RETURN(-1);
3752     }
3753 
3754     DEBUGMSG(("Bound %" PRIu32 "/%" PRIu32 " addresses for %s"),
3755              (uint32_t)n, skSockaddrArrayGetSize(listen_addrs),
3756              skSockaddrArrayGetHostPortPair(listen_addrs));
3757 
3758     QUEUE_LOCK(q);
3759 
3760     if (q->root->listener_state != SKM_THREAD_BEFORE) {
3761         QUEUE_UNLOCK(q);
3762         for (i = 0; i < skSockaddrArrayGetSize(listen_addrs); i++) {
3763             if (pfd[i].fd >= 0) {
3764                 close(pfd[i].fd);
3765             }
3766         }
3767         free(pfd);
3768         RETURN(-1);
3769     }
3770 
3771     /* Set the listen sock for the queue. */
3772     assert(q->root->pfd == NULL);
3773     q->root->pfd = pfd;
3774     q->root->pfd_len = skSockaddrArrayGetSize(listen_addrs);
3775     q->root->bind_tls = (connection_type == CONN_TLS);
3776 
3777     THREAD_START("skmsg_listener", rv, q, &q->root->listener,
3778                  listener_thread, q);
3779     XASSERT(rv == 0);
3780 
3781     while (q->root->listener_state == SKM_THREAD_BEFORE) {
3782         QUEUE_WAIT(&q->root->listener_cond, q);
3783     }
3784     assert(q->root->listener_state == SKM_THREAD_RUNNING);
3785 
3786     QUEUE_UNLOCK(q);
3787 
3788     RETURN(0);
3789 }
3790 
3791 
3792 int
skMsgQueueConnect(sk_msg_queue_t * q,struct sockaddr * addr,socklen_t addrlen,skm_channel_t * channel)3793 skMsgQueueConnect(
3794     sk_msg_queue_t     *q,
3795     struct sockaddr    *addr,
3796     socklen_t           addrlen,
3797     skm_channel_t      *channel)
3798 {
3799     int rv;
3800     int sock;
3801     sk_msg_conn_queue_t *conn;
3802     sk_msg_channel_queue_t *chan;
3803     int retval;
3804     skm_channel_t lchannel;
3805     struct sockaddr *copy;
3806 
3807     DEBUG_ENTER_FUNC;
3808 
3809     assert(q);
3810     sock = socket(addr->sa_family, SOCK_STREAM, 0);
3811     if (sock == -1) {
3812         RETURN(-1);
3813     }
3814 
3815     /* Connect to the remote side */
3816     rv = connect(sock, addr, addrlen);
3817     if (rv == -1) {
3818         DEBUGMSG("Failed to connect: %s", strerror(errno));
3819         close(sock);
3820         RETURN(-1);
3821     }
3822 
3823     /* Create a channel and connection, and bind them */
3824     QUEUE_LOCK(q);
3825     if (q->shuttingdown) {
3826         close(sock);
3827         QUEUE_UNLOCK(q);
3828         RETURN(-1);
3829     }
3830     copy = (struct sockaddr *)malloc(addrlen);
3831     if (copy != NULL) {
3832         memcpy(copy, addr, addrlen);
3833     }
3834     rv = create_connection(q, sock, sock, copy, addrlen, &conn,
3835                            ((connection_type == CONN_TLS)
3836                             ? SKM_TLS_CLIENT : SKM_TLS_NONE));
3837     if (rv == -1) {
3838         close(sock);
3839         free(copy);
3840         QUEUE_UNLOCK(q);
3841         RETURN(-1);
3842     }
3843 
3844     chan = create_channel(q);
3845     start_connection(q, conn);
3846     rv = set_channel_connecting(q, chan, conn);
3847     XASSERT(rv == 0);
3848 
3849     /* Announce the channel id to the remote queue */
3850     lchannel = htons(chan->channel);
3851     DEBUG_PRINT1("Sending SKMSG_CTL_CHANNEL_ANNOUNCE (Ext-control)");
3852     rv = send_message(q, chan->channel,
3853                       SKMSG_CTL_CHANNEL_ANNOUNCE,
3854                       &lchannel, sizeof(lchannel), SKM_SEND_CONTROL, 0, NULL);
3855     if (rv != 0) {
3856         DEBUG_PRINT1("Sending SKMSG_CTL_CHANNEL_ANNOUNCE failed");
3857         destroy_connection(q, conn);
3858         close(sock);
3859         QUEUE_UNLOCK(q);
3860         RETURN(-1);
3861     }
3862 
3863     /* Wait for a reply */
3864     chan->is_pending = 1;
3865     while (chan->is_pending && chan->state == SKM_CONNECTING) {
3866         QUEUE_WAIT(&chan->pending, q);
3867     }
3868     chan->is_pending = 0;
3869 
3870     if (chan->state == SKM_CLOSED) {
3871         destroy_channel(q, chan);
3872         retval = -1;
3873     } else {
3874         retval = 0;
3875         *channel = chan->channel;
3876     }
3877 
3878     QUEUE_UNLOCK(q);
3879 
3880     RETURN(retval);
3881 }
3882 
3883 int
skMsgChannelNew(sk_msg_queue_t * q,skm_channel_t channel,skm_channel_t * new_channel)3884 skMsgChannelNew(
3885     sk_msg_queue_t     *q,
3886     skm_channel_t       channel,
3887     skm_channel_t      *new_channel)
3888 {
3889     sk_msg_channel_queue_t *chan;
3890     sk_msg_channel_queue_t *newchan;
3891     skm_channel_t lchannel;
3892     int retval;
3893     int rv;
3894 
3895     DEBUG_ENTER_FUNC;
3896 
3897     assert(q);
3898     assert(new_channel);
3899 
3900     QUEUE_LOCK(q);
3901 
3902     if (q->shuttingdown) {
3903         QUEUE_UNLOCK(q);
3904         RETURN(-1);
3905     }
3906 
3907     chan = find_channel(q, channel);
3908     XASSERT(chan != NULL);
3909     XASSERT(chan->state == SKM_CONNECTED);
3910     assert(chan->conn != NULL);
3911 
3912     /* Create a channel and connection, and bind it to the connection */
3913     newchan = create_channel(q);
3914     ASSERT_RESULT(set_channel_connecting(q, newchan, chan->conn), int, 0);
3915 
3916     lchannel = htons(newchan->channel);
3917 
3918     /* Announce the channel id to the remote queue */
3919     DEBUG_PRINT1("Sending SKMSG_CTL_CHANNEL_ANNOUNCE (Ext-control))");
3920     rv = send_message(q, newchan->channel,
3921                       SKMSG_CTL_CHANNEL_ANNOUNCE,
3922                       &lchannel, sizeof(lchannel), SKM_SEND_CONTROL, 0, NULL);
3923     if (rv != 0) {
3924         destroy_channel(q, newchan);
3925         QUEUE_UNLOCK(q);
3926         RETURN(-1);
3927     }
3928 
3929     /* Wait for a response */
3930     newchan->is_pending = 1;
3931     while (newchan->is_pending && newchan->state == SKM_CONNECTING) {
3932         QUEUE_WAIT(&newchan->pending, q);
3933     }
3934     newchan->is_pending = 0;
3935 
3936     if (newchan->state == SKM_CLOSED) {
3937         retval = -1;
3938         destroy_channel(q, newchan);
3939     } else {
3940         retval = 0;
3941         *new_channel = newchan->channel;
3942     }
3943 
3944     QUEUE_UNLOCK(q);
3945 
3946     RETURN(retval);
3947 }
3948 
3949 
3950 int
skMsgChannelSplit(sk_msg_queue_t * q,skm_channel_t channel,sk_msg_queue_t ** new_queue)3951 skMsgChannelSplit(
3952     sk_msg_queue_t     *q,
3953     skm_channel_t       channel,
3954     sk_msg_queue_t    **new_queue)
3955 {
3956     sk_msg_queue_t *new_q;
3957     int rv;
3958 
3959     DEBUG_ENTER_FUNC;
3960 
3961     assert(q);
3962 
3963     new_q = (sk_msg_queue_t *)calloc(1, sizeof(*new_q));
3964     if (new_q == NULL) {
3965         return -1;
3966     }
3967 
3968     rv = pthread_cond_init(&new_q->shutdowncond, NULL);
3969     if (rv != 0) {
3970         free(new_q);
3971         return -1;
3972     }
3973 
3974     new_q->channel = int_dict_create(sizeof(sk_msg_channel_queue_t *));
3975     if (new_q->channel == NULL) {
3976         free(new_q);
3977         return -1;
3978     }
3979 
3980     new_q->group = mqCreateFair(sk_destroy_report_message);
3981     if (new_q->group == NULL) {
3982         int_dict_destroy(new_q->channel);
3983         free(new_q);
3984         return -1;
3985     }
3986 
3987     new_q->root = q->root;
3988 
3989     rv = skMsgChannelMove(channel, new_q);
3990     if (rv != 0) {
3991         skMsgQueueDestroy(new_q);
3992     } else {
3993         *new_queue = new_q;
3994     }
3995 
3996     return rv;
3997 }
3998 
3999 
4000 int
skMsgChannelMove(skm_channel_t channel,sk_msg_queue_t * q)4001 skMsgChannelMove(
4002     skm_channel_t       channel,
4003     sk_msg_queue_t     *q)
4004 {
4005     sk_msg_channel_queue_t *chan;
4006     int retval = 0;
4007 
4008     DEBUG_ENTER_FUNC;
4009 
4010     assert(q);
4011 
4012     QUEUE_LOCK(q);
4013 
4014     chan = find_channel(q, channel);
4015     if (chan == NULL) {
4016         retval = -1;
4017         goto end;
4018     }
4019 
4020     ASSERT_RESULT(mqQueueMove(q->group, chan->queue), int, 0);
4021     ASSERT_RESULT(int_dict_del(chan->group->channel, channel), int, 0);
4022     ASSERT_RESULT(int_dict_set(q->channel, channel, &chan), int, 0);
4023     ASSERT_RESULT(int_dict_set(q->root->groups, channel, &q), int, 0);
4024 
4025     chan->group = q;
4026 
4027   end:
4028     QUEUE_UNLOCK(q);
4029 
4030     RETURN(retval);
4031 }
4032 
4033 
4034 int
skMsgChannelKill(sk_msg_queue_t * q,skm_channel_t channel)4035 skMsgChannelKill(
4036     sk_msg_queue_t     *q,
4037     skm_channel_t       channel)
4038 {
4039     sk_msg_channel_queue_t *chan;
4040 
4041     DEBUG_ENTER_FUNC;
4042 
4043     assert(q);
4044 
4045     QUEUE_LOCK(q);
4046 
4047     if (!q->shuttingdown) {
4048         chan = find_channel(q, channel);
4049         XASSERT(chan != NULL);
4050 
4051         destroy_channel(q, chan);
4052     }
4053 
4054     QUEUE_UNLOCK(q);
4055 
4056     RETURN(0);
4057 }
4058 
4059 
4060 static int
send_message(sk_msg_queue_t * q,skm_channel_t lchannel,skm_type_t type,void * message,skm_len_t length,sk_send_type_t send_type,int no_copy,void (* free_fn)(void *))4061 send_message(
4062     sk_msg_queue_t     *q,
4063     skm_channel_t       lchannel,
4064     skm_type_t          type,
4065     void               *message,
4066     skm_len_t           length,
4067     sk_send_type_t      send_type,
4068     int                 no_copy,
4069     void              (*free_fn)(void *))
4070 {
4071     sk_msg_channel_queue_t *chan;
4072     sk_msg_t *msg;
4073     int rv;
4074 
4075     DEBUG_ENTER_FUNC;
4076 
4077     assert(q);
4078     assert(message || length == 0);
4079     assert((NULL == free_fn) ? (0 == no_copy) : (1 == no_copy));
4080     ASSERT_QUEUE_LOCK(q);
4081 
4082     if (int_dict_get(q->root->channel, lchannel, &chan) == NULL) {
4083         if (no_copy) {
4084             free_fn(message);
4085         }
4086         RETURN(-1);
4087     }
4088 
4089     if (chan->state == SKM_CLOSED && send_type != SKM_SEND_INTERNAL) {
4090         if (no_copy) {
4091             free_fn(message);
4092         }
4093         RETURN(0);
4094     }
4095 
4096     msg = (sk_msg_t*)malloc(sizeof(sk_msg_t)
4097                             + (length ? sizeof(struct iovec) : 0));
4098     MEM_ASSERT(msg);
4099 
4100     msg->free_fn = msg_simple_free;
4101     msg->simple_free = NULL;
4102 
4103     msg->segment[0].iov_base = &msg->hdr;
4104     msg->segment[0].iov_len = sizeof(msg->hdr);
4105 
4106     if (!length) {
4107         msg->segments = 1;
4108     } else {
4109         msg->segments = 2;
4110         msg->segment[1].iov_len = length;
4111         if (no_copy) {
4112             msg->simple_free = free_fn;
4113             msg->segment[1].iov_base = message;
4114         } else {
4115             msg->segment[1].iov_base = malloc(length);
4116             if (msg->segment[1].iov_base == NULL) {
4117                 free(msg);
4118                 RETURN(-1);
4119             }
4120             memcpy(msg->segment[1].iov_base, message, length);
4121         }
4122     }
4123 
4124     msg->hdr.type = type;
4125     msg->hdr.size = length;
4126 
4127     rv = send_message_internal(chan, msg, send_type);
4128     if (rv != 0) {
4129         skMsgDestroy(msg);
4130         RETURN(-1);
4131     }
4132 
4133     RETURN(0);
4134 }
4135 
4136 
4137 static int
send_message_internal(sk_msg_channel_queue_t * chan,sk_msg_t * msg,sk_send_type_t send_type)4138 send_message_internal(
4139     sk_msg_channel_queue_t *chan,
4140     sk_msg_t               *msg,
4141     sk_send_type_t          send_type)
4142 {
4143     skDQErr_t err;
4144     int rv;
4145 
4146     DEBUG_ENTER_FUNC;
4147 
4148     assert(chan);
4149     assert(msg);
4150 
4151     switch (send_type) {
4152       case SKM_SEND_INTERNAL:
4153         msg->hdr.channel = chan->channel;
4154         DEBUG_PRINT3("Enqueue: chan=%#x type=%#x",
4155                      msg->hdr.channel, msg->hdr.type);
4156         rv = mqQueueAdd(chan->queue, msg);
4157         if (rv != 0) {
4158             RETURN(-1);
4159         }
4160         break;
4161       case SKM_SEND_REMOTE:
4162         msg->hdr.channel = chan->rchannel;
4163         err = skDequePushFront(chan->conn->queue, msg);
4164         if (err != SKDQ_SUCCESS) {
4165             RETURN(-1);
4166         }
4167         break;
4168       case SKM_SEND_CONTROL:
4169         msg->hdr.channel = SKMSG_CHANNEL_CONTROL;
4170         err = skDequePushFront(chan->conn->queue, msg);
4171         if (err != SKDQ_SUCCESS) {
4172             RETURN(-1);
4173         }
4174         break;
4175       default:
4176         skAbortBadCase(send_type);
4177     }
4178 
4179     RETURN(0);
4180 }
4181 
4182 
4183 int
skMsgQueueSendMessage(sk_msg_queue_t * q,skm_channel_t channel,skm_type_t type,const void * message,skm_len_t length)4184 skMsgQueueSendMessage(
4185     sk_msg_queue_t     *q,
4186     skm_channel_t       channel,
4187     skm_type_t          type,
4188     const void         *message,
4189     skm_len_t           length)
4190 {
4191     int rv;
4192 
4193     DEBUG_ENTER_FUNC;
4194 
4195     QUEUE_LOCK(q);
4196     rv = send_message(q, channel, type, (void *)message,
4197                       length, SKM_SEND_REMOTE, 0, NULL);
4198     QUEUE_UNLOCK(q);
4199     RETURN(rv);
4200 }
4201 
4202 
4203 int
skMsgQueueInjectMessage(sk_msg_queue_t * q,skm_channel_t channel,skm_type_t type,const void * message,skm_len_t length)4204 skMsgQueueInjectMessage(
4205     sk_msg_queue_t     *q,
4206     skm_channel_t       channel,
4207     skm_type_t          type,
4208     const void         *message,
4209     skm_len_t           length)
4210 {
4211     int rv;
4212 
4213     DEBUG_ENTER_FUNC;
4214 
4215     QUEUE_LOCK(q);
4216     rv = send_message(q, channel, type, (void *)message,
4217                       length, SKM_SEND_INTERNAL, 0, NULL);
4218     QUEUE_UNLOCK(q);
4219     RETURN(rv);
4220 }
4221 
4222 int
skMsgQueueSendMessageNoCopy(sk_msg_queue_t * q,skm_channel_t channel,skm_type_t type,void * message,skm_len_t length,void (* free_fn)(void *))4223 skMsgQueueSendMessageNoCopy(
4224     sk_msg_queue_t     *q,
4225     skm_channel_t       channel,
4226     skm_type_t          type,
4227     void               *message,
4228     skm_len_t           length,
4229     void              (*free_fn)(void *))
4230 {
4231     int rv;
4232 
4233     DEBUG_ENTER_FUNC;
4234 
4235     QUEUE_LOCK(q);
4236     rv = send_message(q, channel, type, message, length, SKM_SEND_REMOTE,
4237                       1, free_fn);
4238     QUEUE_UNLOCK(q);
4239     RETURN(rv);
4240 }
4241 
4242 int
skMsgQueueScatterSendMessageNoCopy(sk_msg_queue_t * q,skm_channel_t channel,skm_type_t type,uint16_t num_segments,struct iovec * segments,void (* free_fn)(uint16_t,struct iovec *))4243 skMsgQueueScatterSendMessageNoCopy(
4244     sk_msg_queue_t     *q,
4245     skm_channel_t       channel,
4246     skm_type_t          type,
4247     uint16_t            num_segments,
4248     struct iovec       *segments,
4249     void              (*free_fn)(uint16_t, struct iovec *))
4250 {
4251     sk_msg_channel_queue_t *chan;
4252     sk_msg_t               *msg;
4253     size_t                  size;
4254     uint16_t                i;
4255     int                     rv;
4256 
4257     DEBUG_ENTER_FUNC;
4258 
4259     assert(q);
4260     assert((num_segments && segments) || (!num_segments && !segments));
4261 
4262     QUEUE_LOCK(q);
4263 
4264     if (int_dict_get(q->root->channel, channel, &chan) == NULL) {
4265         free_fn(num_segments, segments);
4266         rv = -1;
4267         goto end;
4268     }
4269 
4270     if (chan->state == SKM_CLOSED) {
4271         free_fn(num_segments, segments);
4272         rv = 0;
4273         goto end;
4274     }
4275 
4276     msg = (sk_msg_t*)malloc(sizeof(sk_msg_t)
4277                             + sizeof(struct iovec) * num_segments);
4278     MEM_ASSERT(msg);
4279 
4280     msg->free_fn = free_fn;
4281     msg->simple_free = NULL;
4282 
4283     msg->segment[0].iov_base = &msg->hdr;
4284     msg->segment[0].iov_len = sizeof(msg->hdr);
4285     msg->segments = 1;
4286 
4287     msg->hdr.type = type;
4288     size = 0;
4289 
4290     /* add all segments to msg before checking the size; this ensures
4291      * they all get freed if the overall message size is too large */
4292     for (i = 0; i < num_segments; i++) {
4293         msg->segment[i + 1] = segments[i];
4294         size += segments[i].iov_len;
4295         ++msg->segments;
4296     }
4297     if (size > UINT16_MAX) {
4298         memset(&msg->hdr, 0, sizeof(msg->hdr));
4299         skMsgDestroy(msg);
4300         rv = -1;
4301         goto end;
4302     }
4303 
4304     msg->hdr.size = size;
4305 
4306     rv = send_message_internal(chan, msg, SKM_SEND_REMOTE);
4307     if (rv != 0) {
4308         skMsgDestroy(msg);
4309     }
4310 
4311   end:
4312     QUEUE_UNLOCK(q);
4313 
4314     RETURN(rv);
4315 }
4316 
4317 
4318 int
skMsgQueueInjectMessageNoCopy(sk_msg_queue_t * q,skm_channel_t channel,skm_type_t type,void * message,skm_len_t length,void (* free_fn)(void *))4319 skMsgQueueInjectMessageNoCopy(
4320     sk_msg_queue_t     *q,
4321     skm_channel_t       channel,
4322     skm_type_t          type,
4323     void               *message,
4324     skm_len_t           length,
4325     void              (*free_fn)(void *))
4326 {
4327     int rv;
4328 
4329     DEBUG_ENTER_FUNC;
4330 
4331     QUEUE_LOCK(q);
4332     rv = send_message(q, channel, type, message, length, SKM_SEND_INTERNAL,
4333                       1, free_fn);
4334     QUEUE_UNLOCK(q);
4335     RETURN(rv);
4336 }
4337 
4338 
4339 int
skMsgQueueGetMessage(sk_msg_queue_t * q,sk_msg_t ** message)4340 skMsgQueueGetMessage(
4341     sk_msg_queue_t     *q,
4342     sk_msg_t          **message)
4343 {
4344     sk_msg_t *msg;
4345     sk_msg_channel_queue_t *chan;
4346     int rv;
4347 
4348     DEBUG_ENTER_FUNC;
4349 
4350     assert(q);
4351     assert(message);
4352 
4353     do {
4354         rv = mqGet(q->group, (void **)&msg);
4355         if (rv != 0) {
4356             RETURN(-1);
4357         }
4358         DEBUG_PRINT2("From GetMessage: %p", (void *)msg);
4359         DEBUG_PRINT4("Dequeue: chan=%#x type=%#x size=%d",
4360                      msg->hdr.channel, msg->hdr.type, msg->hdr.size);
4361 
4362         chan = find_channel(q, msg->hdr.channel);
4363     } while (chan == NULL);
4364 
4365     *message = msg;
4366 
4367     RETURN(0);
4368 }
4369 
4370 
4371 int
skMsgQueueGetMessageFromChannel(sk_msg_queue_t * q,skm_channel_t channel,sk_msg_t ** message)4372 skMsgQueueGetMessageFromChannel(
4373     sk_msg_queue_t     *q,
4374     skm_channel_t       channel,
4375     sk_msg_t          **message)
4376 {
4377     sk_msg_t *msg;
4378     sk_msg_channel_queue_t *chan;
4379     int rv;
4380 
4381     DEBUG_ENTER_FUNC;
4382 
4383     assert(q);
4384     assert(message);
4385 
4386     chan = find_channel(q, channel);
4387 
4388     if (chan == NULL) {
4389         RETURN(-1);
4390     }
4391 
4392     rv = mqQueueGet(chan->queue, (void **)&msg);
4393     if (rv != 0) {
4394         RETURN(-1);
4395     }
4396     DEBUG_PRINT4("Dequeue: chan=%#x type=%#x size=%d",
4397                  msg->hdr.channel, msg->hdr.type, msg->hdr.size);
4398 
4399     assert(msg->hdr.channel == channel);
4400 
4401     chan = find_channel(q, msg->hdr.channel);
4402 
4403     if (chan == NULL) {
4404         RETURN(-1);
4405     }
4406 
4407     *message = msg;
4408 
4409     RETURN(0);
4410 }
4411 
4412 int
skMsgGetRemoteChannelID(sk_msg_queue_t * q,skm_channel_t lchannel,skm_channel_t * rchannel)4413 skMsgGetRemoteChannelID(
4414     sk_msg_queue_t     *q,
4415     skm_channel_t       lchannel,
4416     skm_channel_t      *rchannel)
4417 {
4418     sk_msg_channel_queue_t *chan;
4419     int                     retval;
4420 
4421     DEBUG_ENTER_FUNC;
4422 
4423     assert(q);
4424 
4425     chan = find_channel(q, lchannel);
4426     if (chan == NULL) {
4427         retval = -1;
4428     } else {
4429         *rchannel = chan->rchannel;
4430         retval = 0;
4431     }
4432 
4433     return retval;
4434 }
4435 
4436 
4437 int
skMsgSetKeepalive(sk_msg_queue_t * q,skm_channel_t channel,uint16_t keepalive)4438 skMsgSetKeepalive(
4439     sk_msg_queue_t     *q,
4440     skm_channel_t       channel,
4441     uint16_t            keepalive)
4442 {
4443     sk_msg_channel_queue_t *chan;
4444     int                     retval;
4445 
4446     DEBUG_ENTER_FUNC;
4447 
4448     assert(q);
4449 
4450     QUEUE_LOCK(q);
4451 
4452     chan = find_channel(q, channel);
4453     if (chan == NULL || chan->state != SKM_CONNECTED) {
4454         retval = -1;
4455     } else {
4456         chan->conn->keepalive = keepalive;
4457         unblock_connection(q, chan->conn);
4458         retval = 0;
4459     }
4460     QUEUE_UNLOCK(q);
4461 
4462     return retval;
4463 }
4464 
4465 
4466 int
skMsgGetConnectionInformation(sk_msg_queue_t * q,skm_channel_t channel,char * buffer,size_t buffer_size)4467 skMsgGetConnectionInformation(
4468     sk_msg_queue_t     *q,
4469     skm_channel_t       channel,
4470     char               *buffer,
4471     size_t              buffer_size)
4472 {
4473     sk_msg_channel_queue_t *chan;
4474     sk_msg_conn_queue_t    *conn;
4475     int                     rv;
4476 
4477     DEBUG_ENTER_FUNC;
4478 
4479     assert(q);
4480     assert(buffer);
4481 
4482     QUEUE_LOCK(q);
4483     chan = find_channel(q, channel);
4484     if (chan == NULL) {
4485         QUEUE_UNLOCK(q);
4486         RETURN(-1);
4487     }
4488     conn = chan->conn;
4489     if (conn == NULL) {
4490         QUEUE_UNLOCK(q);
4491         RETURN(-1);
4492     }
4493 
4494 #if SK_ENABLE_GNUTLS
4495     if (conn->use_tls) {
4496         const char *protocol;
4497         const char *encryption;
4498         const char *key_exchange;
4499         const char *mac;
4500 
4501         protocol = gnutls_protocol_get_name(
4502             gnutls_protocol_get_version(conn->session));
4503         encryption = gnutls_cipher_get_name(gnutls_cipher_get(conn->session));
4504         key_exchange = gnutls_kx_get_name(gnutls_kx_get(conn->session));
4505         mac = gnutls_mac_get_name(gnutls_mac_get(conn->session));
4506         /* compression = gnutls_compression_get_name(
4507          *   gnutls_compression_get(conn->session)); */
4508         QUEUE_UNLOCK(q);
4509 
4510         rv = snprintf(buffer, buffer_size, "TCP, %s, %s, %s, %s",
4511                       protocol, encryption, key_exchange, mac);
4512         RETURN(rv);
4513     }
4514 #endif  /* SK_ENABLE_GNUTLS */
4515 
4516     QUEUE_UNLOCK(q);
4517     rv = snprintf(buffer, buffer_size, "TCP");
4518     RETURN(rv);
4519 }
4520 
4521 
4522 int
skMsgGetLocalPort(sk_msg_queue_t * q,skm_channel_t channel,uint16_t * port)4523 skMsgGetLocalPort(
4524     sk_msg_queue_t     *q,
4525     skm_channel_t       channel,
4526     uint16_t           *port)
4527 {
4528     sk_msg_channel_queue_t *chan;
4529     sk_msg_conn_queue_t    *conn;
4530     sk_sockaddr_t           addr;
4531     socklen_t               addrlen;
4532     int rv;
4533 
4534     DEBUG_ENTER_FUNC;
4535 
4536     assert(q);
4537     assert(port);
4538 
4539     rv = -1;
4540 
4541     QUEUE_LOCK(q);
4542     chan = find_channel(q, channel);
4543     if (chan == NULL) {
4544         goto END;
4545     }
4546     conn = chan->conn;
4547     if (conn == NULL) {
4548         goto END;
4549     }
4550 
4551     addrlen = sizeof(addr);
4552     if (-1 == getsockname(conn->rsocket, &addr.sa, &addrlen)) {
4553         goto END;
4554     }
4555 
4556     *port = skSockaddrGetPort(&addr);
4557     rv = 0;
4558 
4559   END:
4560     QUEUE_UNLOCK(q);
4561     RETURN(rv);
4562 }
4563 
4564 
4565 void
skMsgDestroy(sk_msg_t * msg)4566 skMsgDestroy(
4567     sk_msg_t           *msg)
4568 {
4569     DEBUG_ENTER_FUNC;
4570 
4571     assert(msg);
4572 
4573     if (msg->segments == 2 && msg->simple_free) {
4574         msg->simple_free(msg->segment[1].iov_base);
4575     } else if (msg->segments > 1 && msg->free_fn) {
4576         msg->free_fn(msg->segments - 1, &msg->segment[1]);
4577     }
4578     /* We use a static message to unblock the writer queue.  It should
4579      * not be freed. */
4580     if (msg->hdr.channel != SKMSG_CHANNEL_CONTROL ||
4581         msg->hdr.type != SKMSG_WRITER_UNBLOCKER)
4582     {
4583         free(msg);
4584     }
4585 
4586     RETURN_VOID;
4587 }
4588 
4589 
4590 skm_channel_t
skMsgChannel(const sk_msg_t * msg)4591 skMsgChannel(
4592     const sk_msg_t     *msg)
4593 {
4594     DEBUG_ENTER_FUNC;
4595 
4596     assert(msg);
4597     RETURN(msg->hdr.channel);
4598 }
4599 
4600 
4601 skm_type_t
skMsgType(const sk_msg_t * msg)4602 skMsgType(
4603     const sk_msg_t     *msg)
4604 {
4605     DEBUG_ENTER_FUNC;
4606 
4607     assert(msg);
4608     RETURN(msg->hdr.type);
4609 }
4610 
4611 
4612 skm_len_t
skMsgLength(const sk_msg_t * msg)4613 skMsgLength(
4614     const sk_msg_t     *msg)
4615 {
4616     DEBUG_ENTER_FUNC;
4617 
4618     assert(msg);
4619     RETURN(msg->hdr.size);
4620 }
4621 
4622 
4623 const void *
skMsgMessage(const sk_msg_t * msg)4624 skMsgMessage(
4625     const sk_msg_t     *msg)
4626 {
4627     DEBUG_ENTER_FUNC;
4628 
4629     assert(msg);
4630     if (msg->segments == 0) {
4631         RETURN(NULL);
4632     } else {
4633         RETURN(msg->segment[1].iov_base);
4634     }
4635 }
4636 
4637 
4638 #if !SK_ENABLE_GNUTLS
4639 /* no-op function used when gnutls is not available */
4640 
4641 void
skMsgGnuTLSTeardown(void)4642 skMsgGnuTLSTeardown(
4643     void)
4644 {
4645 }
4646 
4647 int
skMsgTlsOptionsRegister(const char * passwd_env_name)4648 skMsgTlsOptionsRegister(
4649     const char         *passwd_env_name)
4650 {
4651     SK_UNUSED_PARAM(passwd_env_name);
4652     return 0;
4653 }
4654 
4655 void
skMsgTlsOptionsUsage(FILE * fh)4656 skMsgTlsOptionsUsage(
4657     FILE               *fh)
4658 {
4659     SK_UNUSED_PARAM(fh);
4660 }
4661 
4662 int
skMsgTlsOptionsVerify(unsigned int * tls_available)4663 skMsgTlsOptionsVerify(
4664     unsigned int       *tls_available)
4665 {
4666     if (tls_available) {
4667         *tls_available = 0;
4668     }
4669     return 0;
4670 }
4671 
4672 #endif /* !SK_ENABLE_GNUTLS */
4673 
4674 /*
4675 ** Local Variables:
4676 ** mode:c
4677 ** indent-tabs-mode:nil
4678 ** c-basic-offset:4
4679 ** End:
4680 */
4681