xref: /netbsd/regress/sys/kern/unfdpass/unfdpass.c (revision 6550d01e)
1 /*	$NetBSD: unfdpass.c,v 1.10 2008/04/28 20:23:07 martin Exp $	*/
2 
3 /*-
4  * Copyright (c) 1998 The NetBSD Foundation, Inc.
5  * All rights reserved.
6  *
7  * This code is derived from software contributed to The NetBSD Foundation
8  * by Jason R. Thorpe of the Numerical Aerospace Simulation Facility,
9  * NASA Ames Research Center.
10  *
11  * Redistribution and use in source and binary forms, with or without
12  * modification, are permitted provided that the following conditions
13  * are met:
14  * 1. Redistributions of source code must retain the above copyright
15  *    notice, this list of conditions and the following disclaimer.
16  * 2. Redistributions in binary form must reproduce the above copyright
17  *    notice, this list of conditions and the following disclaimer in the
18  *    documentation and/or other materials provided with the distribution.
19  *
20  * THIS SOFTWARE IS PROVIDED BY THE NETBSD FOUNDATION, INC. AND CONTRIBUTORS
21  * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
22  * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
23  * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR CONTRIBUTORS
24  * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
25  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
26  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
27  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
28  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
29  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
30  * POSSIBILITY OF SUCH DAMAGE.
31  */
32 
33 /*
34  * Test passing of file descriptors and credentials over Unix domain sockets.
35  */
36 
37 #include <sys/param.h>
38 #include <sys/socket.h>
39 #include <sys/time.h>
40 #include <sys/wait.h>
41 #include <sys/un.h>
42 #include <sys/uio.h>
43 
44 #include <err.h>
45 #include <errno.h>
46 #include <fcntl.h>
47 #include <signal.h>
48 #include <stdio.h>
49 #include <string.h>
50 #include <stdlib.h>
51 #include <unistd.h>
52 
53 #define	SOCK_NAME	"test-sock"
54 
55 int	main(int, char *[]);
56 void	child(void);
57 void	catch_sigchld(int);
58 void	usage(char *progname);
59 
60 #define	FILE_SIZE	128
61 #define	MSG_SIZE	-1
62 #define	NFILES		24
63 
64 #define	FDCM_DATASIZE	(sizeof(int) * NFILES)
65 #define	CRCM_DATASIZE	(SOCKCREDSIZE(NGROUPS))
66 
67 #define	MESSAGE_SIZE	(CMSG_SPACE(FDCM_DATASIZE) +			\
68 			 CMSG_SPACE(CRCM_DATASIZE))
69 
70 int chroot_rcvr = 0;
71 int pass_dir = 0;
72 int pass_root_dir = 0;
73 int exit_early = 0;
74 int exit_later = 0;
75 int pass_sock = 0;
76 int make_pretzel = 0;
77 
78 /* ARGSUSED */
79 int
80 main(argc, argv)
81 	int argc;
82 	char *argv[];
83 {
84 #if MSG_SIZE >= 0
85 	struct iovec iov;
86 #endif
87 	char *progname=argv[0];
88 	struct msghdr msg;
89 	int listensock, sock, fd, i;
90 	char fname[16], buf[FILE_SIZE];
91 	struct cmsghdr *cmp;
92 	void *message;
93 	int *files = NULL;
94 	struct sockcred *sc = NULL;
95 	struct sockaddr_un sun, csun;
96 	socklen_t csunlen;
97 	pid_t pid;
98 	int ch;
99 
100 	message = malloc(CMSG_SPACE(MESSAGE_SIZE));
101 	if (message == NULL)
102 		err(1, "unable to malloc message buffer");
103 	memset(message, 0, CMSG_SPACE(MESSAGE_SIZE));
104 
105 	while ((ch = getopt(argc, argv, "DESdepr")) != -1) {
106 		switch(ch) {
107 
108 		case 'e':
109 			exit_early++; /* test early GC */
110 			break;
111 
112 		case 'E':
113 			exit_later++; /* test later GC */
114 			break;
115 
116 		case 'd':
117 			pass_dir++;
118 			break;
119 
120 		case 'D':
121 			pass_dir++;
122 			pass_root_dir++;
123 			break;
124 
125 		case 'S':
126 			pass_sock++;
127 			break;
128 
129 		case 'r':
130 			chroot_rcvr++;
131 			break;
132 
133 		case 'p':
134 			make_pretzel++;
135 			break;
136 
137 		case '?':
138 		default:
139 			usage(progname);
140 		}
141 	}
142 
143 
144 	/*
145 	 * Create the test files.
146 	 */
147 	for (i = 0; i < NFILES; i++) {
148 		(void) sprintf(fname, "file%d", i + 1);
149 		if ((fd = open(fname, O_WRONLY|O_CREAT|O_TRUNC, 0666)) == -1)
150 			err(1, "open %s", fname);
151 		(void) sprintf(buf, "This is file %d.\n", i + 1);
152 		if (write(fd, buf, strlen(buf)) != strlen(buf))
153 			err(1, "write %s", fname);
154 		(void) close(fd);
155 	}
156 
157 	/*
158 	 * Create the listen socket.
159 	 */
160 	if ((listensock = socket(PF_LOCAL, SOCK_STREAM, 0)) == -1)
161 		err(1, "socket");
162 
163 	(void) unlink(SOCK_NAME);
164 	(void) memset(&sun, 0, sizeof(sun));
165 	sun.sun_family = AF_LOCAL;
166 	(void) strcpy(sun.sun_path, SOCK_NAME);
167 	sun.sun_len = SUN_LEN(&sun);
168 
169 	i = 1;
170 	if (setsockopt(listensock, 0, LOCAL_CREDS, &i, sizeof(i)) == -1)
171 		err(1, "setsockopt");
172 
173 	if (bind(listensock, (struct sockaddr *)&sun, sizeof(sun)) == -1)
174 		err(1, "bind");
175 
176 	if (listen(listensock, 1) == -1)
177 		err(1, "listen");
178 
179 	/*
180 	 * Create the sender.
181 	 */
182 	(void) signal(SIGCHLD, catch_sigchld);
183 	pid = fork();
184 	switch (pid) {
185 	case -1:
186 		err(1, "fork");
187 		/* NOTREACHED */
188 
189 	case 0:
190 		child();
191 		/* NOTREACHED */
192 	}
193 
194 	if (exit_early)
195 		exit(0);
196 
197 	if (chroot_rcvr &&
198 	    ((chroot(".") < 0)))
199 		err(1, "chroot");
200 
201 	/*
202 	 * Wait for the sender to connect.
203 	 */
204 	csunlen = sizeof(csun);
205 	if ((sock = accept(listensock, (struct sockaddr *)&csun,
206 	    &csunlen)) == -1)
207 		err(1, "accept");
208 
209 	/*
210 	 * Give sender a chance to run.  We will get going again
211 	 * once the SIGCHLD arrives.
212 	 */
213 	(void) sleep(10);
214 
215 	if (exit_later)
216 		exit(0);
217 
218 	/*
219 	 * Grab the descriptors and credentials passed to us.
220 	 */
221 
222 	/* Expect 2 messages; descriptors and creds. */
223 	do {
224 		(void) memset(&msg, 0, sizeof(msg));
225 		msg.msg_control = message;
226 		msg.msg_controllen = MESSAGE_SIZE;
227 #if MSG_SIZE >= 0
228 		iov.iov_base = buf;
229 		iov.iov_len = MSG_SIZE;
230 		msg.msg_iov = &iov;
231 		msg.msg_iovlen = 1;
232 #endif
233 
234 		if (recvmsg(sock, &msg, 0) == -1)
235 			err(1, "recvmsg");
236 
237 		(void) close(sock);
238 		sock = -1;
239 
240 		if (msg.msg_controllen == 0)
241 			errx(1, "no control messages received");
242 
243 		if (msg.msg_flags & MSG_CTRUNC)
244 			errx(1, "lost control message data");
245 
246 		for (cmp = CMSG_FIRSTHDR(&msg); cmp != NULL;
247 		     cmp = CMSG_NXTHDR(&msg, cmp)) {
248 			if (cmp->cmsg_level != SOL_SOCKET)
249 				errx(1, "bad control message level %d",
250 				    cmp->cmsg_level);
251 
252 			switch (cmp->cmsg_type) {
253 			case SCM_RIGHTS:
254 				if (cmp->cmsg_len != CMSG_LEN(FDCM_DATASIZE))
255 					errx(1, "bad fd control message "
256 					    "length %d", cmp->cmsg_len);
257 
258 				files = (int *)CMSG_DATA(cmp);
259 				break;
260 
261 			case SCM_CREDS:
262 				if (cmp->cmsg_len < CMSG_LEN(SOCKCREDSIZE(1)))
263 					errx(1, "bad cred control message "
264 					    "length %d", cmp->cmsg_len);
265 
266 				sc = (struct sockcred *)CMSG_DATA(cmp);
267 				break;
268 
269 			default:
270 				errx(1, "unexpected control message");
271 				/* NOTREACHED */
272 			}
273 		}
274 
275 		/*
276 		 * Read the files and print their contents.
277 		 */
278 		if (files == NULL)
279 			warnx("didn't get fd control message");
280 		else {
281 			for (i = 0; i < NFILES; i++) {
282 				struct stat st;
283 				(void) memset(buf, 0, sizeof(buf));
284 				fstat(files[i], &st);
285 				if (S_ISDIR(st.st_mode)) {
286 					printf("file %d is a directory\n", i+1);
287 				} else if (S_ISSOCK(st.st_mode)) {
288 					printf("file %d is a socket\n", i+1);
289 					sock = files[i];
290 				} else {
291 					int c;
292 					c = read (files[i], buf, sizeof(buf));
293 					if (c < 0)
294 						err(1, "read file %d", i + 1);
295 					else if (c == 0)
296 						printf("[eof on %d]\n", i + 1);
297 					else
298 						printf("%s", buf);
299 				}
300 			}
301 		}
302 		/*
303 		 * Double-check credentials.
304 		 */
305 		if (sc == NULL)
306 			warnx("didn't get cred control message");
307 		else {
308 			if (sc->sc_uid == getuid() &&
309 			    sc->sc_euid == geteuid() &&
310 			    sc->sc_gid == getgid() &&
311 			    sc->sc_egid == getegid())
312 				printf("Credentials match.\n");
313 			else
314 				printf("Credentials do NOT match.\n");
315 		}
316 	} while (sock != -1);
317 
318 	/*
319 	 * All done!
320 	 */
321 	exit(0);
322 }
323 
324 void
325 usage(progname)
326 	char *progname;
327 {
328 	fprintf(stderr, "usage: %s [-derDES]\n", progname);
329 	exit(1);
330 }
331 
332 void
333 catch_sigchld(sig)
334 	int sig;
335 {
336 	int status;
337 
338 	(void) wait(&status);
339 }
340 
341 void
342 child()
343 {
344 #if MSG_SIZE >= 0
345 	struct iovec iov;
346 #endif
347 	struct msghdr msg;
348 	char fname[16];
349 	struct cmsghdr *cmp;
350 	void *fdcm;
351 	int i, fd, sock, nfd, *files;
352 	struct sockaddr_un sun;
353 	int spair[2];
354 
355 	fdcm = malloc(CMSG_SPACE(FDCM_DATASIZE));
356 	if (fdcm == NULL)
357 		err(1, "unable to malloc fd control message");
358 	memset(fdcm, 0, CMSG_SPACE(FDCM_DATASIZE));
359 
360 	cmp = fdcm;
361 	files = (int *)CMSG_DATA(fdcm);
362 
363 	/*
364 	 * Create socket and connect to the receiver.
365 	 */
366 	if ((sock = socket(PF_LOCAL, SOCK_STREAM, 0)) == -1)
367 		errx(1, "child socket");
368 
369 	(void) memset(&sun, 0, sizeof(sun));
370 	sun.sun_family = AF_LOCAL;
371 	(void) strcpy(sun.sun_path, SOCK_NAME);
372 	sun.sun_len = SUN_LEN(&sun);
373 
374 	if (connect(sock, (struct sockaddr *)&sun, sizeof(sun)) == -1)
375 		err(1, "child connect");
376 
377 	nfd = NFILES;
378 	i = 0;
379 
380 	if (pass_sock) {
381 		files[i++] = sock;
382 	}
383 
384 	if (pass_dir)
385 		nfd--;
386 
387 	/*
388 	 * Open the files again, and pass them to the child
389 	 * over the socket.
390 	 */
391 
392 	for (; i < nfd; i++) {
393 		(void) sprintf(fname, "file%d", i + 1);
394 		if ((fd = open(fname, O_RDONLY, 0666)) == -1)
395 			err(1, "child open %s", fname);
396 		files[i] = fd;
397 	}
398 
399 	if (pass_dir) {
400 		char *dirname = pass_root_dir ? "/" : ".";
401 
402 
403 		if ((fd = open(dirname, O_RDONLY, 0)) == -1) {
404 			err(1, "child open directory %s", dirname);
405 		}
406 		files[i] = fd;
407 	}
408 
409 	(void) memset(&msg, 0, sizeof(msg));
410 	msg.msg_control = fdcm;
411 	msg.msg_controllen = CMSG_LEN(FDCM_DATASIZE);
412 #if MSG_SIZE >= 0
413 	iov.iov_base = buf;
414 	iov.iov_len = MSG_SIZE;
415 	msg.msg_iov = &iov;
416 	msg.msg_iovlen = 1;
417 #endif
418 
419 	cmp = CMSG_FIRSTHDR(&msg);
420 	cmp->cmsg_len = CMSG_LEN(FDCM_DATASIZE);
421 	cmp->cmsg_level = SOL_SOCKET;
422 	cmp->cmsg_type = SCM_RIGHTS;
423 
424 	while (make_pretzel > 0) {
425 		if (socketpair(PF_LOCAL, SOCK_STREAM, 0, spair) < 0)
426 			err(1, "socketpair");
427 
428 		printf("send pretzel\n");
429 		if (sendmsg(spair[0], &msg, 0) < 0)
430 			err(1, "child prezel sendmsg");
431 
432 		close(files[0]);
433 		close(files[1]);
434 		files[0] = spair[0];
435 		files[1] = spair[1];
436 		make_pretzel--;
437 	}
438 
439 	if (sendmsg(sock, &msg, 0) == -1)
440 		err(1, "child sendmsg");
441 
442 	/*
443 	 * All done!
444 	 */
445 	exit(0);
446 }
447