1 /*
2 * Copyright © 2019 Manuel Stoeckl
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining
5 * a copy of this software and associated documentation files (the
6 * "Software"), to deal in the Software without restriction, including
7 * without limitation the rights to use, copy, modify, merge, publish,
8 * distribute, sublicense, and/or sell copies of the Software, and to
9 * permit persons to whom the Software is furnished to do so, subject to
10 * the following conditions:
11 *
12 * The above copyright notice and this permission notice (including the
13 * next paragraph) shall be included in all copies or substantial
14 * portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
19 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
20 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
21 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
22 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 * SOFTWARE.
24 */
25
26 #include "util.h"
27
28 #include <errno.h>
29 #include <fcntl.h>
30 #include <inttypes.h>
31 #include <poll.h>
32 #include <stdarg.h>
33 #include <stdbool.h>
34 #include <stdio.h>
35 #include <stdlib.h>
36 #include <string.h>
37 #include <sys/socket.h>
38 #include <sys/time.h>
39 #include <sys/un.h>
40 #include <sys/wait.h>
41 #include <time.h>
42 #include <unistd.h>
43
parse_uint32(const char * str,uint32_t * val)44 int parse_uint32(const char *str, uint32_t *val)
45 {
46 if (!str[0] || (str[0] == '0' && str[1])) {
47 return -1;
48 }
49 uint64_t v = 0;
50 for (const char *cursor = str; *cursor; cursor++) {
51 if (*cursor < '0' || *cursor > '9') {
52 return -1;
53 }
54 uint64_t s = (uint64_t)(*cursor - '0');
55 v *= 10;
56 v += s;
57 if (v >= (1uLL << 32)) {
58 return -1;
59 }
60 }
61 *val = (uint32_t)v;
62 return 0;
63 }
64
65 /* An integer-to-string converter which is async-signal-safe, unlike sprintf */
uint_to_str(uint32_t i,char buf[static11])66 static char *uint_to_str(uint32_t i, char buf[static 11])
67 {
68 char *pos = &buf[10];
69 *pos = '\0';
70 while (i) {
71 --pos;
72 *pos = (char)((i % 10) + (uint32_t)'0');
73 i /= 10;
74 }
75 return pos;
76 }
multi_strcat(char * dest,size_t dest_space,...)77 size_t multi_strcat(char *dest, size_t dest_space, ...)
78 {
79 size_t net_len = 0;
80 va_list args;
81 va_start(args, dest_space);
82 while (true) {
83 const char *str = va_arg(args, const char *);
84 if (!str) {
85 break;
86 }
87 net_len += strlen(str);
88 if (net_len >= dest_space) {
89 va_end(args);
90 dest[0] = '\0';
91 return 0;
92 }
93 }
94 va_end(args);
95 va_start(args, dest_space);
96 char *pos = dest;
97 while (true) {
98 const char *str = va_arg(args, const char *);
99 if (!str) {
100 break;
101 }
102 size_t len = strlen(str);
103 memcpy(pos, str, len);
104 pos += len;
105 }
106 va_end(args);
107 *pos = '\0';
108 return net_len;
109 }
110
111 bool shutdown_flag = false;
112 uint64_t inherited_fds[4] = {0, 0, 0, 0};
handle_sigint(int sig)113 void handle_sigint(int sig)
114 {
115 (void)sig;
116 char buf[48];
117 char tmp[11];
118 const char *pidstr = uint_to_str((uint32_t)getpid(), tmp);
119 const char *trailing = shutdown_flag ? "), second interrupt, aborting\n"
120 : ")\n";
121 size_t len = multi_strcat(
122 buf, sizeof(buf), "SIGINT(", pidstr, trailing, NULL);
123 (void)write(STDERR_FILENO, buf, len);
124
125 if (!shutdown_flag) {
126 shutdown_flag = true;
127 } else {
128 abort();
129 }
130 }
131
set_nonblocking(int fd)132 int set_nonblocking(int fd)
133 {
134 int flags = fcntl(fd, F_GETFL, 0);
135 if (flags == -1) {
136 return -1;
137 }
138 return fcntl(fd, F_SETFL, flags | O_NONBLOCK);
139 }
140
set_cloexec(int fd)141 int set_cloexec(int fd)
142 {
143 int flags = fcntl(fd, F_GETFD, 0);
144 if (flags == -1) {
145 return -1;
146 }
147 return fcntl(fd, F_SETFD, flags | O_CLOEXEC);
148 }
149
setup_nb_socket(const struct sockaddr_un * socket_addr,int nmaxclients)150 int setup_nb_socket(const struct sockaddr_un *socket_addr, int nmaxclients)
151 {
152 struct sockaddr_un saddr = *socket_addr;
153 saddr.sun_family = AF_UNIX;
154
155 int sock = socket(AF_UNIX, SOCK_STREAM, 0);
156 if (sock == -1) {
157 wp_error("Error creating socket: %s", strerror(errno));
158 return -1;
159 }
160 if (set_nonblocking(sock) == -1) {
161 wp_error("Error making socket nonblocking: %s",
162 strerror(errno));
163 checked_close(sock);
164 return -1;
165 }
166 if (bind(sock, (struct sockaddr *)&saddr, sizeof(saddr)) == -1) {
167 wp_error("Error binding socket at %s: %s", saddr.sun_path,
168 strerror(errno));
169 checked_close(sock);
170 return -1;
171 }
172 if (listen(sock, nmaxclients) == -1) {
173 wp_error("Error listening to socket at %s: %s", saddr.sun_path,
174 strerror(errno));
175 checked_close(sock);
176 unlink(saddr.sun_path);
177 return -1;
178 }
179 return sock;
180 }
181
connect_to_socket(const struct sockaddr_un * socket_addr)182 int connect_to_socket(const struct sockaddr_un *socket_addr)
183 {
184 struct sockaddr_un saddr = *socket_addr;
185 int chanfd;
186 saddr.sun_family = AF_UNIX;
187
188 chanfd = socket(AF_UNIX, SOCK_STREAM, 0);
189 if (chanfd == -1) {
190 wp_error("Error creating socket: %s", strerror(errno));
191 return -1;
192 }
193
194 if (connect(chanfd, (struct sockaddr *)&saddr, sizeof(saddr)) == -1) {
195 wp_error("Error connecting to socket (%s): %s", saddr.sun_path,
196 strerror(errno));
197 checked_close(chanfd);
198 return -1;
199 }
200 return chanfd;
201 }
202
set_initial_fds(void)203 void set_initial_fds(void)
204 {
205 struct pollfd checklist[256];
206 for (int i = 0; i < 256; i++) {
207 checklist[i].fd = i;
208 checklist[i].events = 0;
209 checklist[i].revents = 0;
210 }
211 if (poll(checklist, 256, 0) == -1) {
212 wp_error("fd-checking poll failed: %s", strerror(errno));
213 return;
214 }
215 for (int i = 0; i < 256; i++) {
216 if (!(checklist[i].revents & POLLNVAL)) {
217 inherited_fds[i / 64] |= (1uLL << (i % 64));
218 }
219 }
220 }
221
check_unclosed_fds(void)222 void check_unclosed_fds(void)
223 {
224 /* Verify that all file descriptors have been closed. Since most
225 * instances have <<256 file descriptors open at a given time, it is
226 * safe to only check up to that level */
227 struct pollfd checklist[256];
228 for (int i = 0; i < 256; i++) {
229 checklist[i].fd = i;
230 checklist[i].events = 0;
231 checklist[i].revents = 0;
232 }
233 if (poll(checklist, 256, 0) == -1) {
234 wp_error("fd-checking poll failed: %s", strerror(errno));
235 return;
236 }
237 for (int i = 0; i < 256; i++) {
238 bool initial_fd = (inherited_fds[i / 64] &
239 (1uLL << (i % 64))) != 0;
240 if (initial_fd) {
241 if (checklist[i].revents & POLLNVAL) {
242 wp_error("Unexpected closed fd %d", i);
243 }
244 } else {
245 if (checklist[i].revents & POLLNVAL) {
246 continue;
247 }
248 #ifdef __linux__
249 char fd_path[64];
250 char link[256];
251 sprintf(fd_path, "/proc/self/fd/%d", i);
252 ssize_t len = readlink(fd_path, link, sizeof(link) - 1);
253 if (len == -1) {
254 wp_error("Failed to readlink /proc/self/fd/%d for unexpected open fd %d",
255 i, i);
256 } else {
257 link[len] = '\0';
258 if (!strcmp(link, "/var/lib/sss/mc/passwd")) {
259 wp_debug("Known issue, leaked fd %d to /var/lib/sss/mc/passwd",
260 i);
261 } else {
262 wp_error("Unexpected open fd %d: %s", i,
263 link);
264 }
265 }
266 #else
267 wp_error("Unexpected open fd %d", i);
268 #endif
269 }
270 }
271 }
272
print_display_error(char * dest,size_t dest_space,uint32_t error_code,const char * message)273 size_t print_display_error(char *dest, size_t dest_space, uint32_t error_code,
274 const char *message)
275 {
276 if (dest_space < 20) {
277 return 0;
278 }
279 size_t msg_len = strlen(message) + 1;
280 size_t net_len = 4 * ((msg_len + 0x3) / 4) + 20;
281 if (net_len > dest_space) {
282 return 0;
283 }
284 uint32_t header[5] = {0x1, (uint32_t)net_len << 16, 0x1, error_code,
285 (uint32_t)msg_len};
286 memcpy(dest, header, sizeof(header));
287 memcpy(dest + sizeof(header), message, msg_len);
288 if (msg_len % 4 != 0) {
289 size_t trailing = 4 - msg_len % 4;
290 uint8_t zeros[4] = {0, 0, 0, 0};
291 memcpy(dest + sizeof(header) + msg_len, zeros, trailing);
292 }
293 return net_len;
294 }
295
print_wrapped_error(char * dest,size_t dest_space,const char * message)296 size_t print_wrapped_error(char *dest, size_t dest_space, const char *message)
297 {
298 size_t msg_len = print_display_error(
299 dest + 4, dest_space - 4, 3, message);
300 if (msg_len == 0) {
301 return 0;
302 }
303 uint32_t header = transfer_header(msg_len + 4, WMSG_PROTOCOL);
304 memcpy(dest, &header, sizeof(header));
305 return msg_len + 4;
306 }
307
send_one_fd(int socket,int fd)308 int send_one_fd(int socket, int fd)
309 {
310 union {
311 char buf[CMSG_SPACE(sizeof(int))];
312 struct cmsghdr align;
313 } uc;
314 memset(uc.buf, 0, sizeof(uc.buf));
315 struct cmsghdr *frst = (struct cmsghdr *)(uc.buf);
316 frst->cmsg_level = SOL_SOCKET;
317 frst->cmsg_type = SCM_RIGHTS;
318 *((int *)CMSG_DATA(frst)) = fd;
319 frst->cmsg_len = CMSG_LEN(sizeof(int));
320
321 struct iovec the_iovec;
322 the_iovec.iov_len = 1;
323 uint8_t dummy_data = 1;
324 the_iovec.iov_base = &dummy_data;
325 struct msghdr msg;
326 msg.msg_name = NULL;
327 msg.msg_namelen = 0;
328 msg.msg_iov = &the_iovec;
329 msg.msg_iovlen = 1;
330 msg.msg_flags = 0;
331 msg.msg_control = uc.buf;
332 msg.msg_controllen = CMSG_SPACE(sizeof(int));
333
334 return (int)sendmsg(socket, &msg, 0);
335 }
336
wait_for_pid_and_clean(pid_t * target_pid,int * status,int options,struct conn_map * map)337 bool wait_for_pid_and_clean(pid_t *target_pid, int *status, int options,
338 struct conn_map *map)
339 {
340 bool found = false;
341 while (1) {
342 int stat;
343 pid_t r = waitpid((pid_t)-1, &stat, options);
344 if (r == 0 || (r == -1 && (errno == ECHILD ||
345 errno == EINTR))) {
346 // Valid exit reasons, not an error
347 errno = 0;
348 return found;
349 } else if (r == -1) {
350 wp_error("waitpid failed: %s", strerror(errno));
351 return found;
352 }
353
354 wp_debug("Child process %d has died", r);
355 if (map) {
356 /* Clean out all entries matching that pid */
357 int iw = 0;
358 for (int ir = 0; ir < map->count; ir++) {
359 map->data[iw] = map->data[ir];
360 if (map->data[ir].pid != r) {
361 iw++;
362 } else {
363 checked_close(map->data[ir].linkfd);
364 }
365 }
366 map->count = iw;
367 }
368
369 if (r == *target_pid) {
370 *target_pid = 0;
371 *status = stat;
372 found = true;
373 }
374 }
375 }
376
buf_ensure_size(int count,size_t obj_size,int * space,void ** data)377 int buf_ensure_size(int count, size_t obj_size, int *space, void **data)
378 {
379 int x = *space;
380 if (count <= x) {
381 return 0;
382 }
383 if (count >= INT32_MAX / 2 || count <= 0) {
384 return -1;
385 }
386 if (x < 1) {
387 x = 1;
388 }
389 while (x < count) {
390 x *= 2;
391 }
392 void *new_data = realloc(*data, (size_t)x * obj_size);
393 if (!new_data) {
394 return -1;
395 }
396 *data = new_data;
397 *space = x;
398 return 0;
399 }
400
401 static const char *const wmsg_types[] = {
402 "WMSG_PROTOCOL",
403 "WMSG_INJECT_RIDS",
404 "WMSG_OPEN_FILE",
405 "WMSG_EXTEND_FILE",
406 "WMSG_OPEN_DMABUF",
407 "WMSG_BUFFER_FILL",
408 "WMSG_BUFFER_DIFF",
409 "WMSG_OPEN_IR_PIPE",
410 "WMSG_OPEN_IW_PIPE",
411 "WMSG_OPEN_RW_PIPE",
412 "WMSG_PIPE_TRANSFER",
413 "WMSG_PIPE_SHUTDOWN_R",
414 "WMSG_PIPE_SHUTDOWN_W",
415 "WMSG_OPEN_DMAVID_SRC",
416 "WMSG_OPEN_DMAVID_DST",
417 "WMSG_SEND_DMAVID_PACKET",
418 "WMSG_ACK_NBLOCKS",
419 "WMSG_RESTART",
420 "WMSG_CLOSE",
421 "WMSG_OPEN_DMAVID_SRC_V2",
422 "WMSG_OPEN_DMAVID_DST_V2",
423 };
wmsg_type_to_str(enum wmsg_type tp)424 const char *wmsg_type_to_str(enum wmsg_type tp)
425 {
426 if (tp >= sizeof(wmsg_types) / sizeof(wmsg_types[0])) {
427 return "???";
428 }
429 return wmsg_types[tp];
430 }
wmsg_type_is_known(enum wmsg_type tp)431 bool wmsg_type_is_known(enum wmsg_type tp)
432 {
433 return (size_t)tp < (sizeof(wmsg_types) / sizeof(wmsg_types[0]));
434 }
435
transfer_ensure_size(struct transfer_queue * transfers,int count)436 int transfer_ensure_size(struct transfer_queue *transfers, int count)
437 {
438 int sz = transfers->size;
439 if (buf_ensure_size(count, sizeof(*transfers->vecs), &sz,
440 (void **)&transfers->vecs) == -1) {
441 return -1;
442 }
443 sz = transfers->size;
444 if (buf_ensure_size(count, sizeof(*transfers->meta), &sz,
445 (void **)&transfers->meta) == -1) {
446 return -1;
447 }
448 transfers->size = sz;
449 return 0;
450 }
451
transfer_add(struct transfer_queue * w,size_t size,void * data)452 int transfer_add(struct transfer_queue *w, size_t size, void *data)
453 {
454 if (size == 0) {
455 return 0;
456 }
457 if (transfer_ensure_size(w, w->end + 1) == -1) {
458 return -1;
459 }
460
461 w->vecs[w->end].iov_len = size;
462 w->vecs[w->end].iov_base = data;
463 w->meta[w->end].msgno = w->last_msgno;
464 w->meta[w->end].static_alloc = false;
465 w->end++;
466 w->last_msgno++;
467 return 0;
468 }
469
transfer_async_add(struct thread_msg_recv_buf * q,void * data,size_t sz)470 void transfer_async_add(struct thread_msg_recv_buf *q, void *data, size_t sz)
471 {
472 struct iovec vec;
473 vec.iov_len = sz;
474 vec.iov_base = data;
475 pthread_mutex_lock(&q->lock);
476 q->data[q->zone_end++] = vec;
477 pthread_mutex_unlock(&q->lock);
478 }
479
transfer_load_async(struct transfer_queue * w)480 int transfer_load_async(struct transfer_queue *w)
481 {
482 pthread_mutex_lock(&w->async_recv_queue.lock);
483 int zstart = w->async_recv_queue.zone_start;
484 int zend = w->async_recv_queue.zone_end;
485 w->async_recv_queue.zone_start = zend;
486 pthread_mutex_unlock(&w->async_recv_queue.lock);
487
488 for (int i = zstart; i < zend; i++) {
489 struct iovec v = w->async_recv_queue.data[i];
490 memset(&w->async_recv_queue.data[i], 0, sizeof(struct iovec));
491 if (v.iov_len == 0 || v.iov_base == NULL) {
492 wp_error("Unexpected empty message");
493 continue;
494 }
495 /* Only fill/diff messages are received async, so msgno
496 * is always incremented */
497 if (transfer_add(w, v.iov_len, v.iov_base) == -1) {
498 wp_error("Failed to add message to transfer queue");
499 pthread_mutex_unlock(&w->async_recv_queue.lock);
500 return -1;
501 }
502 }
503 return 0;
504 }
505
cleanup_transfer_queue(struct transfer_queue * td)506 void cleanup_transfer_queue(struct transfer_queue *td)
507 {
508 for (int i = td->async_recv_queue.zone_start;
509 i < td->async_recv_queue.zone_end; i++) {
510 free(td->async_recv_queue.data[i].iov_base);
511 }
512 pthread_mutex_destroy(&td->async_recv_queue.lock);
513 free(td->async_recv_queue.data);
514 for (int i = 0; i < td->end; i++) {
515 if (!td->meta[i].static_alloc) {
516 free(td->vecs[i].iov_base);
517 }
518 }
519 free(td->vecs);
520 free(td->meta);
521 }
522