xref: /dragonfly/test/unix/passdesc/passdesc.c (revision ed36d35d)
1 #include <sys/types.h>
2 #include <sys/socket.h>
3 #include <sys/stat.h>
4 #include <sys/un.h>
5 #include <sys/wait.h>
6 
7 #include <err.h>
8 #include <stdint.h>
9 #include <stdio.h>
10 #include <stdlib.h>
11 #include <string.h>
12 #include <unistd.h>
13 
14 #define TEST_FILENAME	"/tmp/passdesc"
15 
16 static int	test_buflen;
17 static void	*test_buf;
18 
19 static void
20 test_send_desc(int s, int fd)
21 {
22 	struct msghdr msg;
23 	struct iovec iov;
24 	union {
25 		struct cmsghdr cm;
26 		uint8_t data[CMSG_SPACE(sizeof(int))];
27 	} ctrl;
28 	struct cmsghdr *cm;
29 	int n;
30 
31 	iov.iov_base = test_buf;
32 	iov.iov_len = test_buflen;
33 
34 	memset(&msg, 0, sizeof(msg));
35 	msg.msg_iov = &iov;
36 	msg.msg_iovlen = 1;
37 	msg.msg_control = ctrl.data;
38 	msg.msg_controllen = sizeof(ctrl.data);
39 
40 	memset(&ctrl, 0, sizeof(ctrl));
41 	cm = CMSG_FIRSTHDR(&msg);
42 	cm->cmsg_len = CMSG_LEN(sizeof(int));
43 	cm->cmsg_level = SOL_SOCKET;
44 	cm->cmsg_type = SCM_RIGHTS;
45 	*((int *)CMSG_DATA(cm)) = fd;
46 
47 	n = sendmsg(s, &msg, 0);
48 	if (n < 0)
49 		err(1, "sendmsg failed");
50 	else if (n != test_buflen)
51 		errx(1, "sendmsg sent %d", n);
52 	close(fd);
53 }
54 
55 static void
56 test_recv_desc(int s)
57 {
58 	struct msghdr msg;
59 	struct iovec iov;
60 	union {
61 		struct cmsghdr cm;
62 		uint8_t data[CMSG_SPACE(sizeof(int))];
63 	} ctrl;
64 	struct cmsghdr *cm;
65 	int n, fd;
66 	char data[16];
67 
68 	iov.iov_base = test_buf;
69 	iov.iov_len = test_buflen;
70 
71 	memset(&msg, 0, sizeof(msg));
72 	msg.msg_iov = &iov;
73 	msg.msg_iovlen = 1;
74 	msg.msg_control = ctrl.data;
75 	msg.msg_controllen = sizeof(ctrl.data);
76 
77 	n = recvmsg(s, &msg, MSG_WAITALL);
78 	if (n < 0)
79 		err(1, "recvmsg failed");
80 	else if (n != test_buflen)
81 		errx(1, "recvmsg received %d", n);
82 
83 	cm = CMSG_FIRSTHDR(&msg);
84 	if (cm == NULL)
85 		errx(1, "no cmsg");
86 	if (cm->cmsg_len != CMSG_LEN(sizeof(int)))
87 		errx(1, "cmsg len mismatch");
88 	if (cm->cmsg_level != SOL_SOCKET)
89 		errx(1, "cmsg level mismatch");
90 	if (cm->cmsg_type != SCM_RIGHTS)
91 		errx(1, "cmsg type mismatch");
92 
93 	fd = *((int *)CMSG_DATA(cm));
94 
95 	n = read(fd, data, sizeof(data) - 1);
96 	if (n < 0)
97 		err(1, "read failed");
98 	data[n] = '\0';
99 
100 	fprintf(stderr, "fd content: %s\n", data);
101 }
102 
103 static void
104 usage(const char *cmd)
105 {
106 	fprintf(stderr, "%s [-d] [-s] [-p payload_len]\n", cmd);
107 	exit(1);
108 }
109 
110 int
111 main(int argc, char *argv[])
112 {
113 	int s[2], fd, status, n, discard, skipfd;
114 	int opt;
115 	off_t ofs;
116 
117 	discard = 0;
118 	skipfd = 0;
119 	while ((opt = getopt(argc, argv, "dp:s")) != -1) {
120 		switch (opt) {
121 		case 'd':
122 			discard = 1;
123 			break;
124 
125 		case 'p':
126 			test_buflen = strtoul(optarg, NULL, 10);
127 			break;
128 
129 		case 's':
130 			skipfd = 1;
131 			break;
132 
133 		default:
134 			usage(argv[0]);
135 		}
136 	}
137 
138 	if (test_buflen <= 0)
139 		test_buflen = sizeof(int);
140 	test_buf = malloc(test_buflen);
141 	if (test_buf == NULL)
142 		err(1, "malloc %d failed", test_buflen);
143 
144 	if (socketpair(AF_LOCAL, SOCK_STREAM, 0, s) < 0)
145 		err(1, "socketpair(LOCAL, STREAM) failed");
146 
147 	if (fork() == 0) {
148 		close(s[0]);
149 		if (!discard && !skipfd) {
150 			test_recv_desc(s[1]);
151 		} else if (skipfd) {
152 			int buf;
153 
154 			fprintf(stderr, "skipfd\n");
155 			n = read(s[1], &buf, sizeof(buf));
156 			if (n < 0)
157 				err(1, "read failed");
158 		} else {
159 			fprintf(stderr, "discard msg\n");
160 			sleep(5);
161 		}
162 		exit(0);
163 	}
164 	close(s[1]);
165 
166 	fd = open(TEST_FILENAME, O_RDWR | O_TRUNC | O_CREAT,
167 	    S_IWUSR | S_IRUSR | S_IRGRP | S_IROTH);
168 	if (fd < 0)
169 		err(1, "open " TEST_FILENAME " failed");
170 
171 	n = write(fd, TEST_FILENAME, strlen(TEST_FILENAME));
172 	if (n < 0)
173 		err(1, "write failed");
174 	else if (n != strlen(TEST_FILENAME))
175 		errx(1, "write %d", n);
176 
177 	ofs = lseek(fd, 0, SEEK_SET);
178 	if (ofs < 0)
179 		err(1, "lseek failed");
180 	else if (ofs != 0)
181 		errx(1, "lseek offset %jd", (intmax_t)ofs);
182 
183 	test_send_desc(s[0], fd);
184 
185 	wait(&status);
186 	exit(0);
187 }
188