1 /***************************************************************************
2 * *
3 * This program is free software; you can redistribute it and/or modify *
4 * it under the terms of the GNU General Public License as published by *
5 * the Free Software Foundation; either version 2 of the License, or *
6 * (at your option) any later version. *
7 * *
8 ***************************************************************************
9 * net/socket.cc
10 * (C) 2000-2008 Murat Deligonul
11 */
12
13 #include "autoconf.h"
14
15 #include <cstdlib>
16 #include <cstring>
17 #include <cstdio>
18 #include <sys/time.h>
19 #include <sys/types.h>
20 #include <unistd.h>
21 #include <fcntl.h>
22 #include "util/strings.h"
23 #include "util/tokenizer.h"
24 #include "io/error.h"
25 #include "net/error.h"
26 #include "net/socket.h"
27 #include "net/resolver.h"
28 #include "net/radaptor.h"
29 #include "debug.h"
30
31 namespace net {
32
33 using namespace util::strings;
34
35 /* static */ io::engine * socket::engine = NULL;
36 /* static */ radaptor * socket::radaptor = NULL;
37 /* static */ resolver * socket::resolver = NULL;
38
39 #ifdef HAVE_SSL
40 /* static */ SSL_CTX *socket::ssl_ctx = NULL;
41 /* static */ char *socket::tls_rand_file = NULL;
42 #endif
43
44 /******************************************************
45 * All the constructors throw:
46 * socket_exception -- if socket(), accept() fails
47 * -- if socket table is full
48 * -- if SSL initialization fails
49 ******************************************************/
50
51 /**
52 * Create a new socket from scratch
53 * family: PF_INET
54 * PF_INET6 -- both supported
55 * PF_UNSPEC (0) -- defer creation
56 * */
socket(int family,int options,size_t buffMin,size_t buffMax)57 socket::socket(int family, int options, size_t buffMin, size_t buffMax)
58 : r_callback(this),
59 ibuff(buffMin, buffMax),
60 obuff(buffMin, buffMax),
61 interface_data(NULL)
62 {
63 DEBUG("socket::socket(family: %d, ...)\n", family);
64 lookup_id = 0;
65 tmp_req = NULL;
66
67 if (family != PF_UNSPEC && family != 0) {
68 int f = this->open(family);
69 if (f < 0) {
70 throw socket_exception(net::strerror(f));
71 }
72 this->family = family;
73 state = OPEN;
74 set_nonblocking(f);
75 update_addr_info(0);
76 }
77 else {
78 // defer creation
79 this->family = PF_UNSPEC;
80 state = NEW;
81 }
82 this->options = options;
83
84 #ifdef HAVE_SSL
85 ssl = NULL;
86 #endif
87 }
88
89 /**
90 * Create a socket, by accept()'ing a connection from 'source'
91 * SSL accepts are handled automatically (if source->ssl != NULL)
92 */
socket(socket * source,size_t buffMin,size_t buffMax)93 socket::socket(socket * source, size_t buffMin, size_t buffMax)
94 : r_callback(this),
95 ibuff(buffMin, buffMax),
96 obuff(buffMin, buffMax),
97 interface_data(NULL)
98 {
99 DEBUG("socket::socket(%p, %zd, %zd) [%p]\n", source, buffMin, buffMax, this);
100 lookup_id = 0;
101 tmp_req = NULL;
102
103 #ifdef HAVE_SSL
104 ssl = NULL;
105 #endif
106
107 if (!source || !source->get_state() == LISTENING) {
108 throw socket_exception("Invalid source for accept()");
109 }
110
111 struct sockaddr_storage saddr;
112 socklen_t addrlen = sizeof saddr;
113 int f = accept(source->get_fd(), (sockaddr *) &saddr, &addrlen);
114 if (f < 0) {
115 throw socket_exception("accept() failed");
116 }
117
118 set_fd(f);
119 int r = engine->add(this);
120 if (r < 0) {
121 ::close(f);
122 throw socket_exception("Socket table is full");
123 }
124 set_nonblocking(f);
125
126 options = source->options;
127 family = source->family;
128
129 // SSL?
130 #ifdef HAVE_SSL
131 if (options & SOCK_SSL) {
132 options |= SOCK_SSL;
133 state = ACCEPTING;
134 }
135 else
136 #endif
137 state = CONNECTED;
138
139 update_addr_info(0);
140 update_addr_info(1);
141 }
142
143
144 /**
145 * Open a new socket and put it in the table.
146 * Return:
147 * -- > -1 -- success, and the new file descriptor
148 * -- < 0 -- failures:
149 * ERR_ALREADY_OPEN ditto
150 * ERR_UNABLE open() failed
151 * ERR_TABLE_FULL out of space in socket
152 */
open(int domain,int protocol)153 int socket::open(int domain, int protocol)
154 {
155 DEBUG("socket::open() [%p]\n", this);
156
157 if (get_fd() != -1) {
158 return ERR_ALREADY_OPEN;
159 }
160
161 int f = ::socket(domain, SOCK_STREAM, protocol);
162 if (f < 0) {
163 return ERR_UNABLE;
164 }
165
166 set_fd(f);
167 int r = engine->add(this);
168 if (r < 0) {
169 ::close(f);
170 set_fd(-1);
171 return ERR_TABLE_FULL;
172 }
173 family = domain;
174 state = OPEN;
175 assert(get_fd() != -1);
176 return f;
177 }
178
179 /**
180 * Make socket listen for connections.
181 * -- you must specify an interface, "0.0.0.0 or ::0" at the very
182 * least
183 * return:
184 * ERR_INTERFACE -- host invalid
185 * ERR_SYSCALL -- bind() or listen() failure
186 * 0 -- success
187 * other < 0 -- other socket error codes
188 *
189 *
190 * NOTES: on_readable() is called when connection is waiting
191 */
listen(const char * interface,unsigned short port,int backlog,int opt)192 int socket::listen(const char * interface, unsigned short port, int backlog, int opt)
193 {
194 DEBUG("socket::listen() [%p] (%s, %u, %d, %d)\n", this, interface, port, backlog,opt );
195
196 if (state == CLOSED) {
197 return ERR_SOCK_CLOSED;
198 }
199 if (state == LISTENING) {
200 return ERR_LISTENING;
201 }
202 if (state != NEW && state != OPEN) {
203 return ERR_PROGRESS;
204 }
205 // FIXME: ensure sock isn't in middle of connect or DNS resolve
206
207 struct addrinfo * ai;
208 if (resolver::resolve_address(family, AI_ADDRCONFIG, interface, port, &ai)) {
209 return ERR_INTERFACE;
210 }
211
212 int f = get_fd();
213 if (f < 0) {
214 f = open(ai->ai_family);
215 }
216 if (f < 0) {
217 freeaddrinfo(ai);
218 return -1;
219 }
220 set_fd(f);
221 set_nonblocking(f);
222
223 int parm = 1;
224 setsockopt(f, SOL_SOCKET, SO_REUSEADDR, (const char * ) &parm, sizeof(int));
225
226 if (::bind(f, ai->ai_addr, ai->ai_addrlen)
227 || ::listen(f, backlog)) {
228 // FIXME: close the socket here?
229 freeaddrinfo(ai);
230 return ERR_SYSCALL;
231 }
232 freeaddrinfo(ai);
233
234 state = LISTENING;
235 options = opt;
236
237 set_events(io::EVENT_READ);
238
239 update_addr_info(false);
240 return 0;
241 }
242
243 /**
244 * Attempt to listen on a port randomly selected from a range of
245 * ports.
246 * Possible returns:
247 * 1 -- Success
248 * ERR_DNS
249 * ERR_FAILURE
250 */
listen(const char * hostname,const char * port_range,int backlog,int opt)251 int socket::listen(const char * hostname, const char * port_range, int backlog, int opt)
252 {
253 using std::string;
254 std::vector<string> tokens;
255
256 if (!is_non_empty(port_range)) {
257 return this->listen(hostname, (unsigned short) 0, backlog, opt);
258 }
259
260 tokenize(port_range, ",", tokens);
261
262 for (std::vector<string>::const_iterator i = tokens.begin(), e = tokens.end();
263 i != e;
264 ++i) {
265 unsigned short lbound, ubound = 0;
266 const string& tok = *i;
267 string bounds[2];
268
269 tokenize(tok.c_str(), "-", &bounds[0], 2);
270
271 lbound = atoi(bounds[0].c_str());
272 if (!bounds[1].empty()) {
273 ubound = atoi(bounds[1].c_str());
274 }
275
276 DEBUG("bind_port(): Testing listen ports lower: %d upper: %d\n", lbound, ubound);
277 if (!ubound) {
278 ubound = lbound;
279 }
280 for (int i = lbound; i <= ubound; ++i) {
281 DEBUG("bind_port(): Now calling bind() on port %d\n", i);
282 int j = this->listen(hostname, (unsigned short) i, backlog, opt);
283 /* bail out if the hostname is bad, or
284 * if we succeeded */
285 if (j == ERR_INTERFACE || j == 1) {
286 return j;
287 }
288 /* on non-fatal errors, keep looping */
289 if (j != ERR_SYSCALL) {
290 return j;
291 }
292 }
293 }
294 return ERR_FAILURE;
295 }
296
297
298 /**
299 * Call connect() on this file descriptor to initiate
300 * a non-blocking connect()
301 *
302 * args:
303 * where -- where to connect
304 * interface -- what interface to connect to (IPv6 or IPv4)
305 * port
306 * options -- currently, only 0 or SSL is supported
307 *
308 * return:
309 * 0 -- connect now in progress
310 * ERR_PROGRESS -- already connecting
311 * ERR_ALREADY_OPEN -- already connected
312 *
313 * With the async. DNS lookup this is a little different. We don't actually
314 * call connect() until both the interface and target addresses have been
315 * resolved. Should these lookups fail, socket::on_connect_fail() is called
316 * with the appropriate error message.
317 *
318 * So really there is nothing too useful returned from this function, the most
319 * useful errors will be detected by on_connect_fail().
320 *
321 * NOTE: on_connect_fail() may not assume fd != -1 or other properties relating
322 * to a succesful call to socket()
323 *
324 */
async_connect(const char * where,const char * interface,unsigned short port,int options)325 int socket::async_connect(const char * where, const char * interface, unsigned short port,
326 int options)
327 {
328 DEBUG("socket::connect [%p] (%s, %s, %u, %i)\n", this, where, interface, port, options);
329
330 if (state == CLOSED) {
331 return ERR_SOCK_CLOSED;
332 }
333 if (state == CONNECTED) {
334 return ERR_ALREADY_OPEN;
335 }
336 if (state != NEW && state != OPEN) {
337 return ERR_PROGRESS;
338 }
339
340 /* Setup async DNS lookup for both
341 * interface hostname and target hostname */
342 assert(tmp_req == 0);
343 resolver::request * req1 = NULL, * req2 = NULL;
344
345 req2 = resolver->create_request(family, where, port, 0, &r_callback);
346 if (is_non_empty(interface)) {
347 /* We have two things to lookup */
348 req1 = resolver->create_request(family, interface, 0, 0, &r_callback);
349 tmp_req = req2;
350 lookup_id = resolver->async_lookup(req1);
351 }
352 else {
353 lookup_id = resolver->async_lookup(req2);
354 }
355
356 #ifdef HAVE_SSL
357 /* Set SSL flags if requested */
358 if (options & SOCK_SSL) {
359 this->options = SOCK_SSL;
360 }
361 else
362 #endif
363 this->options = 0;
364
365 state = CONNECT_RESOLVING;
366 return 0;
367 }
368
369 /**
370 * Close the file descriptor, flush buffers,
371 * minimize buffer memory usage and reset all socket
372 * state and option variables
373 */
close()374 int socket::close()
375 {
376 DEBUG("socket::close() [%p]\n", this);
377
378 if (get_fd() > -1) {
379 #ifdef HAVE_SSL
380 if (ssl) {
381 DEBUG("Shutting down SSL for %p...\n", this);
382 SSL_shutdown(ssl);
383 SSL_free(ssl);
384 ssl = NULL;
385 }
386 #endif
387 if (state == CONNECTED) {
388 flushO();
389 }
390 engine->release(this);
391 ::close(get_fd());
392 set_fd(-1);
393
394 update_addr_info(true);
395 update_addr_info(false);
396 }
397
398 ibuff.clear();
399 obuff.clear();
400 optimize_buffers();
401
402 /* Clean up resolver state */
403 delete tmp_req;
404 tmp_req = NULL;
405 if (lookup_id != 0) {
406 resolver->cancel_async_lookup(lookup_id);
407 }
408 lookup_id = 0;
409 if (interface_data != NULL) {
410 delete[] interface_data->first;
411 delete interface_data;
412 interface_data = NULL;
413 }
414
415 state = CLOSED;
416 family = PF_UNSPEC;
417 options = 0;
418
419 #ifdef HAVE_SSL
420 assert(get_fd() == -1 && ssl == NULL);
421 #endif
422 return 0;
423 }
424
425 /**
426 * Checking for socket close sucks. But we must do it ASAP
427 * to simplify work for derived class sockets.
428 * This does a recv test to see if the connection is still alive
429 */
check_for_close() const430 int socket::check_for_close() const
431 {
432 int e;
433 switch (recv_test(get_fd())) {
434 case 0: /* socket closed */
435 e = ERR_CLOSED;
436 break;
437 case -1:
438 e = ERR_SYSCALL;
439 break;
440 default:
441 return 0;
442 }
443 return e;
444 }
445
446 /** 'pollable' interface **/
event_callback(int ev)447 int socket::event_callback(int ev)
448 {
449 int ret = 0;
450 /**
451 * Socket readable.
452 */
453 if (ev & io::EVENT_READ) {
454 if (state != LISTENING) {
455 int e = check_for_close();
456 if (e < 0) {
457 if (state == CONNECTING) {
458 on_connect_fail(e);
459 }
460 else {
461 on_disconnect(e);
462 }
463 return -1;
464 }
465 }
466
467 #ifdef HAVE_SSL
468 /** handle special cases for **
469 ** non-blocking SSL setup **/
470 if (options & SOCK_SSL) {
471 if (state == CONNECTING) {
472 switch (switch_to_ssl()) {
473 case ERR_SSL:
474 on_connect_fail(ERR_SSL);
475 return -1;
476 case 0:
477 on_connect();
478 return 0;
479
480 case ERR_AGAIN:
481 break;
482 }
483 }
484 else if (state == ACCEPTING) {
485 switch (accept_to_ssl()) {
486 case ERR_SSL:
487 on_disconnect(ERR_SSL);
488 return -1;
489 case 0:
490 state = CONNECTED;
491 case ERR_AGAIN:
492 break;
493 }
494 }
495 else {
496 ret = on_readable();
497 }
498 }
499 else
500 #endif
501 if (state == CONNECTING) {
502 // usually socket will only be writeable on connect
503 // however, if data was immediately sent to it upon connection,
504 // it can be marked both readable and writeable
505 on_connect();
506 return 0;
507 }
508 else {
509 ret = on_readable();
510 }
511 }
512
513 /**
514 * Socket writeable.
515 */
516 if ((ev & io::EVENT_WRITE) && ret >= 0) {
517 if (state == CONNECTING) {
518 int e = check_for_close();
519 if (e < 0) {
520 on_connect_fail(e);
521 return -1;
522 }
523 #ifdef HAVE_SSL
524 if (options & SOCK_SSL) {
525 set_events(io::EVENT_READ);
526 switch_to_ssl();
527 }
528 else
529 #endif
530 {
531 on_connect();
532 return 0;
533 }
534 #ifdef HAVE_SSL
535 }
536 else if (state == ACCEPTING) {
537 if (options & SOCK_SSL) {
538 set_events(io::EVENT_READ);
539 accept_to_ssl();
540 }
541 #endif
542 }
543 else {
544 ret = on_writeable();
545 }
546 }
547
548 /**
549 * Weird error condition.
550 */
551 if ((ev & io::EVENT_ERROR) && ret >= 0) {
552 if (state == CONNECTING) {
553 on_connect_fail(ERR_HUP);
554 }
555 else {
556 on_disconnect(ERR_HUP);
557 }
558 return -1;
559 }
560 return ret;
561 }
562
563
564 /** resolver callback functions **/
async_lookup_finished(const resolver::result * req)565 int socket::async_lookup_finished(const resolver::result * req)
566 {
567 /* This means asnyc lookup *succeeded */
568 assert(state == CONNECT_RESOLVING
569 || state == LISTEN_RESOLVING);
570 assert(lookup_id == req->id);
571 DEBUG("socket::async_lookup_finished() [%p] for %s\n", this, req->name);
572
573 lookup_id = 0;
574 if (state == CONNECT_RESOLVING) {
575 /** We just got done looking up the interface **/
576 if (tmp_req != NULL) {
577 /** save the data; we don't bind() till connect()-time **/
578 const size_t len = req->ai->ai_addrlen;
579
580 interface_data = new std::pair<unsigned char *, size_t>( new unsigned char[len], len);
581 memcpy(interface_data->first, req->ai->ai_addr, interface_data->second);
582 lookup_id = resolver->async_lookup(tmp_req);
583 tmp_req = NULL;
584 return 0;
585 }
586 /* Got the target address, ready to connect */
587 tmp_req = NULL;
588
589 /* was the resulting socket family different than expected ? */
590 if (family != req->ai->ai_family) {
591 if (get_fd() != -1) {
592 int t = options; /* close() will clobber this */
593 close();
594 options = t;
595 }
596 int i = open(req->ai->ai_family);
597 if (i < 0) {
598 state = CONNECTING;
599 on_connect_fail(i);
600 return 1;
601 }
602 }
603 int f = get_fd();
604 set_nonblocking(f);
605 state = CONNECTING;
606
607 if (interface_data != NULL) {
608 int r = ::bind(f, (const sockaddr *) interface_data->first, interface_data->second);
609 delete[] interface_data->first;
610 delete interface_data;
611 interface_data = NULL;
612 if (r < 0) {
613 /* bind failed: */
614 on_connect_fail(ERR_INTERFACE);
615 return 1;
616 }
617 }
618
619 if (::connect(f, req->ai->ai_addr, req->ai->ai_addrlen) < 0) {
620 if (errno != EINPROGRESS) {
621 on_connect_fail(ERR_SYSCALL);
622 return 1;
623 }
624 }
625
626 update_addr_info(false);
627 update_addr_info(true, req->ai->ai_addr, req->ai->ai_addrlen);
628 on_connecting();
629 set_events(io::EVENT_READ | io::EVENT_WRITE);
630 }
631 return 0;
632 }
633
async_lookup_failed(const resolver::result * req)634 int socket::async_lookup_failed(const resolver::result * req)
635 {
636 /* This means asnyc lookup *FAILED* */
637 assert(state == CONNECT_RESOLVING || state == LISTEN_RESOLVING);
638 assert(lookup_id == req->id);
639
640 DEBUG("socket::async_lookup_failed() [%p] for %s\n", this, req->name);
641
642 lookup_id = 0;
643 if (state == CONNECT_RESOLVING) {
644 state = CONNECTING;
645 if (tmp_req != NULL) { /* if this was the interface */
646 delete tmp_req;
647 tmp_req = NULL;
648 on_connect_fail(ERR_INTERFACE);
649 }
650 else {
651 tmp_req = NULL;
652 on_connect_fail(ERR_DNS);
653 }
654 }
655 return 0;
656 }
657
658 /**
659 * Updates the 'peer' and 'local' data fields by calling
660 * getpeername()/getsockname().
661 */
update_addr_info(bool is_peer)662 int socket::update_addr_info(bool is_peer)
663 {
664 DEBUG("socket::update_addr_info(%d) [%p]\n", is_peer, this);
665
666 if (get_fd() < 0) {
667 if (is_peer) {
668 this->peer = ap_pair("[not connected]", 0);
669 }
670 else {
671 this->local = ap_pair("[not connected]", 0);
672 }
673 return -1;
674 }
675
676 struct sockaddr_storage saddr;
677 socklen_t len = sizeof(saddr);
678
679 if (is_peer) {
680 if (getpeername(get_fd(), (struct sockaddr *) &saddr, &len)) {
681 return -1;
682 }
683 }
684 else {
685 if (getsockname(get_fd(), (struct sockaddr *) &saddr, &len)) {
686 return -1;
687 }
688 }
689 return update_addr_info(is_peer, (struct sockaddr *) &saddr, len);
690 }
691
692 /**
693 * Updates the 'peer' and 'local' data fields by populating it from
694 * a given network address structure.
695 */
update_addr_info(bool is_peer,const sockaddr * source,size_t len)696 int socket::update_addr_info(bool is_peer, const sockaddr * source, size_t len)
697 {
698 assert(family != 0);
699 char addrbuff[MAX_ADDRSTRLEN+1] = "";
700 unsigned short port;
701 if (resolver::raw_to_ip(source, len, addrbuff, sizeof addrbuff, &port) != 0) {
702 return -1;
703 }
704
705 if (is_peer) {
706 this->peer = ap_pair(addrbuff, port);
707 }
708 else {
709 this->local = ap_pair(addrbuff, port);
710 }
711 return 0;
712 }
713
714 /** Default Event Handlers **/
on_disconnect(int)715 void socket::on_disconnect(int)
716 {
717 DEBUG("[%p] socket::on_disconnect()\n", this);
718 close();
719 }
720
on_connect_fail(int)721 void socket::on_connect_fail(int)
722 {
723 DEBUG("[%p] socket::on_connect_fail()\n", this);
724 close();
725 }
726
on_connecting()727 void socket::on_connecting()
728 {
729 }
730
on_connect()731 void socket::on_connect()
732 {
733 DEBUG("[%p] socket::on_connect()\n", this);
734 assert(state != CONNECTED);
735 state = CONNECTED;
736 update_addr_info(0);
737 update_addr_info(1);
738 }
739
printf(const char * format,...)740 int socket::printf(const char * format, ...)
741 {
742 va_list ap;
743 va_start(ap, format);
744 int r = printf_raw(format, ap);
745 va_end(ap);
746 if (r > 0) {
747 flushO();
748 }
749 return r;
750 }
751
752
753 /**
754 * Like above but does not flush data.
755 */
printfQ(const char * format,...)756 int socket::printfQ(const char * format, ...)
757 {
758 va_list ap;
759 va_start(ap, format);
760 int r = printf_raw(format, ap);
761 va_end(ap);
762 return r;
763 }
764
765 /**
766 * Updated 5/04: We will NOT queue incomplete messages.
767 * 9/07: Flushes buffer to make room, if necessary and possible.
768 */
printf_raw(const char * fmt,va_list ap)769 int socket::printf_raw(const char * fmt, va_list ap)
770 {
771 char * ptr = NULL;
772 int len = my_vasprintf(&ptr, fmt, ap);
773 if (len < 0) {
774 return ERR_MEM;
775 }
776 if (buffer_availableO() < size_t(len)) {
777 flushO();
778 }
779 len = queue(ptr, len);
780 delete[] ptr;
781 return len;
782 }
783
784 /**
785 * Return:
786 * -1: error, socket died
787 * 0: socket closed from eof
788 * 1: socket ok.
789 * 2: socket ok, stuff to read;
790 */
recv_test(int fd)791 /* static */ int socket::recv_test(int fd)
792 {
793 char dummy;
794 switch (recv(fd, &dummy, 1, MSG_PEEK))
795 {
796 case 0:
797 return 0;
798 case -1:
799 if (errno == EAGAIN) {
800 return 1;
801 }
802 return -1;
803 }
804 return 2;
805 }
806
807
808 /** SSL-specific code follows **/
809 #ifdef HAVE_SSL
init_ssl(const char * certfile)810 /* static */ int socket::init_ssl(const char * certfile)
811 {
812 if (ssl_ctx == NULL) {
813 DEBUG("Initializing SSL...\n");
814 SSL_load_error_strings();
815 OpenSSL_add_ssl_algorithms();
816 ssl_ctx = SSL_CTX_new(SSLv23_method());
817
818 if (!ssl_ctx) {
819 DEBUG("SSL_CTX_new() failed\n");
820 return -1;
821 }
822
823 if (seed_PRNG()) {
824 DEBUG("Wasn't able to properly seed the PRNG!\n");
825 SSL_CTX_free(ssl_ctx);
826 ssl_ctx=NULL;
827 return -1;
828 }
829
830 SSL_CTX_use_certificate_file(ssl_ctx, certfile,SSL_FILETYPE_PEM);
831 SSL_CTX_use_RSAPrivateKey_file(ssl_ctx,certfile,SSL_FILETYPE_PEM);
832 if (!SSL_CTX_check_private_key(ssl_ctx)) {
833 DEBUG("Error loading private key/certificate, set correct file in options...\n");
834 SSL_CTX_free(ssl_ctx);
835 ssl_ctx=NULL;
836 return -1;
837 }
838 }
839 return 0;
840 }
841
shutdown_ssl()842 /* static */ int socket::shutdown_ssl()
843 {
844 if (ssl_ctx) {
845 DEBUG("Freeing SSL context...");
846 SSL_CTX_free(ssl_ctx);
847 ssl_ctx = NULL;
848 }
849 if (tls_rand_file) {
850 RAND_write_file(tls_rand_file);
851 }
852 return 1;
853 }
854
855 /**
856 * Return values for SSL functions:
857 *
858 * 0 -- Success
859 * ERR_SSL -- fatal error
860 * ERR_AGAIN -- try again
861 */
switch_to_ssl()862 int socket::switch_to_ssl()
863 {
864 int err;
865 assert(get_fd() > -1);
866 assert(options & SOCK_SSL);
867 DEBUG("socket::switch_to_ssl() [%p] \n", this);
868 if (!ssl) {
869 ssl = SSL_new(ssl_ctx);
870 if (!ssl) {
871 DEBUG("socket::switch_to_SSL() [%p] -- SSL_new() failed\n", this);
872 return ERR_SSL; /* fatal */
873 }
874 SSL_set_fd(ssl, get_fd());
875 }
876
877 err = SSL_connect(ssl);
878
879 if (err < 1) {
880 err = SSL_get_error(ssl, err);
881 if (err != SSL_ERROR_WANT_READ && err != SSL_ERROR_WANT_WRITE) {
882 DEBUG("Error while SSL_connect()\n");
883 DEBUG("SSL_ERROR: %d\n", SSL_get_error(ssl, err));
884 SSL_shutdown(ssl);
885 SSL_free(ssl);
886 ssl = NULL;
887 return ERR_SSL;
888 }
889 return ERR_AGAIN; /* try again */
890 }
891 return 0;
892 }
893
accept_to_ssl()894 int socket::accept_to_ssl()
895 {
896 assert(get_fd() > -1);
897 assert(options & SOCK_SSL);
898 int err;
899
900 DEBUG("socket::accept_to_ssl() [%p] \n", this);
901 if (!ssl) {
902 ssl = SSL_new(ssl_ctx);
903 if (!ssl) {
904 DEBUG("SSL_new() failed\n");
905 return ERR_SSL;
906 }
907 SSL_set_fd(ssl, get_fd());
908 }
909 err = SSL_accept(ssl);
910
911 if (err < 1) {
912 err = SSL_get_error(ssl, err);
913 if (err != SSL_ERROR_WANT_READ && err != SSL_ERROR_WANT_WRITE) {
914 DEBUG("Error while SSL_accept()\n");
915 DEBUG("SSL_ERROR: %d\n", SSL_get_error(ssl, err));
916 SSL_shutdown(ssl);
917 SSL_free(ssl);
918 ssl = NULL;
919 return ERR_SSL;
920 }
921 return ERR_AGAIN; /* try again */
922 }
923 return 0;
924 }
925
926 /**
927 * Seed the OpenSSL PRNG.
928 * Nothing to do if /dev/urandom exists.
929 */
seed_PRNG()930 int socket::seed_PRNG()
931 {
932 char stackdata[1024];
933 static char rand_file[300];
934
935 #if OPENSSL_VERSION_NUMBER >= 0x00905100
936 if (RAND_status()) {
937 return 0; /* PRNG already well seeded */
938 }
939 #endif
940 /**
941 * If the device '/dev/urandom' is present, OpenSSL uses it by default.
942 * check if it's present, else we have to make random data ourselves.
943 */
944 if (access("/dev/urandom", R_OK) == 0) {
945 DEBUG("socket::seed_PRNG(): using /dev/urandom() by default\n");
946 return 0;
947 }
948 if (RAND_file_name(rand_file, sizeof(rand_file))) {
949 tls_rand_file = rand_file;
950 }
951 else {
952 DEBUG("socket::seed_PRNG(): unable to create random seed file\n");
953 return 1;
954 }
955 if (!RAND_load_file(rand_file, sizeof rand_file)) {
956 /* no .rnd file found, create new seed */
957 time_t c;
958 c = time(NULL);
959 RAND_seed(&c, sizeof(c));
960 c = getpid();
961 RAND_seed(&c, sizeof(c));
962 RAND_seed(stackdata, sizeof(stackdata));
963 }
964 #if OPENSSL_VERSION_NUMBER >= 0x00905100
965 if (!RAND_status()) {
966 return 2; /* PRNG still badly seeded */
967 }
968 #endif
969 return 0;
970 }
971 #endif
972
973 } /* namespace net */
974