1 /*
2  * Copyright © 2019 Manuel Stoeckl
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining
5  * a copy of this software and associated documentation files (the
6  * "Software"), to deal in the Software without restriction, including
7  * without limitation the rights to use, copy, modify, merge, publish,
8  * distribute, sublicense, and/or sell copies of the Software, and to
9  * permit persons to whom the Software is furnished to do so, subject to
10  * the following conditions:
11  *
12  * The above copyright notice and this permission notice (including the
13  * next paragraph) shall be included in all copies or substantial
14  * portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
19  * NONINFRINGEMENT.  IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
20  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
21  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
22  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23  * SOFTWARE.
24  */
25 
26 #include "util.h"
27 
28 #include <errno.h>
29 #include <fcntl.h>
30 #include <inttypes.h>
31 #include <poll.h>
32 #include <stdarg.h>
33 #include <stdbool.h>
34 #include <stdio.h>
35 #include <stdlib.h>
36 #include <string.h>
37 #include <sys/socket.h>
38 #include <sys/time.h>
39 #include <sys/un.h>
40 #include <sys/wait.h>
41 #include <time.h>
42 #include <unistd.h>
43 
parse_uint32(const char * str,uint32_t * val)44 int parse_uint32(const char *str, uint32_t *val)
45 {
46 	if (!str[0] || (str[0] == '0' && str[1])) {
47 		return -1;
48 	}
49 	uint64_t v = 0;
50 	for (const char *cursor = str; *cursor; cursor++) {
51 		if (*cursor < '0' || *cursor > '9') {
52 			return -1;
53 		}
54 		uint64_t s = (uint64_t)(*cursor - '0');
55 		v *= 10;
56 		v += s;
57 		if (v >= (1uLL << 32)) {
58 			return -1;
59 		}
60 	}
61 	*val = (uint32_t)v;
62 	return 0;
63 }
64 
65 /* An integer-to-string converter which is async-signal-safe, unlike sprintf */
uint_to_str(uint32_t i,char buf[static11])66 static char *uint_to_str(uint32_t i, char buf[static 11])
67 {
68 	char *pos = &buf[10];
69 	*pos = '\0';
70 	while (i) {
71 		--pos;
72 		*pos = (char)((i % 10) + (uint32_t)'0');
73 		i /= 10;
74 	}
75 	return pos;
76 }
multi_strcat(char * dest,size_t dest_space,...)77 size_t multi_strcat(char *dest, size_t dest_space, ...)
78 {
79 	size_t net_len = 0;
80 	va_list args;
81 	va_start(args, dest_space);
82 	while (true) {
83 		const char *str = va_arg(args, const char *);
84 		if (!str) {
85 			break;
86 		}
87 		net_len += strlen(str);
88 		if (net_len >= dest_space) {
89 			va_end(args);
90 			dest[0] = '\0';
91 			return 0;
92 		}
93 	}
94 	va_end(args);
95 	va_start(args, dest_space);
96 	char *pos = dest;
97 	while (true) {
98 		const char *str = va_arg(args, const char *);
99 		if (!str) {
100 			break;
101 		}
102 		size_t len = strlen(str);
103 		memcpy(pos, str, len);
104 		pos += len;
105 	}
106 	va_end(args);
107 	*pos = '\0';
108 	return net_len;
109 }
110 
111 bool shutdown_flag = false;
112 uint64_t inherited_fds[4] = {0, 0, 0, 0};
handle_sigint(int sig)113 void handle_sigint(int sig)
114 {
115 	(void)sig;
116 	char buf[48];
117 	char tmp[11];
118 	const char *pidstr = uint_to_str((uint32_t)getpid(), tmp);
119 	const char *trailing = shutdown_flag ? "), second interrupt, aborting\n"
120 					     : ")\n";
121 	size_t len = multi_strcat(
122 			buf, sizeof(buf), "SIGINT(", pidstr, trailing, NULL);
123 	(void)write(STDERR_FILENO, buf, len);
124 
125 	if (!shutdown_flag) {
126 		shutdown_flag = true;
127 	} else {
128 		abort();
129 	}
130 }
131 
set_nonblocking(int fd)132 int set_nonblocking(int fd)
133 {
134 	int flags = fcntl(fd, F_GETFL, 0);
135 	if (flags == -1) {
136 		return -1;
137 	}
138 	return fcntl(fd, F_SETFL, flags | O_NONBLOCK);
139 }
140 
set_cloexec(int fd)141 int set_cloexec(int fd)
142 {
143 	int flags = fcntl(fd, F_GETFD, 0);
144 	if (flags == -1) {
145 		return -1;
146 	}
147 	return fcntl(fd, F_SETFD, flags | O_CLOEXEC);
148 }
149 
setup_nb_socket(const struct sockaddr_un * socket_addr,int nmaxclients)150 int setup_nb_socket(const struct sockaddr_un *socket_addr, int nmaxclients)
151 {
152 	struct sockaddr_un saddr = *socket_addr;
153 	saddr.sun_family = AF_UNIX;
154 
155 	int sock = socket(AF_UNIX, SOCK_STREAM, 0);
156 	if (sock == -1) {
157 		wp_error("Error creating socket: %s", strerror(errno));
158 		return -1;
159 	}
160 	if (set_nonblocking(sock) == -1) {
161 		wp_error("Error making socket nonblocking: %s",
162 				strerror(errno));
163 		checked_close(sock);
164 		return -1;
165 	}
166 	if (bind(sock, (struct sockaddr *)&saddr, sizeof(saddr)) == -1) {
167 		wp_error("Error binding socket at %s: %s", saddr.sun_path,
168 				strerror(errno));
169 		checked_close(sock);
170 		return -1;
171 	}
172 	if (listen(sock, nmaxclients) == -1) {
173 		wp_error("Error listening to socket at %s: %s", saddr.sun_path,
174 				strerror(errno));
175 		checked_close(sock);
176 		unlink(saddr.sun_path);
177 		return -1;
178 	}
179 	return sock;
180 }
181 
connect_to_socket(const struct sockaddr_un * socket_addr)182 int connect_to_socket(const struct sockaddr_un *socket_addr)
183 {
184 	struct sockaddr_un saddr = *socket_addr;
185 	int chanfd;
186 	saddr.sun_family = AF_UNIX;
187 
188 	chanfd = socket(AF_UNIX, SOCK_STREAM, 0);
189 	if (chanfd == -1) {
190 		wp_error("Error creating socket: %s", strerror(errno));
191 		return -1;
192 	}
193 
194 	if (connect(chanfd, (struct sockaddr *)&saddr, sizeof(saddr)) == -1) {
195 		wp_error("Error connecting to socket (%s): %s", saddr.sun_path,
196 				strerror(errno));
197 		checked_close(chanfd);
198 		return -1;
199 	}
200 	return chanfd;
201 }
202 
set_initial_fds(void)203 void set_initial_fds(void)
204 {
205 	struct pollfd checklist[256];
206 	for (int i = 0; i < 256; i++) {
207 		checklist[i].fd = i;
208 		checklist[i].events = 0;
209 		checklist[i].revents = 0;
210 	}
211 	if (poll(checklist, 256, 0) == -1) {
212 		wp_error("fd-checking poll failed: %s", strerror(errno));
213 		return;
214 	}
215 	for (int i = 0; i < 256; i++) {
216 		if (!(checklist[i].revents & POLLNVAL)) {
217 			inherited_fds[i / 64] |= (1uLL << (i % 64));
218 		}
219 	}
220 }
221 
check_unclosed_fds(void)222 void check_unclosed_fds(void)
223 {
224 	/* Verify that all file descriptors have been closed. Since most
225 	 * instances have <<256 file descriptors open at a given time, it is
226 	 * safe to only check up to that level */
227 	struct pollfd checklist[256];
228 	for (int i = 0; i < 256; i++) {
229 		checklist[i].fd = i;
230 		checklist[i].events = 0;
231 		checklist[i].revents = 0;
232 	}
233 	if (poll(checklist, 256, 0) == -1) {
234 		wp_error("fd-checking poll failed: %s", strerror(errno));
235 		return;
236 	}
237 	for (int i = 0; i < 256; i++) {
238 		bool initial_fd = (inherited_fds[i / 64] &
239 						  (1uLL << (i % 64))) != 0;
240 		if (initial_fd) {
241 			if (checklist[i].revents & POLLNVAL) {
242 				wp_error("Unexpected closed fd %d", i);
243 			}
244 		} else {
245 			if (checklist[i].revents & POLLNVAL) {
246 				continue;
247 			}
248 #ifdef __linux__
249 			char fd_path[64];
250 			char link[256];
251 			sprintf(fd_path, "/proc/self/fd/%d", i);
252 			ssize_t len = readlink(fd_path, link, sizeof(link) - 1);
253 			if (len == -1) {
254 				wp_error("Failed to readlink /proc/self/fd/%d for unexpected open fd %d",
255 						i, i);
256 			} else {
257 				link[len] = '\0';
258 				if (!strcmp(link, "/var/lib/sss/mc/passwd")) {
259 					wp_debug("Known issue, leaked fd %d to /var/lib/sss/mc/passwd",
260 							i);
261 				} else {
262 					wp_error("Unexpected open fd %d: %s", i,
263 							link);
264 				}
265 			}
266 #else
267 			wp_error("Unexpected open fd %d", i);
268 #endif
269 		}
270 	}
271 }
272 
print_display_error(char * dest,size_t dest_space,uint32_t error_code,const char * message)273 size_t print_display_error(char *dest, size_t dest_space, uint32_t error_code,
274 		const char *message)
275 {
276 	if (dest_space < 20) {
277 		return 0;
278 	}
279 	size_t msg_len = strlen(message) + 1;
280 	size_t net_len = 4 * ((msg_len + 0x3) / 4) + 20;
281 	if (net_len > dest_space) {
282 		return 0;
283 	}
284 	uint32_t header[5] = {0x1, (uint32_t)net_len << 16, 0x1, error_code,
285 			(uint32_t)msg_len};
286 	memcpy(dest, header, sizeof(header));
287 	memcpy(dest + sizeof(header), message, msg_len);
288 	if (msg_len % 4 != 0) {
289 		size_t trailing = 4 - msg_len % 4;
290 		uint8_t zeros[4] = {0, 0, 0, 0};
291 		memcpy(dest + sizeof(header) + msg_len, zeros, trailing);
292 	}
293 	return net_len;
294 }
295 
print_wrapped_error(char * dest,size_t dest_space,const char * message)296 size_t print_wrapped_error(char *dest, size_t dest_space, const char *message)
297 {
298 	size_t msg_len = print_display_error(
299 			dest + 4, dest_space - 4, 3, message);
300 	if (msg_len == 0) {
301 		return 0;
302 	}
303 	uint32_t header = transfer_header(msg_len + 4, WMSG_PROTOCOL);
304 	memcpy(dest, &header, sizeof(header));
305 	return msg_len + 4;
306 }
307 
send_one_fd(int socket,int fd)308 int send_one_fd(int socket, int fd)
309 {
310 	union {
311 		char buf[CMSG_SPACE(sizeof(int))];
312 		struct cmsghdr align;
313 	} uc;
314 	memset(uc.buf, 0, sizeof(uc.buf));
315 	struct cmsghdr *frst = (struct cmsghdr *)(uc.buf);
316 	frst->cmsg_level = SOL_SOCKET;
317 	frst->cmsg_type = SCM_RIGHTS;
318 	*((int *)CMSG_DATA(frst)) = fd;
319 	frst->cmsg_len = CMSG_LEN(sizeof(int));
320 
321 	struct iovec the_iovec;
322 	the_iovec.iov_len = 1;
323 	uint8_t dummy_data = 1;
324 	the_iovec.iov_base = &dummy_data;
325 	struct msghdr msg;
326 	msg.msg_name = NULL;
327 	msg.msg_namelen = 0;
328 	msg.msg_iov = &the_iovec;
329 	msg.msg_iovlen = 1;
330 	msg.msg_flags = 0;
331 	msg.msg_control = uc.buf;
332 	msg.msg_controllen = CMSG_SPACE(sizeof(int));
333 
334 	return (int)sendmsg(socket, &msg, 0);
335 }
336 
wait_for_pid_and_clean(pid_t * target_pid,int * status,int options,struct conn_map * map)337 bool wait_for_pid_and_clean(pid_t *target_pid, int *status, int options,
338 		struct conn_map *map)
339 {
340 	bool found = false;
341 	while (1) {
342 		int stat;
343 		pid_t r = waitpid((pid_t)-1, &stat, options);
344 		if (r == 0 || (r == -1 && (errno == ECHILD ||
345 							  errno == EINTR))) {
346 			// Valid exit reasons, not an error
347 			errno = 0;
348 			return found;
349 		} else if (r == -1) {
350 			wp_error("waitpid failed: %s", strerror(errno));
351 			return found;
352 		}
353 
354 		wp_debug("Child process %d has died", r);
355 		if (map) {
356 			/* Clean out all entries matching that pid */
357 			int iw = 0;
358 			for (int ir = 0; ir < map->count; ir++) {
359 				map->data[iw] = map->data[ir];
360 				if (map->data[ir].pid != r) {
361 					iw++;
362 				} else {
363 					checked_close(map->data[ir].linkfd);
364 				}
365 			}
366 			map->count = iw;
367 		}
368 
369 		if (r == *target_pid) {
370 			*target_pid = 0;
371 			*status = stat;
372 			found = true;
373 		}
374 	}
375 }
376 
buf_ensure_size(int count,size_t obj_size,int * space,void ** data)377 int buf_ensure_size(int count, size_t obj_size, int *space, void **data)
378 {
379 	int x = *space;
380 	if (count <= x) {
381 		return 0;
382 	}
383 	if (count >= INT32_MAX / 2 || count <= 0) {
384 		return -1;
385 	}
386 	if (x < 1) {
387 		x = 1;
388 	}
389 	while (x < count) {
390 		x *= 2;
391 	}
392 	void *new_data = realloc(*data, (size_t)x * obj_size);
393 	if (!new_data) {
394 		return -1;
395 	}
396 	*data = new_data;
397 	*space = x;
398 	return 0;
399 }
400 
401 static const char *const wmsg_types[] = {
402 		"WMSG_PROTOCOL",
403 		"WMSG_INJECT_RIDS",
404 		"WMSG_OPEN_FILE",
405 		"WMSG_EXTEND_FILE",
406 		"WMSG_OPEN_DMABUF",
407 		"WMSG_BUFFER_FILL",
408 		"WMSG_BUFFER_DIFF",
409 		"WMSG_OPEN_IR_PIPE",
410 		"WMSG_OPEN_IW_PIPE",
411 		"WMSG_OPEN_RW_PIPE",
412 		"WMSG_PIPE_TRANSFER",
413 		"WMSG_PIPE_SHUTDOWN_R",
414 		"WMSG_PIPE_SHUTDOWN_W",
415 		"WMSG_OPEN_DMAVID_SRC",
416 		"WMSG_OPEN_DMAVID_DST",
417 		"WMSG_SEND_DMAVID_PACKET",
418 		"WMSG_ACK_NBLOCKS",
419 		"WMSG_RESTART",
420 		"WMSG_CLOSE",
421 		"WMSG_OPEN_DMAVID_SRC_V2",
422 		"WMSG_OPEN_DMAVID_DST_V2",
423 };
wmsg_type_to_str(enum wmsg_type tp)424 const char *wmsg_type_to_str(enum wmsg_type tp)
425 {
426 	if (tp >= sizeof(wmsg_types) / sizeof(wmsg_types[0])) {
427 		return "???";
428 	}
429 	return wmsg_types[tp];
430 }
wmsg_type_is_known(enum wmsg_type tp)431 bool wmsg_type_is_known(enum wmsg_type tp)
432 {
433 	return (size_t)tp < (sizeof(wmsg_types) / sizeof(wmsg_types[0]));
434 }
435 
transfer_ensure_size(struct transfer_queue * transfers,int count)436 int transfer_ensure_size(struct transfer_queue *transfers, int count)
437 {
438 	int sz = transfers->size;
439 	if (buf_ensure_size(count, sizeof(*transfers->vecs), &sz,
440 			    (void **)&transfers->vecs) == -1) {
441 		return -1;
442 	}
443 	sz = transfers->size;
444 	if (buf_ensure_size(count, sizeof(*transfers->meta), &sz,
445 			    (void **)&transfers->meta) == -1) {
446 		return -1;
447 	}
448 	transfers->size = sz;
449 	return 0;
450 }
451 
transfer_add(struct transfer_queue * w,size_t size,void * data)452 int transfer_add(struct transfer_queue *w, size_t size, void *data)
453 {
454 	if (size == 0) {
455 		return 0;
456 	}
457 	if (transfer_ensure_size(w, w->end + 1) == -1) {
458 		return -1;
459 	}
460 
461 	w->vecs[w->end].iov_len = size;
462 	w->vecs[w->end].iov_base = data;
463 	w->meta[w->end].msgno = w->last_msgno;
464 	w->meta[w->end].static_alloc = false;
465 	w->end++;
466 	w->last_msgno++;
467 	return 0;
468 }
469 
transfer_async_add(struct thread_msg_recv_buf * q,void * data,size_t sz)470 void transfer_async_add(struct thread_msg_recv_buf *q, void *data, size_t sz)
471 {
472 	struct iovec vec;
473 	vec.iov_len = sz;
474 	vec.iov_base = data;
475 	pthread_mutex_lock(&q->lock);
476 	q->data[q->zone_end++] = vec;
477 	pthread_mutex_unlock(&q->lock);
478 }
479 
transfer_load_async(struct transfer_queue * w)480 int transfer_load_async(struct transfer_queue *w)
481 {
482 	pthread_mutex_lock(&w->async_recv_queue.lock);
483 	int zstart = w->async_recv_queue.zone_start;
484 	int zend = w->async_recv_queue.zone_end;
485 	w->async_recv_queue.zone_start = zend;
486 	pthread_mutex_unlock(&w->async_recv_queue.lock);
487 
488 	for (int i = zstart; i < zend; i++) {
489 		struct iovec v = w->async_recv_queue.data[i];
490 		memset(&w->async_recv_queue.data[i], 0, sizeof(struct iovec));
491 		if (v.iov_len == 0 || v.iov_base == NULL) {
492 			wp_error("Unexpected empty message");
493 			continue;
494 		}
495 		/* Only fill/diff messages are received async, so msgno
496 		 * is always incremented */
497 		if (transfer_add(w, v.iov_len, v.iov_base) == -1) {
498 			wp_error("Failed to add message to transfer queue");
499 			pthread_mutex_unlock(&w->async_recv_queue.lock);
500 			return -1;
501 		}
502 	}
503 	return 0;
504 }
505 
cleanup_transfer_queue(struct transfer_queue * td)506 void cleanup_transfer_queue(struct transfer_queue *td)
507 {
508 	for (int i = td->async_recv_queue.zone_start;
509 			i < td->async_recv_queue.zone_end; i++) {
510 		free(td->async_recv_queue.data[i].iov_base);
511 	}
512 	pthread_mutex_destroy(&td->async_recv_queue.lock);
513 	free(td->async_recv_queue.data);
514 	for (int i = 0; i < td->end; i++) {
515 		if (!td->meta[i].static_alloc) {
516 			free(td->vecs[i].iov_base);
517 		}
518 	}
519 	free(td->vecs);
520 	free(td->meta);
521 }
522