1 /*
2 * Copyright (C) 2000-2016 Free Software Foundation, Inc.
3 * Copyright (C) 2015-2016 Red Hat, Inc.
4 *
5 * This file is part of GnuTLS.
6 *
7 * GnuTLS is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * GnuTLS is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with this program. If not, see <https://www.gnu.org/licenses/>.
19 */
20
21 #include <config.h>
22
23 #if HAVE_SYS_SOCKET_H
24 #include <sys/socket.h>
25 #elif HAVE_WS2TCPIP_H
26 #include <ws2tcpip.h>
27 #endif
28 #include <netdb.h>
29 #include <string.h>
30 #include <errno.h>
31 #include <sys/select.h>
32 #include <sys/types.h>
33 #include <stdio.h>
34 #include <stdlib.h>
35 #include <unistd.h>
36 #include <arpa/inet.h>
37 #include <socket.h>
38 #include <c-ctype.h>
39 #include "sockets.h"
40 #include "common.h"
41
42 #ifdef _WIN32
43 # undef endservent
44 # define endservent()
45 #endif
46
47 #define MAX_BUF 4096
48
49 /* Functions to manipulate sockets
50 */
51
52 ssize_t
socket_recv(const socket_st * socket,void * buffer,int buffer_size)53 socket_recv(const socket_st * socket, void *buffer, int buffer_size)
54 {
55 int ret;
56
57 if (socket->secure) {
58 do {
59 ret =
60 gnutls_record_recv(socket->session, buffer,
61 buffer_size);
62 if (ret == GNUTLS_E_HEARTBEAT_PING_RECEIVED)
63 gnutls_heartbeat_pong(socket->session, 0);
64 }
65 while (ret == GNUTLS_E_INTERRUPTED
66 || ret == GNUTLS_E_HEARTBEAT_PING_RECEIVED);
67
68 } else
69 do {
70 ret = recv(socket->fd, buffer, buffer_size, 0);
71 }
72 while (ret == -1 && errno == EINTR);
73
74 return ret;
75 }
76
77 ssize_t
socket_recv_timeout(const socket_st * socket,void * buffer,int buffer_size,unsigned ms)78 socket_recv_timeout(const socket_st * socket, void *buffer, int buffer_size, unsigned ms)
79 {
80 int ret;
81
82 if (socket->secure)
83 gnutls_record_set_timeout(socket->session, ms);
84 ret = socket_recv(socket, buffer, buffer_size);
85
86 if (socket->secure)
87 gnutls_record_set_timeout(socket->session, 0);
88
89 return ret;
90 }
91
92 ssize_t
socket_send(const socket_st * socket,const void * buffer,int buffer_size)93 socket_send(const socket_st * socket, const void *buffer, int buffer_size)
94 {
95 return socket_send_range(socket, buffer, buffer_size, NULL);
96 }
97
98
99 ssize_t
socket_send_range(const socket_st * socket,const void * buffer,int buffer_size,gnutls_range_st * range)100 socket_send_range(const socket_st * socket, const void *buffer,
101 int buffer_size, gnutls_range_st * range)
102 {
103 int ret;
104
105 if (socket->secure)
106 do {
107 if (range == NULL)
108 ret =
109 gnutls_record_send(socket->session,
110 buffer,
111 buffer_size);
112 else
113 ret =
114 gnutls_record_send_range(socket->
115 session,
116 buffer,
117 buffer_size,
118 range);
119 }
120 while (ret == GNUTLS_E_AGAIN
121 || ret == GNUTLS_E_INTERRUPTED);
122 else
123 do {
124 ret = send(socket->fd, buffer, buffer_size, 0);
125 }
126 while (ret == -1 && errno == EINTR);
127
128 if (ret > 0 && ret != buffer_size && socket->verbose)
129 fprintf(stderr,
130 "*** Only sent %d bytes instead of %d.\n", ret,
131 buffer_size);
132
133 return ret;
134 }
135
136 static
send_line(socket_st * socket,const char * txt)137 ssize_t send_line(socket_st * socket, const char *txt)
138 {
139 int len = strlen(txt);
140 int ret;
141
142 if (socket->verbose)
143 fprintf(stderr, "starttls: sending: %s\n", txt);
144
145 ret = send(socket->fd, txt, len, 0);
146
147 if (ret == -1) {
148 fprintf(stderr, "error sending \"%s\"\n", txt);
149 exit(2);
150 }
151
152 return ret;
153 }
154
155 static
wait_for_text(socket_st * socket,const char * txt,unsigned txt_size)156 ssize_t wait_for_text(socket_st * socket, const char *txt, unsigned txt_size)
157 {
158 char buf[1024];
159 char *pbuf, *p;
160 int ret;
161 fd_set read_fds;
162 struct timeval tv;
163 size_t left, got;
164
165 if (txt_size > sizeof(buf))
166 abort();
167
168 if (socket->verbose && txt != NULL)
169 fprintf(stderr, "starttls: waiting for: \"%.*s\"\n", txt_size, txt);
170
171 pbuf = buf;
172 left = sizeof(buf)-1;
173 got = 0;
174
175 do {
176 FD_ZERO(&read_fds);
177 FD_SET(socket->fd, &read_fds);
178 tv.tv_sec = 10;
179 tv.tv_usec = 0;
180 ret = select(socket->fd + 1, &read_fds, NULL, NULL, &tv);
181 if (ret > 0)
182 ret = recv(socket->fd, pbuf, left, 0);
183 if (ret == -1) {
184 fprintf(stderr, "error receiving '%s': %s\n", txt, strerror(errno));
185 exit(2);
186 } else if (ret == 0) {
187 fprintf(stderr, "error receiving '%s': Timeout\n", txt);
188 exit(2);
189 }
190 pbuf[ret] = 0;
191
192 if (txt == NULL)
193 break;
194
195 if (socket->verbose)
196 fprintf(stderr, "starttls: received: %s\n", pbuf);
197
198 pbuf += ret;
199 left -= ret;
200 got += ret;
201
202
203 /* check for text after a newline in buffer */
204 if (got > txt_size) {
205 p = memmem(buf, got, txt, txt_size);
206 if (p != NULL && p != buf) {
207 p--;
208 if (*p == '\n' || *p == '\r' || (*txt == '<' && *p == '>')) // XMPP is not line oriented, uses XML format
209 break;
210 }
211 }
212 } while(got < txt_size || strncmp(buf, txt, txt_size) != 0);
213
214 return got;
215 }
216
217 static void
socket_starttls(socket_st * socket)218 socket_starttls(socket_st * socket)
219 {
220 char buf[512];
221
222 if (socket->secure)
223 return;
224
225 if (socket->app_proto == NULL || strcasecmp(socket->app_proto, "https") == 0)
226 return;
227
228 if (strcasecmp(socket->app_proto, "smtp") == 0 || strcasecmp(socket->app_proto, "submission") == 0) {
229 if (socket->verbose)
230 log_msg(stdout, "Negotiating SMTP STARTTLS\n");
231
232 wait_for_text(socket, "220 ", 4);
233 snprintf(buf, sizeof(buf), "EHLO %s\r\n", socket->hostname);
234 send_line(socket, buf);
235 wait_for_text(socket, "250 ", 4);
236 send_line(socket, "STARTTLS\r\n");
237 wait_for_text(socket, "220 ", 4);
238 } else if (strcasecmp(socket->app_proto, "imap") == 0 || strcasecmp(socket->app_proto, "imap2") == 0) {
239 if (socket->verbose)
240 log_msg(stdout, "Negotiating IMAP STARTTLS\n");
241
242 send_line(socket, "a CAPABILITY\r\n");
243 wait_for_text(socket, "a OK", 4);
244 send_line(socket, "a STARTTLS\r\n");
245 wait_for_text(socket, "a OK", 4);
246 } else if (strcasecmp(socket->app_proto, "xmpp") == 0) {
247 if (socket->verbose)
248 log_msg(stdout, "Negotiating XMPP STARTTLS\n");
249
250 snprintf(buf, sizeof(buf), "<stream:stream xmlns:stream='http://etherx.jabber.org/streams' xmlns='jabber:client' to='%s' version='1.0'>\n", socket->hostname);
251 send_line(socket, buf);
252 wait_for_text(socket, "<?", 2);
253 send_line(socket, "<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>");
254 wait_for_text(socket, "<proceed", 8);
255 } else if (strcasecmp(socket->app_proto, "ldap") == 0) {
256 if (socket->verbose)
257 log_msg(stdout, "Negotiating LDAP STARTTLS\n");
258 #define LDAP_STR "\x30\x1d\x02\x01\x01\x77\x18\x80\x16\x31\x2e\x33\x2e\x36\x2e\x31\x2e\x34\x2e\x31\x2e\x31\x34\x36\x36\x2e\x32\x30\x30\x33\x37"
259 send(socket->fd, LDAP_STR, sizeof(LDAP_STR)-1, 0);
260 wait_for_text(socket, NULL, 0);
261 } else if (strcasecmp(socket->app_proto, "ftp") == 0 || strcasecmp(socket->app_proto, "ftps") == 0) {
262 if (socket->verbose)
263 log_msg(stdout, "Negotiating FTP STARTTLS\n");
264
265 send_line(socket, "FEAT\r\n");
266 wait_for_text(socket, "211 ", 4);
267 send_line(socket, "AUTH TLS\r\n");
268 wait_for_text(socket, "234", 3);
269 } else if (strcasecmp(socket->app_proto, "lmtp") == 0) {
270 if (socket->verbose)
271 log_msg(stdout, "Negotiating LMTP STARTTLS\n");
272
273 wait_for_text(socket, "220 ", 4);
274 snprintf(buf, sizeof(buf), "LHLO %s\r\n", socket->hostname);
275 send_line(socket, buf);
276 wait_for_text(socket, "250 ", 4);
277 send_line(socket, "STARTTLS\r\n");
278 wait_for_text(socket, "220 ", 4);
279 } else if (strcasecmp(socket->app_proto, "pop3") == 0) {
280 if (socket->verbose)
281 log_msg(stdout, "Negotiating POP3 STARTTLS\n");
282
283 wait_for_text(socket, "+OK", 3);
284 send_line(socket, "STLS\r\n");
285 wait_for_text(socket, "+OK", 3);
286 } else if (strcasecmp(socket->app_proto, "nntp") == 0) {
287 if (socket->verbose)
288 log_msg(stdout, "Negotiating NNTP STARTTLS\n");
289
290 wait_for_text(socket, "200 ", 4);
291 send_line(socket, "STARTTLS\r\n");
292 wait_for_text(socket, "382 ", 4);
293 } else if (strcasecmp(socket->app_proto, "sieve") == 0) {
294 if (socket->verbose)
295 log_msg(stdout, "Negotiating Sieve STARTTLS\n");
296
297 wait_for_text(socket, "OK ", 3);
298 send_line(socket, "STARTTLS\r\n");
299 wait_for_text(socket, "OK ", 3);
300 } else if (strcasecmp(socket->app_proto, "postgres") == 0 || strcasecmp(socket->app_proto, "postgresql") == 0) {
301 if (socket->verbose)
302 log_msg(stdout, "Negotiating PostgreSQL STARTTLS\n");
303
304 #define POSTGRES_STR "\x00\x00\x00\x08\x04\xD2\x16\x2F"
305 send(socket->fd, POSTGRES_STR, sizeof(POSTGRES_STR)-1, 0);
306 wait_for_text(socket, NULL, 0);
307 } else {
308 if (!c_isdigit(socket->app_proto[0])) {
309 static int warned = 0;
310 if (warned == 0) {
311 fprintf(stderr, "unknown protocol '%s'\n", socket->app_proto);
312 warned = 1;
313 }
314 }
315 }
316
317 return;
318 }
319
320 #define CANON_SERVICE(app_proto) \
321 if (strcasecmp(app_proto, "xmpp") == 0) \
322 app_proto = "xmpp-server"; \
323
324 int
starttls_proto_to_port(const char * app_proto)325 starttls_proto_to_port(const char *app_proto)
326 {
327 struct servent *s;
328
329 CANON_SERVICE(app_proto);
330
331 s = getservbyname(app_proto, NULL);
332 if (s != NULL) {
333 return ntohs(s->s_port);
334 }
335
336 endservent();
337
338 return 443;
339 }
340
starttls_proto_to_service(const char * app_proto)341 const char *starttls_proto_to_service(const char *app_proto)
342 {
343 struct servent *s;
344
345 CANON_SERVICE(app_proto);
346
347 s = getservbyname(app_proto, NULL);
348 if (s != NULL) {
349 return s->s_name;
350 }
351 endservent();
352
353 return "443";
354 }
355
socket_bye(socket_st * socket,unsigned polite)356 void socket_bye(socket_st * socket, unsigned polite)
357 {
358 int ret;
359
360 if (socket->secure && socket->session) {
361 if (polite) {
362 do
363 ret = gnutls_bye(socket->session, GNUTLS_SHUT_WR);
364 while (ret == GNUTLS_E_INTERRUPTED
365 || ret == GNUTLS_E_AGAIN);
366 if (socket->verbose && ret < 0)
367 fprintf(stderr, "*** gnutls_bye() error: %s\n",
368 gnutls_strerror(ret));
369 }
370 }
371
372 if (socket->session) {
373 gnutls_deinit(socket->session);
374 socket->session = NULL;
375 }
376
377 freeaddrinfo(socket->addr_info);
378 socket->addr_info = socket->ptr = NULL;
379 socket->connect_addrlen = 0;
380
381 free(socket->ip);
382 free(socket->hostname);
383 free(socket->service);
384
385 shutdown(socket->fd, SHUT_RDWR); /* no more receptions */
386 close(socket->fd);
387
388 gnutls_free(socket->rdata.data);
389 socket->rdata.data = NULL;
390
391 if (socket->server_trace)
392 fclose(socket->server_trace);
393 if (socket->client_trace)
394 fclose(socket->client_trace);
395
396 socket->fd = -1;
397 socket->secure = 0;
398 }
399
400 /* Handle host:port format.
401 */
canonicalize_host(char * hostname,char * service,unsigned service_size)402 void canonicalize_host(char *hostname, char *service, unsigned service_size)
403 {
404 char *p;
405
406 if ((p = strchr(hostname, ':'))) {
407 unsigned char buf[64];
408
409 if (inet_pton(AF_INET6, hostname, buf) == 1)
410 return;
411
412 *p = 0;
413
414 if (service && service_size)
415 snprintf(service, service_size, "%s", p+1);
416 } else
417 p = hostname + strlen(hostname);
418
419 if (p > hostname && p[-1] == '.')
420 p[-1] = 0; // remove trailing dot on FQDN
421 }
422
423 static ssize_t
wrap_pull(gnutls_transport_ptr_t ptr,void * data,size_t len)424 wrap_pull(gnutls_transport_ptr_t ptr, void *data, size_t len)
425 {
426 socket_st *hd = ptr;
427 ssize_t r;
428
429 r = recv(hd->fd, data, len, 0);
430 if (r > 0 && hd->server_trace) {
431 fwrite(data, 1, r, hd->server_trace);
432 }
433 return r;
434 }
435
436 static ssize_t
wrap_push(gnutls_transport_ptr_t ptr,const void * data,size_t len)437 wrap_push(gnutls_transport_ptr_t ptr, const void *data, size_t len)
438 {
439 socket_st *hd = ptr;
440
441 if (hd->client_trace) {
442 fwrite(data, 1, len, hd->client_trace);
443 }
444
445 return send(hd->fd, data, len, 0);
446 }
447
448 /* inline is used to avoid a gcc warning if used in mini-eagain */
wrap_pull_timeout_func(gnutls_transport_ptr_t ptr,unsigned int ms)449 inline static int wrap_pull_timeout_func(gnutls_transport_ptr_t ptr,
450 unsigned int ms)
451 {
452 socket_st *hd = ptr;
453
454 return gnutls_system_recv_timeout((gnutls_transport_ptr_t)(long)hd->fd, ms);
455 }
456
457
458 void
socket_open2(socket_st * hd,const char * hostname,const char * service,const char * app_proto,int flags,const char * msg,gnutls_datum_t * rdata,gnutls_datum_t * edata,FILE * server_trace,FILE * client_trace)459 socket_open2(socket_st * hd, const char *hostname, const char *service,
460 const char *app_proto, int flags, const char *msg, gnutls_datum_t *rdata, gnutls_datum_t *edata,
461 FILE *server_trace, FILE *client_trace)
462 {
463 struct addrinfo hints, *res, *ptr;
464 int sd, err = 0;
465 int udp = flags & SOCKET_FLAG_UDP;
466 int ret;
467 int fastopen = flags & SOCKET_FLAG_FASTOPEN;
468 char buffer[MAX_BUF + 1];
469 char portname[16] = { 0 };
470 gnutls_datum_t idna;
471 char *a_hostname;
472
473 memset(hd, 0, sizeof(*hd));
474
475 if (flags & SOCKET_FLAG_VERBOSE)
476 hd->verbose = 1;
477
478 if (rdata) {
479 hd->rdata.data = rdata->data;
480 hd->rdata.size = rdata->size;
481 }
482
483 if (edata) {
484 hd->edata.data = edata->data;
485 hd->edata.size = edata->size;
486 }
487
488 ret = gnutls_idna_map(hostname, strlen(hostname), &idna, 0);
489 if (ret < 0) {
490 fprintf(stderr, "Cannot convert %s to IDNA: %s\n", hostname, gnutls_strerror(ret));
491 exit(1);
492 }
493
494 hd->hostname = strdup(hostname);
495 a_hostname = (char*)idna.data;
496
497 if (msg != NULL)
498 log_msg(stdout, "Resolving '%s:%s'...\n", a_hostname, service);
499
500 /* get server name */
501 memset(&hints, 0, sizeof(hints));
502 hints.ai_socktype = udp ? SOCK_DGRAM : SOCK_STREAM;
503 if ((err = getaddrinfo(a_hostname, service, &hints, &res))) {
504 fprintf(stderr, "Cannot resolve %s:%s: %s\n", hostname,
505 service, gai_strerror(err));
506 exit(1);
507 }
508
509 sd = -1;
510 for (ptr = res; ptr != NULL; ptr = ptr->ai_next) {
511 sd = socket(ptr->ai_family, ptr->ai_socktype,
512 ptr->ai_protocol);
513 if (sd == -1)
514 continue;
515
516 if ((err =
517 getnameinfo(ptr->ai_addr, ptr->ai_addrlen, buffer,
518 MAX_BUF, portname, sizeof(portname),
519 NI_NUMERICHOST | NI_NUMERICSERV)) != 0) {
520 fprintf(stderr, "getnameinfo(): %s\n",
521 gai_strerror(err));
522 continue;
523 }
524
525 if (hints.ai_socktype == SOCK_DGRAM) {
526 #if defined(IP_DONTFRAG)
527 int yes = 1;
528 if (setsockopt(sd, IPPROTO_IP, IP_DONTFRAG,
529 (const void *) &yes,
530 sizeof(yes)) < 0)
531 perror("setsockopt(IP_DF) failed");
532 #elif defined(IP_MTU_DISCOVER)
533 int yes = IP_PMTUDISC_DO;
534 if (setsockopt(sd, IPPROTO_IP, IP_MTU_DISCOVER,
535 (const void *) &yes,
536 sizeof(yes)) < 0)
537 perror("setsockopt(IP_DF) failed");
538 #endif
539 }
540
541 if (fastopen && ptr->ai_socktype == SOCK_STREAM
542 && (ptr->ai_family == AF_INET || ptr->ai_family == AF_INET6)) {
543 memcpy(&hd->connect_addr, ptr->ai_addr, ptr->ai_addrlen);
544 hd->connect_addrlen = ptr->ai_addrlen;
545
546 if (msg)
547 log_msg(stdout, "%s '%s:%s' (TFO)...\n", msg, buffer, portname);
548
549 } else {
550 if (msg)
551 log_msg(stdout, "%s '%s:%s'...\n", msg, buffer, portname);
552
553 if ((err = connect(sd, ptr->ai_addr, ptr->ai_addrlen)) < 0)
554 continue;
555 }
556
557 hd->fd = sd;
558 if (flags & SOCKET_FLAG_STARTTLS) {
559 hd->app_proto = app_proto;
560 socket_starttls(hd);
561 hd->app_proto = NULL;
562 }
563
564 if (!(flags & SOCKET_FLAG_SKIP_INIT)) {
565 hd->session = init_tls_session(hostname);
566 if (hd->session == NULL) {
567 fprintf(stderr, "error initializing session\n");
568 exit(1);
569 }
570 }
571
572 if (hd->session) {
573 if (hd->edata.data) {
574 ret = gnutls_record_send_early_data(hd->session, hd->edata.data, hd->edata.size);
575 if (ret < 0) {
576 fprintf(stderr, "error sending early data\n");
577 exit(1);
578 }
579 }
580 if (hd->rdata.data) {
581 gnutls_session_set_data(hd->session, hd->rdata.data, hd->rdata.size);
582 }
583
584 if (server_trace)
585 hd->server_trace = server_trace;
586
587 if (client_trace)
588 hd->client_trace = client_trace;
589
590 gnutls_transport_set_push_function(hd->session, wrap_push);
591 gnutls_transport_set_pull_function(hd->session, wrap_pull);
592 gnutls_transport_set_pull_timeout_function(hd->session, wrap_pull_timeout_func);
593 gnutls_transport_set_ptr(hd->session, hd);
594 }
595
596 if (!(flags & SOCKET_FLAG_RAW) && !(flags & SOCKET_FLAG_SKIP_INIT)) {
597 err = do_handshake(hd);
598 if (err == GNUTLS_E_PUSH_ERROR) { /* failed connecting */
599 gnutls_deinit(hd->session);
600 hd->session = NULL;
601 continue;
602 }
603 else if (err < 0) {
604 if (!(flags & SOCKET_FLAG_DONT_PRINT_ERRORS))
605 fprintf(stderr, "*** handshake has failed: %s\n", gnutls_strerror(err));
606 exit(1);
607 }
608 }
609
610 break;
611 }
612
613 if (err != 0) {
614 int e = errno;
615 fprintf(stderr, "Could not connect to %s:%s: %s\n",
616 buffer, portname, strerror(e));
617 exit(1);
618 }
619
620 if (sd == -1) {
621 fprintf(stderr, "Could not find a supported socket\n");
622 exit(1);
623 }
624
625 if ((flags & SOCKET_FLAG_RAW) || (flags & SOCKET_FLAG_SKIP_INIT))
626 hd->secure = 0;
627 else
628 hd->secure = 1;
629
630 hd->fd = sd;
631 hd->ip = strdup(buffer);
632 hd->service = strdup(portname);
633 hd->ptr = ptr;
634 hd->addr_info = res;
635 gnutls_free(hd->rdata.data);
636 hd->rdata.data = NULL;
637 gnutls_free(hd->edata.data);
638 hd->edata.data = NULL;
639 gnutls_free(idna.data);
640 return;
641 }
642
643 /* converts a textual service or port to
644 * a service.
645 */
port_to_service(const char * sport,const char * proto)646 const char *port_to_service(const char *sport, const char *proto)
647 {
648 unsigned int port;
649 struct servent *sr;
650
651 if (!c_isdigit(sport[0]))
652 return sport;
653
654 port = atoi(sport);
655 if (port == 0)
656 return sport;
657
658 port = htons(port);
659
660 sr = getservbyport(port, proto);
661 if (sr == NULL) {
662 fprintf(stderr,
663 "Warning: getservbyport(%s) failed. Using port number as service.\n", sport);
664 return sport;
665 }
666
667 return sr->s_name;
668 }
669
service_to_port(const char * service,const char * proto)670 int service_to_port(const char *service, const char *proto)
671 {
672 unsigned int port;
673 struct servent *sr;
674
675 port = atoi(service);
676 if (port != 0)
677 return port;
678
679 sr = getservbyname(service, proto);
680 if (sr == NULL) {
681 fprintf(stderr, "Warning: getservbyname() failed for '%s/%s'.\n", service, proto);
682 exit(1);
683 }
684
685 return ntohs(sr->s_port);
686 }
687