xref: /openbsd/regress/sys/kern/unfdpass/unfdpass.c (revision 3bef86f7)
1 /*	$OpenBSD: unfdpass.c,v 1.23 2023/03/08 04:43:06 guenther Exp $	*/
2 /*	$NetBSD: unfdpass.c,v 1.3 1998/06/24 23:51:30 thorpej Exp $	*/
3 
4 /*-
5  * Copyright (c) 1998 The NetBSD Foundation, Inc.
6  * All rights reserved.
7  *
8  * This code is derived from software contributed to The NetBSD Foundation
9  * by Jason R. Thorpe of the Numerical Aerospace Simulation Facility,
10  * NASA Ames Research Center.
11  *
12  * Redistribution and use in source and binary forms, with or without
13  * modification, are permitted provided that the following conditions
14  * are met:
15  * 1. Redistributions of source code must retain the above copyright
16  *    notice, this list of conditions and the following disclaimer.
17  * 2. Redistributions in binary form must reproduce the above copyright
18  *    notice, this list of conditions and the following disclaimer in the
19  *    documentation and/or other materials provided with the distribution.
20  *
21  * THIS SOFTWARE IS PROVIDED BY THE NETBSD FOUNDATION, INC. AND CONTRIBUTORS
22  * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
23  * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
24  * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR CONTRIBUTORS
25  * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
26  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
27  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
28  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
29  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
30  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31  * POSSIBILITY OF SUCH DAMAGE.
32  */
33 
34 /*
35  * Test passing of file descriptors over Unix domain sockets and socketpairs.
36  */
37 
38 #include <sys/socket.h>
39 #include <sys/time.h>
40 #include <sys/wait.h>
41 #include <sys/un.h>
42 #include <err.h>
43 #include <errno.h>
44 #include <fcntl.h>
45 #include <signal.h>
46 #include <stdio.h>
47 #include <stdlib.h>
48 #include <string.h>
49 #include <unistd.h>
50 
51 #define	SOCK_NAME	"test-sock"
52 
53 int	main(int, char *[]);
54 void	child(int, int, int);
55 void	catch_sigchld(int);
56 
57 int
58 main(int argc, char *argv[])
59 {
60 	struct msghdr msg;
61 	int sock, pfd[2], fd, i;
62 	int listensock = -1;
63 	char fname[16], buf[64];
64 	struct cmsghdr *cmp;
65 	int *files = NULL;
66 	struct sockaddr_un sun, csun;
67 	int csunlen;
68 	pid_t pid;
69 	union {
70 		struct cmsghdr hdr;
71 		char buf[CMSG_SPACE(sizeof(int) * 3)];
72 	} cmsgbuf;
73 	int pflag, oflag, rflag;
74 	int type = SOCK_STREAM;
75 	extern char *__progname;
76 
77 	pflag = 0;
78 	oflag = 0;
79 	rflag = 0;
80 	while ((i = getopt(argc, argv, "opqr")) != -1) {
81 		switch (i) {
82 		case 'o':
83 			oflag = 1;
84 			break;
85 		case 'p':
86 			pflag = 1;
87 			break;
88 		case 'q':
89 			type = SOCK_SEQPACKET;
90 			break;
91 		case 'r':
92 			rflag = 1;
93 			break;
94 		default:
95 			fprintf(stderr, "usage: %s [-opqr]\n", __progname);
96 			exit(1);
97 		}
98 	}
99 
100 	/*
101 	 * Create the test files.
102 	 */
103 	for (i = 0; i < 5; i++) {
104 		(void) snprintf(fname, sizeof fname, "file%d", i + 1);
105 		if ((fd = open(fname, O_WRONLY|O_CREAT|O_TRUNC, 0666)) == -1)
106 			err(1, "open %s", fname);
107 		(void) snprintf(buf, sizeof buf, "This is file %d.\n", i + 1);
108 		if (write(fd, buf, strlen(buf)) != (ssize_t) strlen(buf))
109 			err(1, "write %s", fname);
110 		(void) close(fd);
111 	}
112 
113 	if (pflag) {
114 		/*
115 		 * Create the socketpair
116 		 */
117 		if (socketpair(PF_LOCAL, type, 0, pfd) == -1)
118 			err(1, "socketpair");
119 	} else {
120 		/*
121 		 * Create the listen socket.
122 		 */
123 		if ((listensock = socket(PF_LOCAL, type, 0)) == -1)
124 			err(1, "socket");
125 
126 		(void) unlink(SOCK_NAME);
127 		(void) memset(&sun, 0, sizeof(sun));
128 		sun.sun_family = AF_LOCAL;
129 		(void) strlcpy(sun.sun_path, SOCK_NAME, sizeof sun.sun_path);
130 
131 		if (bind(listensock, (struct sockaddr *)&sun, sizeof(sun)) == -1)
132 			err(1, "bind");
133 
134 		if (listen(listensock, 1) == -1)
135 			err(1, "listen");
136 		pfd[0] = pfd[1] = -1;
137 	}
138 
139 	/*
140 	 * Create the sender.
141 	 */
142 	(void) signal(SIGCHLD, catch_sigchld);
143 	pid = fork();
144 	switch (pid) {
145 	case -1:
146 		err(1, "fork");
147 		/* NOTREACHED */
148 
149 	case 0:
150 		if (pfd[0] != -1)
151 			close(pfd[0]);
152 		child(pfd[1], type, oflag);
153 		/* NOTREACHED */
154 	}
155 
156 	if (pfd[0] != -1) {
157 		close(pfd[1]);
158 		sock = pfd[0];
159 	} else {
160 		/*
161 		 * Wait for the sender to connect.
162 		 */
163 		if ((sock = accept(listensock, (struct sockaddr *)&csun,
164 		    &csunlen)) == -1)
165 		err(1, "accept");
166 	}
167 
168 	/*
169 	 * Give sender a chance to run.  We will get going again
170 	 * once the SIGCHLD arrives.
171 	 */
172 	(void) sleep(10);
173 
174 	if (rflag) {
175 		if (read(sock, buf, sizeof(buf)) < 0)
176 			err(1, "read");
177 		printf("read successfully returned\n");
178 		exit(0);
179 	}
180 
181 	/*
182 	 * Grab the descriptors passed to us.
183 	 */
184 	memset(&msg, 0, sizeof(msg));
185 	msg.msg_control = &cmsgbuf.buf;
186 	msg.msg_controllen = sizeof(cmsgbuf.buf);
187 
188 	if (recvmsg(sock, &msg, 0) < 0) {
189 		if (errno == EMSGSIZE) {
190 			printf("recvmsg returned EMSGSIZE\n");
191 			exit(0);
192 		} else
193 			err(1, "recvmsg");
194 	}
195 
196 	(void) close(sock);
197 
198 	if (msg.msg_controllen == 0)
199 		errx(1, "no control messages received");
200 
201 	if (msg.msg_flags & MSG_CTRUNC)
202 		errx(1, "lost control message data");
203 
204 	for (cmp = CMSG_FIRSTHDR(&msg); cmp != NULL;
205 	    cmp = CMSG_NXTHDR(&msg, cmp)) {
206 		if (cmp->cmsg_level != SOL_SOCKET)
207 			errx(1, "bad control message level %d",
208 			    cmp->cmsg_level);
209 
210 		switch (cmp->cmsg_type) {
211 		case SCM_RIGHTS:
212 			if (cmp->cmsg_len != CMSG_LEN(sizeof(int) * 3))
213 				errx(1, "bad fd control message length %d",
214 				    cmp->cmsg_len);
215 
216 			files = (int *)CMSG_DATA(cmp);
217 			break;
218 
219 		default:
220 			errx(1, "unexpected control message");
221 			/* NOTREACHED */
222 		}
223 	}
224 
225 	/*
226 	 * Read the files and print their contents.
227 	 */
228 	if (files == NULL)
229 		warnx("didn't get fd control message");
230 	else {
231 		for (i = 0; i < 3; i++) {
232 			(void) memset(buf, 0, sizeof(buf));
233 			if (read(files[i], buf, sizeof(buf)) <= 0)
234 				err(1, "read file %d (%d)", i + 1, files[i]);
235 			printf("%s", buf);
236 		}
237 	}
238 
239 	/*
240 	 * All done!
241 	 */
242 	exit(0);
243 }
244 
245 void
246 catch_sigchld(sig)
247 	int sig;
248 {
249 	int save_errno = errno;
250 	int status;
251 
252 	(void) wait(&status);
253 	errno = save_errno;
254 }
255 
256 void
257 child(int sock, int type, int oflag)
258 {
259 	struct msghdr msg;
260 	char fname[16];
261 	struct cmsghdr *cmp;
262 	int i, fd, nfds = 3;
263 	struct sockaddr_un sun;
264 	size_t len;
265 	char *cmsgbuf;
266 	int *files;
267 
268 	/*
269 	 * Create socket if needed and connect to the receiver.
270 	 */
271 	if (sock == -1) {
272 		if ((sock = socket(PF_LOCAL, type, 0)) == -1)
273 			err(1, "child socket");
274 
275 		(void) memset(&sun, 0, sizeof(sun));
276 		sun.sun_family = AF_LOCAL;
277 		(void) strlcpy(sun.sun_path, SOCK_NAME, sizeof sun.sun_path);
278 
279 		if (connect(sock, (struct sockaddr *)&sun, sizeof(sun)) == -1)
280 			err(1, "child connect");
281 	}
282 
283 	if (oflag)
284 		nfds = 5;
285 	len = CMSG_SPACE(sizeof(int) * nfds);
286 	if ((cmsgbuf = malloc(len)) == NULL)
287 		err(1, "child");
288 
289 	(void) memset(&msg, 0, sizeof(msg));
290 	msg.msg_control = cmsgbuf;
291 	msg.msg_controllen = len;
292 
293 	cmp = CMSG_FIRSTHDR(&msg);
294 	cmp->cmsg_len = CMSG_LEN((sizeof(int) * nfds));
295 	cmp->cmsg_level = SOL_SOCKET;
296 	cmp->cmsg_type = SCM_RIGHTS;
297 
298 	/*
299 	 * Open the files again, and pass them to the parent over the socket.
300 	 */
301 	files = (int *)CMSG_DATA(cmp);
302 	for (i = 0; i < nfds; i++) {
303 		(void) snprintf(fname, sizeof fname, "file%d", i + 1);
304 		if ((fd = open(fname, O_RDONLY)) == -1)
305 			err(1, "child open %s", fname);
306 		files[i] = fd;
307 	}
308 
309 	if (sendmsg(sock, &msg, 0))
310 		err(1, "child sendmsg");
311 
312 	/*
313 	 * All done!
314 	 */
315 	exit(0);
316 }
317