1 /* This file is part of GNU Pies testsuite.
2    Copyright (C) 2019-2020 Sergey Poznyakoff
3 
4    GNU Pies is free software; you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation; either version 3, or (at your option)
7    any later version.
8 
9    GNU Pies is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13 
14    You should have received a copy of the GNU General Public License
15    along with GNU Pies.  If not, see <http://www.gnu.org/licenses/>. */
16 
17 #include <config.h>
18 #include <stdio.h>
19 #include <stdlib.h>
20 #include <unistd.h>
21 #include <errno.h>
22 #include <sys/types.h>
23 #include <sys/socket.h>
24 #include <sys/un.h>
25 #include <sys/stat.h>
26 #include <sys/wait.h>
27 #include <signal.h>
28 #include "libpies.h"
29 
30 char const *progname;
31 
32 void
usage(FILE * fp,int status)33 usage (FILE *fp, int status)
34 {
35   fprintf (fp, "usage: %s [-s SOCKET] COMMAND ARGS...\n", progname);
36   fprintf (fp, "Test tool for accept and pass-fd pies components.\n");
37   fprintf (fp, "Listens on the file descriptor, either 0 or obtained from SOCKET.\n");
38   fprintf (fp, "For each connection, execs COMMAND ARGS as a separate process.\n");
39   exit (status);
40 }
41 
42 static int
listen_socket(char const * socket_name)43 listen_socket (char const *socket_name)
44 {
45   struct sockaddr_un addr;
46   int sockfd;
47 
48   if (strlen (socket_name) > sizeof addr.sun_path)
49     {
50       fprintf (stderr, "%s: UNIX socket name too long\n", progname);
51       return -1;
52     }
53   addr.sun_family = AF_UNIX;
54   strcpy (addr.sun_path, socket_name);
55 
56   sockfd = socket (PF_UNIX, SOCK_STREAM, 0);
57   if (sockfd == -1)
58     {
59       perror ("socket");
60       exit (1);
61     }
62 
63   umask (0117);
64   if (bind (sockfd, (struct sockaddr *) &addr, sizeof (addr)) < 0)
65     {
66       perror ("bind");
67       exit (1);
68     }
69 
70   if (listen (sockfd, 8) < 0)
71     {
72       perror ("listen");
73       exit (1);
74     }
75   return sockfd;
76 }
77 
78 static int
read_fd(int fd)79 read_fd (int fd)
80 {
81   struct msghdr msg;
82   struct iovec iov[1];
83   char base[1];
84 
85 #if HAVE_STRUCT_MSGHDR_MSG_CONTROL
86   union
87   {
88     struct cmsghdr cm;
89     char control[CMSG_SPACE (sizeof (int))];
90   } control_un;
91   struct cmsghdr *cmptr;
92 
93   msg.msg_control = control_un.control;
94   msg.msg_controllen = sizeof (control_un.control);
95 #elif HAVE_STRUCT_MSGHDR_MSG_ACCRIGHTS
96   int newfd;
97 
98   msg.msg_accrights = (caddr_t) &newfd;
99   msg.msg_accrightslen = sizeof (int);
100 #else
101   fprintf (stderr, "no way to get fd\n");
102   exit (77);
103 #endif
104 
105   msg.msg_name = NULL;
106   msg.msg_namelen = 0;
107 
108   iov[0].iov_base = base;
109   iov[0].iov_len = sizeof (base);
110 
111   msg.msg_iov = iov;
112   msg.msg_iovlen = 1;
113 
114   if (recvmsg (fd, &msg, 0) > 0)
115     {
116 #if HAVE_STRUCT_MSGHDR_MSG_CONTROL
117       if ((cmptr = CMSG_FIRSTHDR (&msg)) != NULL
118 	  && cmptr->cmsg_len == CMSG_LEN (sizeof (int))
119 	  && cmptr->cmsg_level == SOL_SOCKET
120 	  && cmptr->cmsg_type == SCM_RIGHTS)
121 	return *((int*) CMSG_DATA (cmptr));
122 #elif HAVE_STRUCT_MSGHDR_MSG_ACCRIGHTS
123       if (msg.msg_accrightslen == sizeof (int))
124 	return newfd;
125 #endif
126     }
127   return -1;
128 }
129 
130 static int
get_fd(int lfd)131 get_fd (int lfd)
132 {
133   int sfd, fd = accept (lfd, NULL, NULL);
134   if (fd == -1)
135     {
136       perror ("accept");
137       exit (1);
138     }
139 
140   sfd = read_fd (fd);
141   close (fd);
142   return sfd;
143 }
144 
145 static void
sigchld(int sig)146 sigchld (int sig)
147 {
148   pid_t pid;
149 
150   while ((pid = waitpid ((pid_t)-1, NULL, WNOHANG)) >= 0)
151     ;
152   signal (sig, sigchld);
153 }
154 
155 static void
sigquit(int sig)156 sigquit (int sig)
157 {
158   kill (0, sig);
159   exit (0);
160 }
161 
162 int
main(int argc,char ** argv)163 main (int argc, char **argv)
164 {
165   int c;
166   int fd;
167   char *socket_name = NULL;
168 
169   progname = argv[0];
170 
171   while ((c = getopt (argc, argv, "hs:")) != EOF)
172     {
173       switch (c)
174 	{
175 	case 'h':
176 	  usage (stdout, 0);
177 	  break;
178 
179 	case 's':
180 	  socket_name = optarg;
181 	  break;
182 
183 	default:
184 	  exit (64);
185 	}
186     }
187 
188   argc -= optind;
189   argv += optind;
190 
191   if (argc == 0)
192     usage (stderr, 64);
193 
194   if (socket_name)
195     {
196       int sfd = listen_socket (socket_name);
197       fd = get_fd (sfd);
198       close (sfd);
199     }
200   else
201     fd = 0;
202 
203   signal (SIGCHLD, sigchld);
204   signal (SIGTERM, sigquit);
205   signal (SIGHUP, sigquit);
206   signal (SIGINT, sigquit);
207   signal (SIGQUIT, sigquit);
208 
209   while (1)
210     {
211       int cfd = accept (fd, NULL, NULL);
212       if (cfd == -1)
213 	{
214 	  perror ("accept");
215 	  exit (1);
216 	}
217 
218       pid_t pid = fork ();
219       if (pid == 0)
220 	{
221 	  int i;
222 
223 	  for (i = getmaxfd (); i >= 0; i--)
224 	    if (i != cfd)
225 	      close (i);
226 
227 	  if (cfd != 0)
228 	    dup2 (cfd, 0);
229 	  if (cfd != 1)
230 	    dup2 (cfd, 1);
231 	  if (cfd != 2)
232 	    dup2 (cfd, 2);
233 	  if (cfd > 2)
234 	    close (cfd);
235 
236 	  execvp (argv[0], argv);
237 	  exit (127);
238 	}
239       if (pid == -1)
240 	{
241 	  perror ("fork");
242 	}
243       close (cfd);
244     }
245   return 0;
246 }
247