1 /* Copyright (C) 2021 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
2
3 This program is free software: you can redistribute it and/or modify
4 it under the terms of the GNU General Public License as published by
5 the Free Software Foundation, either version 3 of the License, or
6 (at your option) any later version.
7
8 This program is distributed in the hope that it will be useful,
9 but WITHOUT ANY WARRANTY; without even the implied warranty of
10 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 GNU General Public License for more details.
12
13 You should have received a copy of the GNU General Public License
14 along with this program. If not, see <https://www.gnu.org/licenses/>.
15 */
16
17 #include <arpa/inet.h>
18 #include <fcntl.h>
19 #include <netdb.h>
20 #include <poll.h>
21 #include <stdlib.h>
22 #include <netinet/in.h>
23 #include <sys/types.h> // OpenBSD
24 #include <netinet/tcp.h> // TCP_FASTOPEN
25 #include <sys/socket.h>
26
27 #ifdef HAVE_SYS_UIO_H
28 #include <sys/uio.h>
29 #endif
30
31 #include "utils/common/netio.h"
32 #include "utils/common/msg.h"
33 #include "utils/common/tls.h"
34 #include "libknot/libknot.h"
35 #include "contrib/sockaddr.h"
36
srv_info_create(const char * name,const char * service)37 srv_info_t *srv_info_create(const char *name, const char *service)
38 {
39 if (name == NULL || service == NULL) {
40 DBG_NULL;
41 return NULL;
42 }
43
44 // Create output structure.
45 srv_info_t *server = calloc(1, sizeof(srv_info_t));
46
47 // Check output.
48 if (server == NULL) {
49 return NULL;
50 }
51
52 // Fill output.
53 server->name = strdup(name);
54 server->service = strdup(service);
55
56 if (server->name == NULL || server->service == NULL) {
57 srv_info_free(server);
58 return NULL;
59 }
60
61 // Return result.
62 return server;
63 }
64
srv_info_free(srv_info_t * server)65 void srv_info_free(srv_info_t *server)
66 {
67 if (server == NULL) {
68 DBG_NULL;
69 return;
70 }
71
72 free(server->name);
73 free(server->service);
74 free(server);
75 }
76
get_iptype(const ip_t ip)77 int get_iptype(const ip_t ip)
78 {
79 switch (ip) {
80 case IP_4:
81 return AF_INET;
82 case IP_6:
83 return AF_INET6;
84 default:
85 return AF_UNSPEC;
86 }
87 }
88
get_socktype(const protocol_t proto,const uint16_t type)89 int get_socktype(const protocol_t proto, const uint16_t type)
90 {
91 switch (proto) {
92 case PROTO_TCP:
93 return SOCK_STREAM;
94 case PROTO_UDP:
95 return SOCK_DGRAM;
96 default:
97 if (type == KNOT_RRTYPE_AXFR || type == KNOT_RRTYPE_IXFR) {
98 return SOCK_STREAM;
99 } else {
100 return SOCK_DGRAM;
101 }
102 }
103 }
104
get_sockname(const int socktype)105 const char *get_sockname(const int socktype)
106 {
107 switch (socktype) {
108 case SOCK_STREAM:
109 return "TCP";
110 case SOCK_DGRAM:
111 return "UDP";
112 default:
113 return "UNKNOWN";
114 }
115 }
116
get_addr(const srv_info_t * server,const int iptype,const int socktype,struct addrinfo ** info)117 static int get_addr(const srv_info_t *server,
118 const int iptype,
119 const int socktype,
120 struct addrinfo **info)
121 {
122 struct addrinfo hints;
123
124 // Set connection hints.
125 memset(&hints, 0, sizeof(hints));
126 hints.ai_family = iptype;
127 hints.ai_socktype = socktype;
128
129 // Get connection parameters.
130 int ret = getaddrinfo(server->name, server->service, &hints, info);
131 switch (ret) {
132 case 0:
133 return 0;
134 #ifdef EAI_ADDRFAMILY /* EAI_ADDRFAMILY isn't implemented in FreeBSD/macOS anymore. */
135 case EAI_ADDRFAMILY:
136 break;
137 #else /* FreeBSD, macOS, and likely others return EAI_NONAME instead. */
138 case EAI_NONAME:
139 if (iptype != AF_UNSPEC) {
140 break;
141 }
142 /* FALLTHROUGH */
143 #endif /* EAI_ADDRFAMILY */
144 default:
145 ERR("%s for %s@%s\n", gai_strerror(ret), server->name, server->service);
146 }
147 return -1;
148 }
149
get_addr_str(const struct sockaddr_storage * ss,const int socktype,char ** dst)150 void get_addr_str(const struct sockaddr_storage *ss,
151 const int socktype,
152 char **dst)
153 {
154 char addr_str[SOCKADDR_STRLEN] = {0};
155
156 // Get network address string and port number.
157 sockaddr_tostr(addr_str, sizeof(addr_str), ss);
158
159 // Calculate needed buffer size
160 const char *sock_name = get_sockname(socktype);
161 size_t buflen = strlen(addr_str) + strlen(sock_name) + 3 /* () */;
162
163 // Free previous string if any and write result
164 free(*dst);
165 *dst = malloc(buflen);
166 if (*dst != NULL) {
167 int ret = snprintf(*dst, buflen, "%s(%s)", addr_str, sock_name);
168 if (ret <= 0 || ret >= buflen) {
169 **dst = '\0';
170 }
171 }
172 }
173
net_init(const srv_info_t * local,const srv_info_t * remote,const int iptype,const int socktype,const int wait,const net_flags_t flags,const tls_params_t * tls_params,const https_params_t * https_params,net_t * net)174 int net_init(const srv_info_t *local,
175 const srv_info_t *remote,
176 const int iptype,
177 const int socktype,
178 const int wait,
179 const net_flags_t flags,
180 const tls_params_t *tls_params,
181 const https_params_t *https_params,
182 net_t *net)
183 {
184 if (remote == NULL || net == NULL) {
185 DBG_NULL;
186 return KNOT_EINVAL;
187 }
188
189 // Clean network structure.
190 memset(net, 0, sizeof(*net));
191 net->sockfd = -1;
192
193 // Get remote address list.
194 if (get_addr(remote, iptype, socktype, &net->remote_info) != 0) {
195 net_clean(net);
196 return KNOT_NET_EADDR;
197 }
198
199 // Set current remote address.
200 net->srv = net->remote_info;
201
202 // Get local address if specified.
203 if (local != NULL) {
204 if (get_addr(local, iptype, socktype, &net->local_info) != 0) {
205 net_clean(net);
206 return KNOT_NET_EADDR;
207 }
208 }
209
210 // Store network parameters.
211 net->sockfd = -1;
212 net->iptype = iptype;
213 net->socktype = socktype;
214 net->wait = wait;
215 net->local = local;
216 net->remote = remote;
217 net->flags = flags;
218
219 // Prepare for TLS.
220 if (tls_params != NULL && tls_params->enable) {
221 int ret = tls_ctx_init(&net->tls, tls_params, net->wait);
222 if (ret != KNOT_EOK) {
223 net_clean(net);
224 return ret;
225 }
226
227 #ifdef LIBNGHTTP2
228 // Prepare for HTTPS.
229 if (https_params != NULL && https_params->enable) {
230 ret = https_ctx_init(&net->https, &net->tls, https_params);
231 if (ret != KNOT_EOK) {
232 net_clean(net);
233 return ret;
234 }
235 }
236 #endif //LIBNGHTTP2
237 }
238
239 return KNOT_EOK;
240 }
241
242 /*!
243 * Connect with TCP Fast Open.
244 */
fastopen_connect(int sockfd,const struct addrinfo * srv)245 static int fastopen_connect(int sockfd, const struct addrinfo *srv)
246 {
247 #if defined( __FreeBSD__) && defined(TCP_FASTOPEN)
248 const int enable = 1;
249 return setsockopt(sockfd, IPPROTO_TCP, TCP_FASTOPEN, &enable, sizeof(enable));
250 #elif defined(__APPLE__)
251 // connection is performed lazily when first data are sent
252 struct sa_endpoints ep = {0};
253 ep.sae_dstaddr = srv->ai_addr;
254 ep.sae_dstaddrlen = srv->ai_addrlen;
255 int flags = CONNECT_DATA_IDEMPOTENT|CONNECT_RESUME_ON_READ_WRITE;
256
257 return connectx(sockfd, &ep, SAE_ASSOCID_ANY, flags, NULL, 0, NULL, NULL);
258 #elif defined(__linux__)
259 // connect() will be called implicitly with sendto(), sendmsg()
260 return 0;
261 #else
262 errno = ENOTSUP;
263 return -1;
264 #endif
265 }
266
267 /*!
268 * Sends data with TCP Fast Open.
269 */
fastopen_send(int sockfd,const struct msghdr * msg,int timeout)270 static int fastopen_send(int sockfd, const struct msghdr *msg, int timeout)
271 {
272 #if (defined(__FreeBSD__) && defined(TCP_FASTOPEN))|| defined(__APPLE__)
273 return sendmsg(sockfd, msg, 0);
274 #elif defined(__linux__)
275 int ret = sendmsg(sockfd, msg, MSG_FASTOPEN);
276 if (ret == -1 && errno == EINPROGRESS) {
277 struct pollfd pfd = {
278 .fd = sockfd,
279 .events = POLLOUT,
280 .revents = 0,
281 };
282 if (poll(&pfd, 1, 1000 * timeout) != 1) {
283 errno = ETIMEDOUT;
284 return -1;
285 }
286 ret = sendmsg(sockfd, msg, 0);
287 }
288 return ret;
289 #else
290 errno = ENOTSUP;
291 return -1;
292 #endif
293 }
294
net_connect(net_t * net)295 int net_connect(net_t *net)
296 {
297 if (net == NULL || net->srv == NULL) {
298 DBG_NULL;
299 return KNOT_EINVAL;
300 }
301
302 // Set remote information string.
303 get_addr_str((struct sockaddr_storage *)net->srv->ai_addr,
304 net->socktype, &net->remote_str);
305
306 // Create socket.
307 int sockfd = socket(net->srv->ai_family, net->socktype, 0);
308 if (sockfd == -1) {
309 WARN("can't create socket for %s\n", net->remote_str);
310 return KNOT_NET_ESOCKET;
311 }
312
313 // Initialize poll descriptor structure.
314 struct pollfd pfd = {
315 .fd = sockfd,
316 .events = POLLOUT,
317 .revents = 0,
318 };
319
320 // Set non-blocking socket.
321 if (fcntl(sockfd, F_SETFL, O_NONBLOCK) == -1) {
322 WARN("can't set non-blocking socket for %s\n", net->remote_str);
323 return KNOT_NET_ESOCKET;
324 }
325
326 // Bind address to socket if specified.
327 if (net->local_info != NULL) {
328 if (bind(sockfd, net->local_info->ai_addr,
329 net->local_info->ai_addrlen) == -1) {
330 WARN("can't assign address %s\n", net->local->name);
331 return KNOT_NET_ESOCKET;
332 }
333 } else {
334 // Ensure source port is always randomized (even for TCP).
335 struct sockaddr_storage local = { .ss_family = net->srv->ai_family };
336 (void)bind(sockfd, (struct sockaddr *)&local, sockaddr_len(&local));
337 }
338
339 if (net->socktype == SOCK_STREAM) {
340 int cs, err, ret = 0;
341 socklen_t err_len = sizeof(err);
342 bool fastopen = net->flags & NET_FLAGS_FASTOPEN;
343
344 // Establish a connection.
345 if (net->tls.params == NULL || !fastopen) {
346 if (fastopen) {
347 ret = fastopen_connect(sockfd, net->srv);
348 } else {
349 ret = connect(sockfd, net->srv->ai_addr, net->srv->ai_addrlen);
350 }
351 if (ret != 0 && errno != EINPROGRESS) {
352 WARN("can't connect to %s\n", net->remote_str);
353 close(sockfd);
354 return KNOT_NET_ECONNECT;
355 }
356
357 // Check for connection timeout.
358 if (!fastopen && poll(&pfd, 1, 1000 * net->wait) != 1) {
359 WARN("connection timeout for %s\n", net->remote_str);
360 close(sockfd);
361 return KNOT_NET_ECONNECT;
362 }
363
364 // Check if NB socket is writeable.
365 cs = getsockopt(sockfd, SOL_SOCKET, SO_ERROR, &err, &err_len);
366 if (cs < 0 || err != 0) {
367 WARN("can't connect to %s\n", net->remote_str);
368 close(sockfd);
369 return KNOT_NET_ECONNECT;
370 }
371 }
372
373 if (net->tls.params != NULL) {
374 #ifdef LIBNGHTTP2
375 if (net->https.params.enable) {
376 // Establish HTTPS connection.
377 char *remote = NULL;
378 if (net->tls.params->sni != NULL) {
379 remote = net->tls.params->sni;
380 } else if (net->tls.params->hostname != NULL) {
381 remote = net->tls.params->hostname;
382 } else if (strchr(net->remote_str, ':') == NULL) {
383 char *at = strchr(net->remote_str, '@');
384 if (at != NULL && strncmp(net->remote->name, net->remote_str,
385 at - net->remote_str)) {
386 remote = net->remote->name;
387 }
388 }
389 ret = https_ctx_connect(&net->https, sockfd, remote, fastopen,
390 (struct sockaddr_storage *)net->srv->ai_addr);
391 } else {
392 #endif //LIBNGHTTP2
393 // Establish TLS connection.
394 ret = tls_ctx_connect(&net->tls, sockfd, net->tls.params->sni, fastopen,
395 (struct sockaddr_storage *)net->srv->ai_addr, &dot_alpn);
396 #ifdef LIBNGHTTP2
397 }
398 #endif //LIBNGHTTP2
399 if (ret != KNOT_EOK) {
400 close(sockfd);
401 return ret;
402 }
403 }
404 }
405
406 // Store socket descriptor.
407 net->sockfd = sockfd;
408
409 return KNOT_EOK;
410 }
411
net_set_local_info(net_t * net)412 int net_set_local_info(net_t *net)
413 {
414 if (net == NULL) {
415 DBG_NULL;
416 return KNOT_EINVAL;
417 }
418
419 socklen_t local_addr_len = sizeof(struct sockaddr_storage);
420
421 struct addrinfo *new_info = calloc(1, sizeof(*new_info) + local_addr_len);
422 if (new_info == NULL) {
423 return KNOT_ENOMEM;
424 }
425
426 new_info->ai_addr = (struct sockaddr *)(new_info + 1);
427 new_info->ai_family = net->srv->ai_family;
428 new_info->ai_socktype = net->srv->ai_socktype;
429 new_info->ai_protocol = net->srv->ai_protocol;
430 new_info->ai_addrlen = local_addr_len;
431
432 if (getsockname(net->sockfd, new_info->ai_addr, &local_addr_len) == -1) {
433 WARN("can't get local address\n");
434 free(new_info);
435 return KNOT_NET_ESOCKET;
436 }
437
438 if (net->local_info != NULL) {
439 if (net->local == NULL) {
440 free(net->local_info);
441 } else {
442 freeaddrinfo(net->local_info);
443 }
444 }
445
446 net->local_info = new_info;
447
448 get_addr_str((struct sockaddr_storage *)net->local_info->ai_addr,
449 net->socktype, &net->local_str);
450
451 return KNOT_EOK;
452 }
453
net_send(const net_t * net,const uint8_t * buf,const size_t buf_len)454 int net_send(const net_t *net, const uint8_t *buf, const size_t buf_len)
455 {
456 if (net == NULL || buf == NULL) {
457 DBG_NULL;
458 return KNOT_EINVAL;
459 }
460
461 // Send data over UDP.
462 if (net->socktype == SOCK_DGRAM) {
463 if (sendto(net->sockfd, buf, buf_len, 0, net->srv->ai_addr,
464 net->srv->ai_addrlen) != (ssize_t)buf_len) {
465 WARN("can't send query to %s\n", net->remote_str);
466 return KNOT_NET_ESEND;
467 }
468 #ifdef LIBNGHTTP2
469 // Send data over HTTPS
470 } else if (net->https.params.enable) {
471 int ret = https_send_dns_query((https_ctx_t *)&net->https, buf, buf_len);
472 if (ret != KNOT_EOK) {
473 WARN("can't send query to %s\n", net->remote_str);
474 return KNOT_NET_ESEND;
475 }
476 #endif //LIBNGHTTP2
477 // Send data over TLS.
478 } else if (net->tls.params != NULL) {
479 int ret = tls_ctx_send((tls_ctx_t *)&net->tls, buf, buf_len);
480 if (ret != KNOT_EOK) {
481 WARN("can't send query to %s\n", net->remote_str);
482 return KNOT_NET_ESEND;
483 }
484 // Send data over TCP.
485 } else {
486 bool fastopen = net->flags & NET_FLAGS_FASTOPEN;
487
488 // Leading packet length bytes.
489 uint16_t pktsize = htons(buf_len);
490
491 struct iovec iov[2];
492 iov[0].iov_base = &pktsize;
493 iov[0].iov_len = sizeof(pktsize);
494 iov[1].iov_base = (uint8_t *)buf;
495 iov[1].iov_len = buf_len;
496
497 // Compute packet total length.
498 ssize_t total = iov[0].iov_len + iov[1].iov_len;
499
500 struct msghdr msg = {0};
501 msg.msg_iov = iov;
502 msg.msg_iovlen = sizeof(iov) / sizeof(*iov);
503 msg.msg_name = net->srv->ai_addr;
504 msg.msg_namelen = net->srv->ai_addrlen;
505
506 int ret = 0;
507 if (fastopen) {
508 ret = fastopen_send(net->sockfd, &msg, net->wait);
509 } else {
510 ret = sendmsg(net->sockfd, &msg, 0);
511 }
512 if (ret != total) {
513 WARN("can't send query to %s\n", net->remote_str);
514 return KNOT_NET_ESEND;
515 }
516 }
517
518 return KNOT_EOK;
519 }
520
net_receive(const net_t * net,uint8_t * buf,const size_t buf_len)521 int net_receive(const net_t *net, uint8_t *buf, const size_t buf_len)
522 {
523 if (net == NULL || buf == NULL) {
524 DBG_NULL;
525 return KNOT_EINVAL;
526 }
527
528 // Initialize poll descriptor structure.
529 struct pollfd pfd = {
530 .fd = net->sockfd,
531 .events = POLLIN,
532 .revents = 0,
533 };
534
535 // Receive data over UDP.
536 if (net->socktype == SOCK_DGRAM) {
537 struct sockaddr_storage from;
538 memset(&from, '\0', sizeof(from));
539
540 // Receive replies unless correct reply or timeout.
541 while (true) {
542 socklen_t from_len = sizeof(from);
543
544 // Wait for datagram data.
545 if (poll(&pfd, 1, 1000 * net->wait) != 1) {
546 WARN("response timeout for %s\n",
547 net->remote_str);
548 return KNOT_NET_ETIMEOUT;
549 }
550
551 // Receive whole UDP datagram.
552 ssize_t ret = recvfrom(net->sockfd, buf, buf_len, 0,
553 (struct sockaddr *)&from, &from_len);
554 if (ret <= 0) {
555 WARN("can't receive reply from %s\n",
556 net->remote_str);
557 return KNOT_NET_ERECV;
558 }
559
560 // Compare reply address with the remote one.
561 if (from_len > sizeof(from) ||
562 memcmp(&from, net->srv->ai_addr, from_len) != 0) {
563 char *src = NULL;
564 get_addr_str(&from, net->socktype, &src);
565 WARN("unexpected reply source %s\n", src);
566 free(src);
567 continue;
568 }
569
570 return ret;
571 }
572 #ifdef LIBNGHTTP2
573 // Receive data over HTTPS.
574 } else if (net->https.params.enable) {
575 return https_recv_dns_response((https_ctx_t *)&net->https, buf, buf_len);
576 #endif //LIBNGHTTP2
577 // Receive data over TLS.
578 } else if (net->tls.params != NULL) {
579 int ret = tls_ctx_receive((tls_ctx_t *)&net->tls, buf, buf_len);
580 if (ret < 0) {
581 WARN("can't receive reply from %s\n", net->remote_str);
582 return KNOT_NET_ERECV;
583 }
584
585 return ret;
586 // Receive data over TCP.
587 } else {
588 uint32_t total = 0;
589
590 uint16_t msg_len = 0;
591 // Receive TCP message header.
592 while (total < sizeof(msg_len)) {
593 if (poll(&pfd, 1, 1000 * net->wait) != 1) {
594 WARN("response timeout for %s\n",
595 net->remote_str);
596 return KNOT_NET_ETIMEOUT;
597 }
598
599 // Receive piece of message.
600 ssize_t ret = recv(net->sockfd, (uint8_t *)&msg_len + total,
601 sizeof(msg_len) - total, 0);
602 if (ret <= 0) {
603 WARN("can't receive reply from %s\n",
604 net->remote_str);
605 return KNOT_NET_ERECV;
606 }
607 total += ret;
608 }
609
610 // Convert number to host format.
611 msg_len = ntohs(msg_len);
612 if (msg_len > buf_len) {
613 return KNOT_ESPACE;
614 }
615
616 total = 0;
617
618 // Receive whole answer message by parts.
619 while (total < msg_len) {
620 if (poll(&pfd, 1, 1000 * net->wait) != 1) {
621 WARN("response timeout for %s\n",
622 net->remote_str);
623 return KNOT_NET_ETIMEOUT;
624 }
625
626 // Receive piece of message.
627 ssize_t ret = recv(net->sockfd, buf + total, msg_len - total, 0);
628 if (ret <= 0) {
629 WARN("can't receive reply from %s\n",
630 net->remote_str);
631 return KNOT_NET_ERECV;
632 }
633 total += ret;
634 }
635
636 return total;
637 }
638
639 return KNOT_NET_ERECV;
640 }
641
net_close(net_t * net)642 void net_close(net_t *net)
643 {
644 if (net == NULL) {
645 DBG_NULL;
646 return;
647 }
648
649 tls_ctx_close(&net->tls);
650 close(net->sockfd);
651 net->sockfd = -1;
652 }
653
net_clean(net_t * net)654 void net_clean(net_t *net)
655 {
656 if (net == NULL) {
657 DBG_NULL;
658 return;
659 }
660
661 free(net->local_str);
662 free(net->remote_str);
663 net->local_str = NULL;
664 net->remote_str = NULL;
665
666 if (net->local_info != NULL) {
667 if (net->local == NULL) {
668 free(net->local_info);
669 } else {
670 freeaddrinfo(net->local_info);
671 }
672 net->local_info = NULL;
673 }
674
675 if (net->remote_info != NULL) {
676 freeaddrinfo(net->remote_info);
677 net->remote_info = NULL;
678 }
679
680 #ifdef LIBNGHTTP2
681 https_ctx_deinit(&net->https);
682 #endif
683 tls_ctx_deinit(&net->tls);
684 }
685