1 /*  Copyright (C) 2021 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
2 
3     This program is free software: you can redistribute it and/or modify
4     it under the terms of the GNU General Public License as published by
5     the Free Software Foundation, either version 3 of the License, or
6     (at your option) any later version.
7 
8     This program is distributed in the hope that it will be useful,
9     but WITHOUT ANY WARRANTY; without even the implied warranty of
10     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11     GNU General Public License for more details.
12 
13     You should have received a copy of the GNU General Public License
14     along with this program.  If not, see <https://www.gnu.org/licenses/>.
15  */
16 
17 #include <arpa/inet.h>
18 #include <fcntl.h>
19 #include <netdb.h>
20 #include <poll.h>
21 #include <stdlib.h>
22 #include <netinet/in.h>
23 #include <sys/types.h>   // OpenBSD
24 #include <netinet/tcp.h> // TCP_FASTOPEN
25 #include <sys/socket.h>
26 
27 #ifdef HAVE_SYS_UIO_H
28 #include <sys/uio.h>
29 #endif
30 
31 #include "utils/common/netio.h"
32 #include "utils/common/msg.h"
33 #include "utils/common/tls.h"
34 #include "libknot/libknot.h"
35 #include "contrib/sockaddr.h"
36 
srv_info_create(const char * name,const char * service)37 srv_info_t *srv_info_create(const char *name, const char *service)
38 {
39 	if (name == NULL || service == NULL) {
40 		DBG_NULL;
41 		return NULL;
42 	}
43 
44 	// Create output structure.
45 	srv_info_t *server = calloc(1, sizeof(srv_info_t));
46 
47 	// Check output.
48 	if (server == NULL) {
49 		return NULL;
50 	}
51 
52 	// Fill output.
53 	server->name = strdup(name);
54 	server->service = strdup(service);
55 
56 	if (server->name == NULL || server->service == NULL) {
57 		srv_info_free(server);
58 		return NULL;
59 	}
60 
61 	// Return result.
62 	return server;
63 }
64 
srv_info_free(srv_info_t * server)65 void srv_info_free(srv_info_t *server)
66 {
67 	if (server == NULL) {
68 		DBG_NULL;
69 		return;
70 	}
71 
72 	free(server->name);
73 	free(server->service);
74 	free(server);
75 }
76 
get_iptype(const ip_t ip)77 int get_iptype(const ip_t ip)
78 {
79 	switch (ip) {
80 	case IP_4:
81 		return AF_INET;
82 	case IP_6:
83 		return AF_INET6;
84 	default:
85 		return AF_UNSPEC;
86 	}
87 }
88 
get_socktype(const protocol_t proto,const uint16_t type)89 int get_socktype(const protocol_t proto, const uint16_t type)
90 {
91 	switch (proto) {
92 	case PROTO_TCP:
93 		return SOCK_STREAM;
94 	case PROTO_UDP:
95 		return SOCK_DGRAM;
96 	default:
97 		if (type == KNOT_RRTYPE_AXFR || type == KNOT_RRTYPE_IXFR) {
98 			return SOCK_STREAM;
99 		} else {
100 			return SOCK_DGRAM;
101 		}
102 	}
103 }
104 
get_sockname(const int socktype)105 const char *get_sockname(const int socktype)
106 {
107 	switch (socktype) {
108 	case SOCK_STREAM:
109 		return "TCP";
110 	case SOCK_DGRAM:
111 		return "UDP";
112 	default:
113 		return "UNKNOWN";
114 	}
115 }
116 
get_addr(const srv_info_t * server,const int iptype,const int socktype,struct addrinfo ** info)117 static int get_addr(const srv_info_t *server,
118                     const int        iptype,
119                     const int        socktype,
120                     struct addrinfo  **info)
121 {
122 	struct addrinfo hints;
123 
124 	// Set connection hints.
125 	memset(&hints, 0, sizeof(hints));
126 	hints.ai_family = iptype;
127 	hints.ai_socktype = socktype;
128 
129 	// Get connection parameters.
130 	int ret = getaddrinfo(server->name, server->service, &hints, info);
131 	switch (ret) {
132 	case 0:
133 		return 0;
134 #ifdef EAI_ADDRFAMILY	/* EAI_ADDRFAMILY isn't implemented in FreeBSD/macOS anymore. */
135 	case EAI_ADDRFAMILY:
136 		break;
137 #else			/* FreeBSD, macOS, and likely others return EAI_NONAME instead. */
138 	case EAI_NONAME:
139 		if (iptype != AF_UNSPEC) {
140 			break;
141 		}
142 		/* FALLTHROUGH */
143 #endif	/* EAI_ADDRFAMILY */
144 	default:
145 		ERR("%s for %s@%s\n", gai_strerror(ret), server->name, server->service);
146 	}
147 	return -1;
148 }
149 
get_addr_str(const struct sockaddr_storage * ss,const int socktype,char ** dst)150 void get_addr_str(const struct sockaddr_storage *ss,
151                   const int                     socktype,
152                   char                          **dst)
153 {
154 	char addr_str[SOCKADDR_STRLEN] = {0};
155 
156 	// Get network address string and port number.
157 	sockaddr_tostr(addr_str, sizeof(addr_str), ss);
158 
159 	// Calculate needed buffer size
160 	const char *sock_name = get_sockname(socktype);
161 	size_t buflen = strlen(addr_str) + strlen(sock_name) + 3 /* () */;
162 
163 	// Free previous string if any and write result
164 	free(*dst);
165 	*dst = malloc(buflen);
166 	if (*dst != NULL) {
167 		int ret = snprintf(*dst, buflen, "%s(%s)", addr_str, sock_name);
168 		if (ret <= 0 || ret >= buflen) {
169 			**dst = '\0';
170 		}
171 	}
172 }
173 
net_init(const srv_info_t * local,const srv_info_t * remote,const int iptype,const int socktype,const int wait,const net_flags_t flags,const tls_params_t * tls_params,const https_params_t * https_params,net_t * net)174 int net_init(const srv_info_t     *local,
175              const srv_info_t     *remote,
176              const int            iptype,
177              const int            socktype,
178              const int            wait,
179              const net_flags_t    flags,
180              const tls_params_t   *tls_params,
181              const https_params_t *https_params,
182              net_t                *net)
183 {
184 	if (remote == NULL || net == NULL) {
185 		DBG_NULL;
186 		return KNOT_EINVAL;
187 	}
188 
189 	// Clean network structure.
190 	memset(net, 0, sizeof(*net));
191 	net->sockfd = -1;
192 
193 	// Get remote address list.
194 	if (get_addr(remote, iptype, socktype, &net->remote_info) != 0) {
195 		net_clean(net);
196 		return KNOT_NET_EADDR;
197 	}
198 
199 	// Set current remote address.
200 	net->srv = net->remote_info;
201 
202 	// Get local address if specified.
203 	if (local != NULL) {
204 		if (get_addr(local, iptype, socktype, &net->local_info) != 0) {
205 			net_clean(net);
206 			return KNOT_NET_EADDR;
207 		}
208 	}
209 
210 	// Store network parameters.
211 	net->sockfd = -1;
212 	net->iptype = iptype;
213 	net->socktype = socktype;
214 	net->wait = wait;
215 	net->local = local;
216 	net->remote = remote;
217 	net->flags = flags;
218 
219 	// Prepare for TLS.
220 	if (tls_params != NULL && tls_params->enable) {
221 		int ret = tls_ctx_init(&net->tls, tls_params, net->wait);
222 		if (ret != KNOT_EOK) {
223 			net_clean(net);
224 			return ret;
225 		}
226 
227 #ifdef LIBNGHTTP2
228 		// Prepare for HTTPS.
229 		if (https_params != NULL && https_params->enable) {
230 			ret = https_ctx_init(&net->https, &net->tls, https_params);
231 			if (ret != KNOT_EOK) {
232 				net_clean(net);
233 				return ret;
234 			}
235 		}
236 #endif //LIBNGHTTP2
237 	}
238 
239 	return KNOT_EOK;
240 }
241 
242 /*!
243  * Connect with TCP Fast Open.
244  */
fastopen_connect(int sockfd,const struct addrinfo * srv)245 static int fastopen_connect(int sockfd, const struct addrinfo *srv)
246 {
247 #if defined( __FreeBSD__) && defined(TCP_FASTOPEN)
248 	const int enable = 1;
249 	return setsockopt(sockfd, IPPROTO_TCP, TCP_FASTOPEN, &enable, sizeof(enable));
250 #elif defined(__APPLE__)
251 	// connection is performed lazily when first data are sent
252 	struct sa_endpoints ep = {0};
253 	ep.sae_dstaddr = srv->ai_addr;
254 	ep.sae_dstaddrlen = srv->ai_addrlen;
255 	int flags =  CONNECT_DATA_IDEMPOTENT|CONNECT_RESUME_ON_READ_WRITE;
256 
257 	return connectx(sockfd, &ep, SAE_ASSOCID_ANY, flags, NULL, 0, NULL, NULL);
258 #elif defined(__linux__)
259 	// connect() will be called implicitly with sendto(), sendmsg()
260 	return 0;
261 #else
262 	errno = ENOTSUP;
263 	return -1;
264 #endif
265 }
266 
267 /*!
268  * Sends data with TCP Fast Open.
269  */
fastopen_send(int sockfd,const struct msghdr * msg,int timeout)270 static int fastopen_send(int sockfd, const struct msghdr *msg, int timeout)
271 {
272 #if (defined(__FreeBSD__) && defined(TCP_FASTOPEN))|| defined(__APPLE__)
273 	return sendmsg(sockfd, msg, 0);
274 #elif defined(__linux__)
275 	int ret = sendmsg(sockfd, msg, MSG_FASTOPEN);
276 	if (ret == -1 && errno == EINPROGRESS) {
277 		struct pollfd pfd = {
278 			.fd = sockfd,
279 			.events = POLLOUT,
280 			.revents = 0,
281 		};
282 		if (poll(&pfd, 1, 1000 * timeout) != 1) {
283 			errno = ETIMEDOUT;
284 			return -1;
285 		}
286 		ret = sendmsg(sockfd, msg, 0);
287 	}
288 	return ret;
289 #else
290 	errno = ENOTSUP;
291 	return -1;
292 #endif
293 }
294 
net_connect(net_t * net)295 int net_connect(net_t *net)
296 {
297 	if (net == NULL || net->srv == NULL) {
298 		DBG_NULL;
299 		return KNOT_EINVAL;
300 	}
301 
302 	// Set remote information string.
303 	get_addr_str((struct sockaddr_storage *)net->srv->ai_addr,
304 	             net->socktype, &net->remote_str);
305 
306 	// Create socket.
307 	int sockfd = socket(net->srv->ai_family, net->socktype, 0);
308 	if (sockfd == -1) {
309 		WARN("can't create socket for %s\n", net->remote_str);
310 		return KNOT_NET_ESOCKET;
311 	}
312 
313 	// Initialize poll descriptor structure.
314 	struct pollfd pfd = {
315 		.fd = sockfd,
316 		.events = POLLOUT,
317 		.revents = 0,
318 	};
319 
320 	// Set non-blocking socket.
321 	if (fcntl(sockfd, F_SETFL, O_NONBLOCK) == -1) {
322 		WARN("can't set non-blocking socket for %s\n", net->remote_str);
323 		return KNOT_NET_ESOCKET;
324 	}
325 
326 	// Bind address to socket if specified.
327 	if (net->local_info != NULL) {
328 		if (bind(sockfd, net->local_info->ai_addr,
329 		         net->local_info->ai_addrlen) == -1) {
330 			WARN("can't assign address %s\n", net->local->name);
331 			return KNOT_NET_ESOCKET;
332 		}
333 	} else {
334 		// Ensure source port is always randomized (even for TCP).
335 		struct sockaddr_storage local = { .ss_family = net->srv->ai_family };
336 		(void)bind(sockfd, (struct sockaddr *)&local, sockaddr_len(&local));
337 	}
338 
339 	if (net->socktype == SOCK_STREAM) {
340 		int  cs, err, ret = 0;
341 		socklen_t err_len = sizeof(err);
342 		bool fastopen = net->flags & NET_FLAGS_FASTOPEN;
343 
344 		// Establish a connection.
345 		if (net->tls.params == NULL || !fastopen) {
346 			if (fastopen) {
347 				ret = fastopen_connect(sockfd, net->srv);
348 			} else {
349 				ret = connect(sockfd, net->srv->ai_addr, net->srv->ai_addrlen);
350 			}
351 			if (ret != 0 && errno != EINPROGRESS) {
352 				WARN("can't connect to %s\n", net->remote_str);
353 				close(sockfd);
354 				return KNOT_NET_ECONNECT;
355 			}
356 
357 			// Check for connection timeout.
358 			if (!fastopen && poll(&pfd, 1, 1000 * net->wait) != 1) {
359 				WARN("connection timeout for %s\n", net->remote_str);
360 				close(sockfd);
361 				return KNOT_NET_ECONNECT;
362 			}
363 
364 			// Check if NB socket is writeable.
365 			cs = getsockopt(sockfd, SOL_SOCKET, SO_ERROR, &err, &err_len);
366 			if (cs < 0 || err != 0) {
367 				WARN("can't connect to %s\n", net->remote_str);
368 				close(sockfd);
369 				return KNOT_NET_ECONNECT;
370 			}
371 		}
372 
373 		if (net->tls.params != NULL) {
374 #ifdef LIBNGHTTP2
375 			if (net->https.params.enable) {
376 				// Establish HTTPS connection.
377 				char *remote = NULL;
378 				if (net->tls.params->sni != NULL) {
379 					remote = net->tls.params->sni;
380 				} else if (net->tls.params->hostname != NULL) {
381 					remote = net->tls.params->hostname;
382 				} else if (strchr(net->remote_str, ':') == NULL) {
383 					char *at = strchr(net->remote_str, '@');
384 					if (at != NULL && strncmp(net->remote->name, net->remote_str,
385 					                          at - net->remote_str)) {
386 						remote = net->remote->name;
387 					}
388 				}
389 				ret = https_ctx_connect(&net->https, sockfd, remote, fastopen,
390 				                        (struct sockaddr_storage *)net->srv->ai_addr);
391 			} else {
392 #endif //LIBNGHTTP2
393 				// Establish TLS connection.
394 				ret = tls_ctx_connect(&net->tls, sockfd, net->tls.params->sni, fastopen,
395 				                      (struct sockaddr_storage *)net->srv->ai_addr, &dot_alpn);
396 #ifdef LIBNGHTTP2
397 			}
398 #endif //LIBNGHTTP2
399 			if (ret != KNOT_EOK) {
400 				close(sockfd);
401 				return ret;
402 			}
403 		}
404 	}
405 
406 	// Store socket descriptor.
407 	net->sockfd = sockfd;
408 
409 	return KNOT_EOK;
410 }
411 
net_set_local_info(net_t * net)412 int net_set_local_info(net_t *net)
413 {
414 	if (net == NULL) {
415 		DBG_NULL;
416 		return KNOT_EINVAL;
417 	}
418 
419 	socklen_t local_addr_len = sizeof(struct sockaddr_storage);
420 
421 	struct addrinfo *new_info = calloc(1, sizeof(*new_info) + local_addr_len);
422 	if (new_info == NULL) {
423 		return KNOT_ENOMEM;
424 	}
425 
426 	new_info->ai_addr = (struct sockaddr *)(new_info + 1);
427 	new_info->ai_family = net->srv->ai_family;
428 	new_info->ai_socktype = net->srv->ai_socktype;
429 	new_info->ai_protocol = net->srv->ai_protocol;
430 	new_info->ai_addrlen = local_addr_len;
431 
432 	if (getsockname(net->sockfd, new_info->ai_addr,	&local_addr_len) == -1) {
433 		WARN("can't get local address\n");
434 		free(new_info);
435 		return KNOT_NET_ESOCKET;
436 	}
437 
438 	if (net->local_info != NULL) {
439 		if (net->local == NULL) {
440 			free(net->local_info);
441 		} else {
442 			freeaddrinfo(net->local_info);
443 		}
444 	}
445 
446 	net->local_info = new_info;
447 
448 	get_addr_str((struct sockaddr_storage *)net->local_info->ai_addr,
449 	             net->socktype, &net->local_str);
450 
451 	return KNOT_EOK;
452 }
453 
net_send(const net_t * net,const uint8_t * buf,const size_t buf_len)454 int net_send(const net_t *net, const uint8_t *buf, const size_t buf_len)
455 {
456 	if (net == NULL || buf == NULL) {
457 		DBG_NULL;
458 		return KNOT_EINVAL;
459 	}
460 
461 	// Send data over UDP.
462 	if (net->socktype == SOCK_DGRAM) {
463 		if (sendto(net->sockfd, buf, buf_len, 0, net->srv->ai_addr,
464 		           net->srv->ai_addrlen) != (ssize_t)buf_len) {
465 			WARN("can't send query to %s\n", net->remote_str);
466 			return KNOT_NET_ESEND;
467 		}
468 #ifdef LIBNGHTTP2
469 	// Send data over HTTPS
470 	} else if (net->https.params.enable) {
471 		int ret = https_send_dns_query((https_ctx_t *)&net->https, buf, buf_len);
472 		if (ret != KNOT_EOK) {
473 			WARN("can't send query to %s\n", net->remote_str);
474 			return KNOT_NET_ESEND;
475 		}
476 #endif //LIBNGHTTP2
477 	// Send data over TLS.
478 	} else if (net->tls.params != NULL) {
479 		int ret = tls_ctx_send((tls_ctx_t *)&net->tls, buf, buf_len);
480 		if (ret != KNOT_EOK) {
481 			WARN("can't send query to %s\n", net->remote_str);
482 			return KNOT_NET_ESEND;
483 		}
484 	// Send data over TCP.
485 	} else {
486 		bool fastopen = net->flags & NET_FLAGS_FASTOPEN;
487 
488 		// Leading packet length bytes.
489 		uint16_t pktsize = htons(buf_len);
490 
491 		struct iovec iov[2];
492 		iov[0].iov_base = &pktsize;
493 		iov[0].iov_len = sizeof(pktsize);
494 		iov[1].iov_base = (uint8_t *)buf;
495 		iov[1].iov_len = buf_len;
496 
497 		// Compute packet total length.
498 		ssize_t total = iov[0].iov_len + iov[1].iov_len;
499 
500 		struct msghdr msg = {0};
501 		msg.msg_iov = iov;
502 		msg.msg_iovlen = sizeof(iov) / sizeof(*iov);
503 		msg.msg_name = net->srv->ai_addr;
504 		msg.msg_namelen = net->srv->ai_addrlen;
505 
506 		int ret = 0;
507 		if (fastopen) {
508 			ret = fastopen_send(net->sockfd, &msg, net->wait);
509 		} else {
510 			ret = sendmsg(net->sockfd, &msg, 0);
511 		}
512 		if (ret != total) {
513 			WARN("can't send query to %s\n", net->remote_str);
514 			return KNOT_NET_ESEND;
515 		}
516 	}
517 
518 	return KNOT_EOK;
519 }
520 
net_receive(const net_t * net,uint8_t * buf,const size_t buf_len)521 int net_receive(const net_t *net, uint8_t *buf, const size_t buf_len)
522 {
523 	if (net == NULL || buf == NULL) {
524 		DBG_NULL;
525 		return KNOT_EINVAL;
526 	}
527 
528 	// Initialize poll descriptor structure.
529 	struct pollfd pfd = {
530 		.fd = net->sockfd,
531 		.events = POLLIN,
532 		.revents = 0,
533 	};
534 
535 	// Receive data over UDP.
536 	if (net->socktype == SOCK_DGRAM) {
537 		struct sockaddr_storage from;
538 		memset(&from, '\0', sizeof(from));
539 
540 		// Receive replies unless correct reply or timeout.
541 		while (true) {
542 			socklen_t from_len = sizeof(from);
543 
544 			// Wait for datagram data.
545 			if (poll(&pfd, 1, 1000 * net->wait) != 1) {
546 				WARN("response timeout for %s\n",
547 				     net->remote_str);
548 				return KNOT_NET_ETIMEOUT;
549 			}
550 
551 			// Receive whole UDP datagram.
552 			ssize_t ret = recvfrom(net->sockfd, buf, buf_len, 0,
553 			                       (struct sockaddr *)&from, &from_len);
554 			if (ret <= 0) {
555 				WARN("can't receive reply from %s\n",
556 				     net->remote_str);
557 				return KNOT_NET_ERECV;
558 			}
559 
560 			// Compare reply address with the remote one.
561 			if (from_len > sizeof(from) ||
562 			    memcmp(&from, net->srv->ai_addr, from_len) != 0) {
563 				char *src = NULL;
564 				get_addr_str(&from, net->socktype, &src);
565 				WARN("unexpected reply source %s\n", src);
566 				free(src);
567 				continue;
568 			}
569 
570 			return ret;
571 		}
572 #ifdef LIBNGHTTP2
573 	// Receive data over HTTPS.
574 	} else if (net->https.params.enable) {
575 		return https_recv_dns_response((https_ctx_t *)&net->https, buf, buf_len);
576 #endif //LIBNGHTTP2
577 	// Receive data over TLS.
578 	} else if (net->tls.params != NULL) {
579 		int ret = tls_ctx_receive((tls_ctx_t *)&net->tls, buf, buf_len);
580 		if (ret < 0) {
581 			WARN("can't receive reply from %s\n", net->remote_str);
582 			return KNOT_NET_ERECV;
583 		}
584 
585 		return ret;
586 	// Receive data over TCP.
587 	} else {
588 		uint32_t total = 0;
589 
590 		uint16_t msg_len = 0;
591 		// Receive TCP message header.
592 		while (total < sizeof(msg_len)) {
593 			if (poll(&pfd, 1, 1000 * net->wait) != 1) {
594 				WARN("response timeout for %s\n",
595 				     net->remote_str);
596 				return KNOT_NET_ETIMEOUT;
597 			}
598 
599 			// Receive piece of message.
600 			ssize_t ret = recv(net->sockfd, (uint8_t *)&msg_len + total,
601 				           sizeof(msg_len) - total, 0);
602 			if (ret <= 0) {
603 				WARN("can't receive reply from %s\n",
604 				     net->remote_str);
605 				return KNOT_NET_ERECV;
606 			}
607 			total += ret;
608 		}
609 
610 		// Convert number to host format.
611 		msg_len = ntohs(msg_len);
612 		if (msg_len > buf_len) {
613 			return KNOT_ESPACE;
614 		}
615 
616 		total = 0;
617 
618 		// Receive whole answer message by parts.
619 		while (total < msg_len) {
620 			if (poll(&pfd, 1, 1000 * net->wait) != 1) {
621 				WARN("response timeout for %s\n",
622 				     net->remote_str);
623 				return KNOT_NET_ETIMEOUT;
624 			}
625 
626 			// Receive piece of message.
627 			ssize_t ret = recv(net->sockfd, buf + total, msg_len - total, 0);
628 			if (ret <= 0) {
629 				WARN("can't receive reply from %s\n",
630 				     net->remote_str);
631 				return KNOT_NET_ERECV;
632 			}
633 			total += ret;
634 		}
635 
636 		return total;
637 	}
638 
639 	return KNOT_NET_ERECV;
640 }
641 
net_close(net_t * net)642 void net_close(net_t *net)
643 {
644 	if (net == NULL) {
645 		DBG_NULL;
646 		return;
647 	}
648 
649 	tls_ctx_close(&net->tls);
650 	close(net->sockfd);
651 	net->sockfd = -1;
652 }
653 
net_clean(net_t * net)654 void net_clean(net_t *net)
655 {
656 	if (net == NULL) {
657 		DBG_NULL;
658 		return;
659 	}
660 
661 	free(net->local_str);
662 	free(net->remote_str);
663 	net->local_str = NULL;
664 	net->remote_str = NULL;
665 
666 	if (net->local_info != NULL) {
667 		if (net->local == NULL) {
668 			free(net->local_info);
669 		} else {
670 			freeaddrinfo(net->local_info);
671 		}
672 		net->local_info = NULL;
673 	}
674 
675 	if (net->remote_info != NULL) {
676 		freeaddrinfo(net->remote_info);
677 		net->remote_info = NULL;
678 	}
679 
680 #ifdef LIBNGHTTP2
681 	https_ctx_deinit(&net->https);
682 #endif
683 	tls_ctx_deinit(&net->tls);
684 }
685