1 /***************************************************************************
2  *                                                                         *
3  *   This program is free software; you can redistribute it and/or modify  *
4  *   it under the terms of the GNU General Public License as published by  *
5  *   the Free Software Foundation; either version 2 of the License, or     *
6  *   (at your option) any later version.                                   *
7  *                                                                         *
8  ***************************************************************************
9  * net/socket.cc
10  * (C) 2000-2008 Murat Deligonul
11  */
12 
13 #include "autoconf.h"
14 
15 #include <cstdlib>
16 #include <cstring>
17 #include <cstdio>
18 #include <sys/time.h>
19 #include <sys/types.h>
20 #include <unistd.h>
21 #include <fcntl.h>
22 #include "util/strings.h"
23 #include "util/tokenizer.h"
24 #include "io/error.h"
25 #include "net/error.h"
26 #include "net/socket.h"
27 #include "net/resolver.h"
28 #include "net/radaptor.h"
29 #include "debug.h"
30 
31 namespace net {
32 
33 using namespace util::strings;
34 
35 /* static */ io::engine * socket::engine 	= NULL;
36 /* static */ radaptor * socket::radaptor 	= NULL;
37 /* static */ resolver * socket::resolver 	= NULL;
38 
39 #ifdef HAVE_SSL
40 /* static */ SSL_CTX *socket::ssl_ctx 		= NULL;
41 /* static */ char   *socket::tls_rand_file 	= NULL;
42 #endif
43 
44 /******************************************************
45  * All the constructors throw:
46  * socket_exception     -- if socket(), accept() fails
47  *                      -- if socket table is full
48  *                      -- if SSL initialization fails
49  ******************************************************/
50 
51 /**
52   * Create a new socket from scratch
53   * family:     PF_INET
54   *             PF_INET6      -- both supported
55   *             PF_UNSPEC (0) -- defer creation
56   * */
socket(int family,int options,size_t buffMin,size_t buffMax)57 socket::socket(int family, int options, size_t buffMin, size_t buffMax)
58 	: r_callback(this),
59 	  ibuff(buffMin, buffMax),
60 	  obuff(buffMin, buffMax),
61 	  interface_data(NULL)
62 {
63 	DEBUG("socket::socket(family: %d, ...)\n", family);
64 	lookup_id = 0;
65 	tmp_req = NULL;
66 
67 	if (family != PF_UNSPEC && family != 0)	{
68 		int f = this->open(family);
69 		if (f < 0) {
70 			throw socket_exception(net::strerror(f));
71 		}
72 		this->family = family;
73 		state = OPEN;
74 		set_nonblocking(f);
75 		update_addr_info(0);
76 	}
77 	else {
78 		// defer creation
79 		this->family = PF_UNSPEC;
80 		state = NEW;
81 	}
82 	this->options = options;
83 
84 #ifdef HAVE_SSL
85 	ssl = NULL;
86 #endif
87 }
88 
89 /**
90  * Create a socket, by accept()'ing a connection from 'source'
91  * SSL accepts are handled automatically (if source->ssl != NULL)
92  */
socket(socket * source,size_t buffMin,size_t buffMax)93 socket::socket(socket * source, size_t buffMin, size_t buffMax)
94 	: r_callback(this),
95 		ibuff(buffMin, buffMax),
96 		obuff(buffMin, buffMax),
97 		interface_data(NULL)
98 {
99    	DEBUG("socket::socket(%p, %zd, %zd) [%p]\n", source, buffMin, buffMax, this);
100 	lookup_id = 0;
101 	tmp_req = NULL;
102 
103 #ifdef HAVE_SSL
104 	ssl = NULL;
105 #endif
106 
107 	if (!source || !source->get_state() == LISTENING) {
108 		throw socket_exception("Invalid source for accept()");
109 	}
110 
111 	struct sockaddr_storage saddr;
112 	socklen_t addrlen = sizeof saddr;
113 	int f = accept(source->get_fd(), (sockaddr *) &saddr, &addrlen);
114 	if (f < 0) {
115 		throw socket_exception("accept() failed");
116 	}
117 
118 	set_fd(f);
119 	int r = engine->add(this);
120 	if (r < 0) {
121 		::close(f);
122 		throw socket_exception("Socket table is full");
123     	}
124 	set_nonblocking(f);
125 
126 	options = source->options;
127 	family = source->family;
128 
129 	// SSL?
130 #ifdef HAVE_SSL
131 	if (options & SOCK_SSL) {
132 		options |= SOCK_SSL;
133 		state = ACCEPTING;
134 	}
135 	else
136 #endif
137 	state = CONNECTED;
138 
139 	update_addr_info(0);
140 	update_addr_info(1);
141 }
142 
143 
144 /**
145   * Open a new socket and put it in the table.
146   * Return:
147   *     --  > -1 -- success, and the new file descriptor
148   *     --  < 0  -- failures:
149   *	 	ERR_ALREADY_OPEN	 ditto
150   *		ERR_UNABLE		 open() failed
151   *		ERR_TABLE_FULL		 out of space in socket
152   */
open(int domain,int protocol)153 int socket::open(int domain, int protocol)
154 {
155 	DEBUG("socket::open() [%p]\n", this);
156 
157 	if (get_fd() != -1) {
158 		return ERR_ALREADY_OPEN;
159 	}
160 
161 	int f = ::socket(domain, SOCK_STREAM, protocol);
162 	if (f < 0) {
163 		return ERR_UNABLE;
164 	}
165 
166 	set_fd(f);
167 	int r = engine->add(this);
168 	if (r < 0) {
169 		::close(f);
170 		set_fd(-1);
171 		return ERR_TABLE_FULL;
172     	}
173 	family = domain;
174 	state = OPEN;
175 	assert(get_fd() != -1);
176 	return f;
177 }
178 
179 /**
180  * Make socket listen for connections.
181  * -- you must specify an interface, "0.0.0.0 or ::0" at the very
182  *    least
183  * return:
184  *      ERR_INTERFACE -- host invalid
185  *      ERR_SYSCALL   -- bind() or listen() failure
186  *       0                 -- success
187  *       other < 0	   -- other socket error codes
188  *
189  *
190  * NOTES: on_readable() is called when connection is waiting
191  */
listen(const char * interface,unsigned short port,int backlog,int opt)192 int socket::listen(const char * interface, unsigned short port, int backlog, int opt)
193 {
194 	DEBUG("socket::listen() [%p] (%s, %u, %d, %d)\n", this, interface, port, backlog,opt );
195 
196 	if (state == CLOSED) {
197 		return ERR_SOCK_CLOSED;
198 	}
199 	if (state == LISTENING) {
200 		return ERR_LISTENING;
201 	}
202 	if (state != NEW && state != OPEN) {
203 		return ERR_PROGRESS;
204 	}
205 	// FIXME: ensure sock isn't in middle of connect or DNS resolve
206 
207 	struct addrinfo * ai;
208 	if (resolver::resolve_address(family, AI_ADDRCONFIG, interface, port, &ai)) {
209 		return ERR_INTERFACE;
210 	}
211 
212 	int f = get_fd();
213 	if (f < 0) {
214 		f = open(ai->ai_family);
215 	}
216 	if (f < 0) {
217 		freeaddrinfo(ai);
218 		return -1;
219 	}
220 	set_fd(f);
221 	set_nonblocking(f);
222 
223 	int parm = 1;
224 	setsockopt(f, SOL_SOCKET, SO_REUSEADDR, (const char * ) &parm, sizeof(int));
225 
226 	if (::bind(f, ai->ai_addr, ai->ai_addrlen)
227 		|| ::listen(f, backlog)) {
228 		// FIXME: close the socket here?
229 		freeaddrinfo(ai);
230 		return ERR_SYSCALL;
231 	}
232 	freeaddrinfo(ai);
233 
234 	state = LISTENING;
235 	options = opt;
236 
237 	set_events(io::EVENT_READ);
238 
239 	update_addr_info(false);
240 	return 0;
241 }
242 
243 /**
244  * Attempt to listen on a port randomly selected from a range of
245  * ports.
246  * Possible returns:
247  *      1           -- Success
248  *      ERR_DNS
249  *      ERR_FAILURE
250  */
listen(const char * hostname,const char * port_range,int backlog,int opt)251 int socket::listen(const char * hostname, const char * port_range, int backlog, int opt)
252 {
253 	using std::string;
254 	std::vector<string> tokens;
255 
256 	if (!is_non_empty(port_range)) {
257 		return this->listen(hostname, (unsigned short) 0, backlog, opt);
258 	}
259 
260 	tokenize(port_range, ",", tokens);
261 
262 	for (std::vector<string>::const_iterator i = tokens.begin(), e = tokens.end();
263 		i != e;
264 		++i) {
265 		unsigned short lbound, ubound = 0;
266 		const string& tok = *i;
267 		string bounds[2];
268 
269 		tokenize(tok.c_str(), "-", &bounds[0], 2);
270 
271 		lbound = atoi(bounds[0].c_str());
272 		if (!bounds[1].empty()) {
273 			ubound = atoi(bounds[1].c_str());
274 		}
275 
276 		DEBUG("bind_port(): Testing listen ports lower: %d upper: %d\n", lbound, ubound);
277 		if (!ubound) {
278 			ubound = lbound;
279 		}
280 		for (int i = lbound; i <= ubound; ++i) {
281 			DEBUG("bind_port(): Now calling bind() on port %d\n", i);
282 			int j =  this->listen(hostname, (unsigned short) i, backlog, opt);
283 			/* bail out if the hostname is bad, or
284 			 * if we succeeded */
285 			if (j == ERR_INTERFACE || j == 1) {
286 				return j;
287 			}
288 			/* on non-fatal errors, keep looping */
289 			if (j != ERR_SYSCALL) {
290 				return j;
291 			}
292 		}
293 	}
294 	return ERR_FAILURE;
295 }
296 
297 
298 /**
299  * Call connect() on this file descriptor to initiate
300  * a non-blocking connect()
301  *
302  * args:
303  *      where       -- where to connect
304  *      interface   -- what interface to connect to (IPv6 or IPv4)
305  *      port
306  *      options     -- currently, only 0 or SSL is supported
307  *
308  * return:
309  *      0                       -- connect now in progress
310  *	ERR_PROGRESS 		-- already connecting
311  *	ERR_ALREADY_OPEN 	-- already connected
312  *
313  * With the async. DNS lookup this is a little different.  We don't actually
314  * call connect() until both the interface and target addresses have been
315  * resolved.  Should these lookups fail, socket::on_connect_fail() is called
316  * with the appropriate error message.
317  *
318  * So really there is nothing too useful returned from this function, the most
319  * useful errors will be detected by on_connect_fail().
320  *
321  * NOTE: on_connect_fail() may not assume fd != -1 or other properties relating
322  *		to a succesful call to socket()
323  *
324  */
async_connect(const char * where,const char * interface,unsigned short port,int options)325 int socket::async_connect(const char * where, const char * interface, unsigned short port,
326                         int options)
327 {
328 	 DEBUG("socket::connect [%p] (%s, %s, %u, %i)\n", this, where, interface, port, options);
329 
330 	 if (state == CLOSED) {
331 		 return ERR_SOCK_CLOSED;
332 	 }
333 	 if (state == CONNECTED) {
334 		 return ERR_ALREADY_OPEN;
335 	 }
336 	 if (state != NEW && state != OPEN) {
337 		 return ERR_PROGRESS;
338 	 }
339 
340 	 /* Setup async DNS lookup for both
341 	  * interface hostname and target hostname */
342 	 assert(tmp_req == 0);
343 	 resolver::request * req1 = NULL, * req2 = NULL;
344 
345 	 req2 = resolver->create_request(family, where, port, 0, &r_callback);
346 	 if (is_non_empty(interface)) {
347 		 /* We have two things to lookup */
348 		 req1 = resolver->create_request(family, interface, 0, 0, &r_callback);
349 		 tmp_req = req2;
350 		 lookup_id = resolver->async_lookup(req1);
351 	 }
352 	 else {
353 		 lookup_id = resolver->async_lookup(req2);
354 	 }
355 
356 #ifdef HAVE_SSL
357 	 /* Set SSL flags if requested */
358 	 if (options & SOCK_SSL) {
359 		 this->options = SOCK_SSL;
360 	 }
361 	 else
362 #endif
363 		 this->options = 0;
364 
365 	 state = CONNECT_RESOLVING;
366 	 return 0;
367 }
368 
369 /**
370  * Close the file descriptor, flush buffers,
371  * minimize buffer memory usage and reset all socket
372  * state and option variables
373  */
close()374 int socket::close()
375 {
376 	DEBUG("socket::close() [%p]\n", this);
377 
378 	if (get_fd() > -1) {
379 #ifdef HAVE_SSL
380 		if (ssl) {
381 			DEBUG("Shutting down SSL for %p...\n", this);
382 			SSL_shutdown(ssl);
383 			SSL_free(ssl);
384 			ssl = NULL;
385   		}
386 #endif
387 		if (state == CONNECTED) {
388 			flushO();
389 		}
390 		engine->release(this);
391 		::close(get_fd());
392 		set_fd(-1);
393 
394 		update_addr_info(true);
395 		update_addr_info(false);
396 	}
397 
398 	ibuff.clear();
399 	obuff.clear();
400 	optimize_buffers();
401 
402 	/* Clean up resolver state */
403 	delete tmp_req;
404 	tmp_req = NULL;
405 	if (lookup_id != 0) {
406 		resolver->cancel_async_lookup(lookup_id);
407 	}
408 	lookup_id = 0;
409 	if (interface_data != NULL) {
410 		delete[] interface_data->first;
411 		delete interface_data;
412 		interface_data = NULL;
413 	}
414 
415 	state = CLOSED;
416 	family = PF_UNSPEC;
417 	options = 0;
418 
419 #ifdef HAVE_SSL
420 	assert(get_fd() == -1 && ssl == NULL);
421 #endif
422 	return 0;
423 }
424 
425 /**
426  * Checking for socket close sucks. But we must do it ASAP
427  * to simplify work for derived class sockets.
428  * This does a recv test to see if the connection is still alive
429  */
check_for_close() const430 int socket::check_for_close() const
431 {
432 	int e;
433 	switch (recv_test(get_fd())) {
434 	case 0:		/* socket closed */
435 		e = ERR_CLOSED;
436 		break;
437 	case -1:
438 		e = ERR_SYSCALL;
439 		break;
440 	default:
441 		return 0;
442 	}
443 	return e;
444 }
445 
446 /** 'pollable' interface **/
event_callback(int ev)447 int socket::event_callback(int ev)
448 {
449 	int ret = 0;
450 	/**
451 	  * Socket readable.
452 	  */
453 	if (ev & io::EVENT_READ) {
454 		if (state != LISTENING) {
455 			int e = check_for_close();
456 			if (e < 0) {
457 				if (state == CONNECTING) {
458 					on_connect_fail(e);
459 				}
460 				else {
461 					on_disconnect(e);
462 				}
463 				return -1;
464 			}
465 		}
466 
467 #ifdef HAVE_SSL
468 		/** handle special cases for **
469 		 ** non-blocking SSL setup **/
470 		if (options & SOCK_SSL)	{
471 			if (state == CONNECTING) {
472 				switch (switch_to_ssl()) {
473 				case ERR_SSL:
474 					on_connect_fail(ERR_SSL);
475 					return -1;
476 				case 0:
477 					on_connect();
478 					return 0;
479 
480 				case ERR_AGAIN:
481 					break;
482 				}
483 			}
484 			else if (state == ACCEPTING) {
485 				switch (accept_to_ssl()) {
486 				case ERR_SSL:
487 					on_disconnect(ERR_SSL);
488 					return -1;
489 				case 0:
490 					state = CONNECTED;
491 				case ERR_AGAIN:
492 					break;
493 				}
494 			}
495 			else {
496 				ret = on_readable();
497 			}
498 		}
499 		else
500 #endif
501 			if (state == CONNECTING) {
502 				// usually socket will only be writeable on connect
503 				// however, if data was immediately sent to it upon connection,
504 				// it can be marked both readable and writeable
505 				on_connect();
506 				return 0;
507 			}
508 			else {
509 				ret = on_readable();
510 			}
511 	}
512 
513 	/**
514 	  * Socket writeable.
515 	  */
516 	if ((ev & io::EVENT_WRITE) && ret >= 0) {
517 		if (state == CONNECTING) {
518 			int e = check_for_close();
519 			if (e < 0) {
520 				on_connect_fail(e);
521 				return -1;
522 			}
523 #ifdef HAVE_SSL
524 			if (options & SOCK_SSL)	{
525 				set_events(io::EVENT_READ);
526 				switch_to_ssl();
527 			}
528 			else
529 #endif
530 			{
531 				on_connect();
532 				return 0;
533 			}
534 #ifdef HAVE_SSL
535 		}
536 		else if (state == ACCEPTING) {
537 			if (options & SOCK_SSL)	{
538 				set_events(io::EVENT_READ);
539 				accept_to_ssl();
540 			}
541 #endif
542 		}
543 		else {
544 			ret = on_writeable();
545 		}
546 	}
547 
548 	/**
549 	  * Weird error condition.
550 	  */
551 	if ((ev & io::EVENT_ERROR) && ret >= 0) {
552 		if (state == CONNECTING) {
553 			on_connect_fail(ERR_HUP);
554 		}
555 		else {
556 			on_disconnect(ERR_HUP);
557 		}
558 		return -1;
559 	}
560 	return ret;
561 }
562 
563 
564 /** resolver callback functions **/
async_lookup_finished(const resolver::result * req)565 int socket::async_lookup_finished(const resolver::result * req)
566 {
567 	/* This means asnyc lookup *succeeded */
568 	assert(state == CONNECT_RESOLVING
569 		|| state == LISTEN_RESOLVING);
570 	assert(lookup_id == req->id);
571 	DEBUG("socket::async_lookup_finished() [%p] for %s\n", this, req->name);
572 
573 	lookup_id = 0;
574 	if (state == CONNECT_RESOLVING) {
575 		/** We just got done looking up the interface **/
576 		if (tmp_req != NULL) {
577 			/** save the data; we don't bind() till connect()-time **/
578 			const size_t len =  req->ai->ai_addrlen;
579 
580 			interface_data = new std::pair<unsigned char *, size_t>( new unsigned char[len], len);
581 			memcpy(interface_data->first, req->ai->ai_addr, interface_data->second);
582 			lookup_id = resolver->async_lookup(tmp_req);
583 			tmp_req = NULL;
584 			return 0;
585 		}
586 		/* Got the target address, ready to connect */
587 		tmp_req = NULL;
588 
589 		/* was the resulting socket family different than expected ? */
590 		if (family != req->ai->ai_family) {
591 			if (get_fd() != -1) {
592 				int t = options;		/* close() will clobber this */
593 				close();
594 				options = t;
595 			}
596 			int i = open(req->ai->ai_family);
597 			if (i < 0) {
598 				state = CONNECTING;
599 				on_connect_fail(i);
600 				return 1;
601 			}
602 		}
603 		int f = get_fd();
604 		set_nonblocking(f);
605 		state = CONNECTING;
606 
607 		if (interface_data != NULL) {
608 			int r = ::bind(f, (const sockaddr *) interface_data->first, interface_data->second);
609 			delete[] interface_data->first;
610 			delete interface_data;
611 			interface_data = NULL;
612 			if (r < 0) {
613 				/* bind failed: */
614 				on_connect_fail(ERR_INTERFACE);
615 				return 1;
616 			}
617 		}
618 
619             	if (::connect(f, req->ai->ai_addr, req->ai->ai_addrlen) < 0) {
620         		if (errno != EINPROGRESS) {
621         			on_connect_fail(ERR_SYSCALL);
622 				return 1;
623 			}
624 		}
625 
626 		update_addr_info(false);
627 		update_addr_info(true, req->ai->ai_addr, req->ai->ai_addrlen);
628 		on_connecting();
629 		set_events(io::EVENT_READ | io::EVENT_WRITE);
630         }
631 	return 0;
632 }
633 
async_lookup_failed(const resolver::result * req)634 int socket::async_lookup_failed(const resolver::result * req)
635 {
636 	/* This means asnyc lookup *FAILED* */
637 	assert(state == CONNECT_RESOLVING || state == LISTEN_RESOLVING);
638 	assert(lookup_id == req->id);
639 
640 	DEBUG("socket::async_lookup_failed() [%p] for %s\n", this, req->name);
641 
642 	lookup_id = 0;
643 	if (state == CONNECT_RESOLVING) {
644 		state = CONNECTING;
645 		if (tmp_req != NULL) {  /* if this was the interface */
646 			delete tmp_req;
647 			tmp_req = NULL;
648 			on_connect_fail(ERR_INTERFACE);
649 		}
650 		else {
651 			tmp_req = NULL;
652 			on_connect_fail(ERR_DNS);
653 		}
654 	}
655 	return 0;
656 }
657 
658 /**
659  * Updates the 'peer' and 'local' data fields by calling
660  * getpeername()/getsockname().
661  */
update_addr_info(bool is_peer)662 int socket::update_addr_info(bool is_peer)
663 {
664 	DEBUG("socket::update_addr_info(%d) [%p]\n", is_peer, this);
665 
666 	if (get_fd() < 0) {
667 		if (is_peer) {
668 			this->peer = ap_pair("[not connected]", 0);
669 		}
670 		else {
671 			this->local = ap_pair("[not connected]", 0);
672 		}
673 		return -1;
674 	}
675 
676 	struct sockaddr_storage saddr;
677 	socklen_t len = sizeof(saddr);
678 
679 	if (is_peer) {
680 		if (getpeername(get_fd(), (struct sockaddr *) &saddr, &len)) {
681 			return -1;
682 		}
683 	}
684 	else {
685 		if (getsockname(get_fd(), (struct sockaddr *) &saddr, &len)) {
686 			return -1;
687 		}
688 	}
689 	return update_addr_info(is_peer, (struct sockaddr *) &saddr, len);
690 }
691 
692 /**
693  * Updates the 'peer' and 'local' data fields by populating it from
694  * a given network address structure.
695  */
update_addr_info(bool is_peer,const sockaddr * source,size_t len)696 int socket::update_addr_info(bool is_peer, const sockaddr * source, size_t len)
697 {
698 	assert(family != 0);
699 	char addrbuff[MAX_ADDRSTRLEN+1] = "";
700 	unsigned short port;
701 	if (resolver::raw_to_ip(source, len, addrbuff, sizeof addrbuff, &port) != 0) {
702 		return -1;
703 	}
704 
705 	if (is_peer) {
706 		this->peer = ap_pair(addrbuff, port);
707 	}
708 	else {
709 		this->local = ap_pair(addrbuff, port);
710 	}
711 	return 0;
712 }
713 
714 /** Default Event Handlers **/
on_disconnect(int)715 void socket::on_disconnect(int)
716 {
717 	DEBUG("[%p] socket::on_disconnect()\n", this);
718 	close();
719 }
720 
on_connect_fail(int)721 void socket::on_connect_fail(int)
722 {
723 	DEBUG("[%p] socket::on_connect_fail()\n", this);
724 	close();
725 }
726 
on_connecting()727 void socket::on_connecting()
728 {
729 }
730 
on_connect()731 void socket::on_connect()
732 {
733 	DEBUG("[%p] socket::on_connect()\n", this);
734 	assert(state != CONNECTED);
735 	state = CONNECTED;
736 	update_addr_info(0);
737 	update_addr_info(1);
738 }
739 
printf(const char * format,...)740 int socket::printf(const char * format, ...)
741 {
742 	va_list ap;
743 	va_start(ap, format);
744 	int r = printf_raw(format, ap);
745 	va_end(ap);
746 	if (r > 0) {
747 		flushO();
748 	}
749 	return r;
750 }
751 
752 
753 /**
754  * Like above but does not flush data.
755  */
printfQ(const char * format,...)756 int socket::printfQ(const char * format, ...)
757 {
758 	va_list ap;
759 	va_start(ap, format);
760 	int r = printf_raw(format, ap);
761 	va_end(ap);
762 	return r;
763 }
764 
765 /**
766  * Updated 5/04: We will NOT queue incomplete messages.
767  * 	   9/07: Flushes buffer to make room, if necessary and possible.
768  */
printf_raw(const char * fmt,va_list ap)769 int socket::printf_raw(const char * fmt, va_list ap)
770 {
771 	char * ptr = NULL;
772 	int len = my_vasprintf(&ptr, fmt, ap);
773 	if (len < 0) {
774 		return ERR_MEM;
775 	}
776 	if (buffer_availableO() < size_t(len)) {
777 		flushO();
778 	}
779 	len = queue(ptr, len);
780 	delete[] ptr;
781 	return len;
782 }
783 
784 /**
785  * Return:
786  * -1:      error, socket died
787  *  0:      socket closed from eof
788  *  1:      socket ok.
789  *  2:      socket ok, stuff to read;
790  */
recv_test(int fd)791 /* static */ int socket::recv_test(int fd)
792 {
793 	char dummy;
794 	switch (recv(fd, &dummy, 1, MSG_PEEK))
795 	{
796 	case 0:
797 		return 0;
798 	case -1:
799 		if (errno == EAGAIN) {
800 			return 1;
801 		}
802 		return -1;
803 	}
804 	return 2;
805 }
806 
807 
808 /** SSL-specific code follows **/
809 #ifdef HAVE_SSL
init_ssl(const char * certfile)810 /* static */ int socket::init_ssl(const char * certfile)
811 {
812 	if (ssl_ctx == NULL) {
813 		DEBUG("Initializing SSL...\n");
814 		SSL_load_error_strings();
815 		OpenSSL_add_ssl_algorithms();
816 		ssl_ctx = SSL_CTX_new(SSLv23_method());
817 
818 		if (!ssl_ctx)  {
819 			DEBUG("SSL_CTX_new() failed\n");
820 			return -1;
821 		}
822 
823 		if (seed_PRNG()) {
824 			DEBUG("Wasn't able to properly seed the PRNG!\n");
825 			SSL_CTX_free(ssl_ctx);
826             		ssl_ctx=NULL;
827 			return -1;
828 		}
829 
830 		SSL_CTX_use_certificate_file(ssl_ctx, certfile,SSL_FILETYPE_PEM);
831 		SSL_CTX_use_RSAPrivateKey_file(ssl_ctx,certfile,SSL_FILETYPE_PEM);
832 		if (!SSL_CTX_check_private_key(ssl_ctx)) {
833 			DEBUG("Error loading private key/certificate, set correct file in options...\n");
834 			SSL_CTX_free(ssl_ctx);
835 			ssl_ctx=NULL;
836 			return -1;
837 		}
838         }
839     	return 0;
840 }
841 
shutdown_ssl()842 /* static */ int socket::shutdown_ssl()
843 {
844 	if (ssl_ctx) {
845 		DEBUG("Freeing SSL context...");
846 		SSL_CTX_free(ssl_ctx);
847 		ssl_ctx = NULL;
848 	}
849 	if (tls_rand_file) {
850 		RAND_write_file(tls_rand_file);
851 	}
852 	return 1;
853 }
854 
855 /**
856  * Return values for SSL functions:
857  *
858  * 	0		-- Success
859  * 	ERR_SSL  	-- fatal error
860  * 	ERR_AGAIN 	-- try again
861  */
switch_to_ssl()862 int socket::switch_to_ssl()
863 {
864 	int err;
865 	assert(get_fd() > -1);
866 	assert(options & SOCK_SSL);
867 	DEBUG("socket::switch_to_ssl() [%p] \n", this);
868 	if (!ssl) {
869 		ssl = SSL_new(ssl_ctx);
870 		if (!ssl) {
871 			DEBUG("socket::switch_to_SSL() [%p] -- SSL_new() failed\n", this);
872 			return ERR_SSL; /* fatal */
873 		}
874 		SSL_set_fd(ssl, get_fd());
875 	}
876 
877 	err = SSL_connect(ssl);
878 
879 	if (err < 1) {
880 		err = SSL_get_error(ssl, err);
881 		if (err != SSL_ERROR_WANT_READ && err != SSL_ERROR_WANT_WRITE) {
882 			DEBUG("Error while SSL_connect()\n");
883 			DEBUG("SSL_ERROR: %d\n", SSL_get_error(ssl, err));
884 			SSL_shutdown(ssl);
885 			SSL_free(ssl);
886 			ssl = NULL;
887 			return ERR_SSL;
888 		}
889 		return ERR_AGAIN;   /* try again */
890 	}
891 	return 0;
892 }
893 
accept_to_ssl()894 int socket::accept_to_ssl()
895 {
896 	assert(get_fd() > -1);
897 	assert(options & SOCK_SSL);
898 	int err;
899 
900 	DEBUG("socket::accept_to_ssl() [%p] \n", this);
901 	if (!ssl) {
902 		ssl = SSL_new(ssl_ctx);
903 		if (!ssl) {
904 			DEBUG("SSL_new() failed\n");
905 			return ERR_SSL;
906 		}
907 		SSL_set_fd(ssl, get_fd());
908 	}
909 	err = SSL_accept(ssl);
910 
911 	if (err < 1) {
912 		err = SSL_get_error(ssl, err);
913 		if (err != SSL_ERROR_WANT_READ && err != SSL_ERROR_WANT_WRITE) {
914 			DEBUG("Error while SSL_accept()\n");
915 			DEBUG("SSL_ERROR: %d\n", SSL_get_error(ssl, err));
916 			SSL_shutdown(ssl);
917 			SSL_free(ssl);
918 			ssl = NULL;
919 			return ERR_SSL;
920 	        }
921         	return ERR_AGAIN;   /* try again */
922     	}
923 	return 0;
924 }
925 
926 /**
927   * Seed the OpenSSL PRNG.
928   * Nothing to do if /dev/urandom exists.
929   */
seed_PRNG()930 int socket::seed_PRNG()
931 {
932 	char stackdata[1024];
933 	static char rand_file[300];
934 
935 #if OPENSSL_VERSION_NUMBER >= 0x00905100
936 	if (RAND_status()) {
937 		return 0;     /* PRNG already well seeded */
938 	}
939 #endif
940 	/**
941 	 * If the device '/dev/urandom' is present, OpenSSL uses it by default.
942 	 * check if it's present, else we have to make random data ourselves.
943 	 */
944 	if (access("/dev/urandom", R_OK) == 0) {
945 		DEBUG("socket::seed_PRNG(): using /dev/urandom() by default\n");
946 		return 0;
947 	}
948 	if (RAND_file_name(rand_file, sizeof(rand_file))) {
949 		tls_rand_file = rand_file;
950 	}
951 	else {
952 		DEBUG("socket::seed_PRNG(): unable to create random seed file\n");
953 		return 1;
954 	}
955 	if (!RAND_load_file(rand_file, sizeof rand_file)) {
956 		/* no .rnd file found, create new seed */
957 		time_t c;
958 		c = time(NULL);
959 		RAND_seed(&c, sizeof(c));
960 		c = getpid();
961 		RAND_seed(&c, sizeof(c));
962 		RAND_seed(stackdata, sizeof(stackdata));
963 	}
964 #if OPENSSL_VERSION_NUMBER >= 0x00905100
965 	if (!RAND_status()) {
966 		return 2;   /* PRNG still badly seeded */
967 	}
968 #endif
969 	return 0;
970 }
971 #endif
972 
973 } /* namespace net */
974