xref: /netbsd/sys/external/bsd/libnv/dist/msgio.c (revision 4ea48eba)
1 /*	$NetBSD: msgio.c,v 1.2 2018/09/08 14:02:15 christos Exp $	*/
2 
3 /*-
4  * SPDX-License-Identifier: BSD-2-Clause-FreeBSD
5  *
6  * Copyright (c) 2013 The FreeBSD Foundation
7  * Copyright (c) 2013 Mariusz Zaborski <oshogbo@FreeBSD.org>
8  * All rights reserved.
9  *
10  * This software was developed by Pawel Jakub Dawidek under sponsorship from
11  * the FreeBSD Foundation.
12  *
13  * Redistribution and use in source and binary forms, with or without
14  * modification, are permitted provided that the following conditions
15  * are met:
16  * 1. Redistributions of source code must retain the above copyright
17  *    notice, this list of conditions and the following disclaimer.
18  * 2. Redistributions in binary form must reproduce the above copyright
19  *    notice, this list of conditions and the following disclaimer in the
20  *    documentation and/or other materials provided with the distribution.
21  *
22  * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
23  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
24  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
25  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
26  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
27  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
28  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
29  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
30  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
31  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
32  * SUCH DAMAGE.
33  */
34 
35 #include <sys/cdefs.h>
36 #ifdef __FreeBSD__
37 __FBSDID("$FreeBSD: head/lib/libnv/msgio.c 326219 2017-11-26 02:00:33Z pfg $");
38 #else
39 __RCSID("$NetBSD: msgio.c,v 1.2 2018/09/08 14:02:15 christos Exp $");
40 #endif
41 
42 #include <sys/param.h>
43 #include <sys/socket.h>
44 #include <sys/select.h>
45 
46 #include <errno.h>
47 #include <fcntl.h>
48 #include <stdbool.h>
49 #include <stdint.h>
50 #include <stdlib.h>
51 #include <string.h>
52 #include <unistd.h>
53 
54 #ifdef HAVE_PJDLOG
55 #include <pjdlog.h>
56 #endif
57 
58 #include "common_impl.h"
59 #include "msgio.h"
60 
61 #ifndef	HAVE_PJDLOG
62 #include <assert.h>
63 #define	PJDLOG_ASSERT(...)		assert(__VA_ARGS__)
64 #define	PJDLOG_RASSERT(expr, ...)	assert(expr)
65 #define	PJDLOG_ABORT(...)		abort()
66 #endif
67 
68 #ifdef __linux__
69 /* Linux: arbitrary size, but must be lower than SCM_MAX_FD. */
70 #define	PKG_MAX_SIZE	((64U - 1) * CMSG_SPACE(sizeof(int)))
71 #else
72 #define	PKG_MAX_SIZE	(MCLBYTES / CMSG_SPACE(sizeof(int)) - 1)
73 #endif
74 
75 static int
msghdr_add_fd(struct cmsghdr * cmsg,int fd)76 msghdr_add_fd(struct cmsghdr *cmsg, int fd)
77 {
78 
79 	PJDLOG_ASSERT(fd >= 0);
80 
81 	if (!fd_is_valid(fd)) {
82 		errno = EBADF;
83 		return (-1);
84 	}
85 
86 	cmsg->cmsg_level = SOL_SOCKET;
87 	cmsg->cmsg_type = SCM_RIGHTS;
88 	cmsg->cmsg_len = CMSG_LEN(sizeof(fd));
89 	bcopy(&fd, CMSG_DATA(cmsg), sizeof(fd));
90 
91 	return (0);
92 }
93 
94 static int
msghdr_get_fd(struct cmsghdr * cmsg)95 msghdr_get_fd(struct cmsghdr *cmsg)
96 {
97 	int fd;
98 
99 	if (cmsg == NULL || cmsg->cmsg_level != SOL_SOCKET ||
100 	    cmsg->cmsg_type != SCM_RIGHTS ||
101 	    cmsg->cmsg_len != CMSG_LEN(sizeof(fd))) {
102 		errno = EINVAL;
103 		return (-1);
104 	}
105 
106 	bcopy(CMSG_DATA(cmsg), &fd, sizeof(fd));
107 #ifndef MSG_CMSG_CLOEXEC
108 	/*
109 	 * If the MSG_CMSG_CLOEXEC flag is not available we cannot set the
110 	 * close-on-exec flag atomically, but we still want to set it for
111 	 * consistency.
112 	 */
113 	(void) fcntl(fd, F_SETFD, FD_CLOEXEC);
114 #endif
115 
116 	return (fd);
117 }
118 
119 static void
fd_wait(int fd,bool doread)120 fd_wait(int fd, bool doread)
121 {
122 	fd_set fds;
123 
124 	PJDLOG_ASSERT(fd >= 0);
125 
126 	FD_ZERO(&fds);
127 	FD_SET(fd, &fds);
128 	(void)select(fd + 1, doread ? &fds : NULL, doread ? NULL : &fds,
129 	    NULL, NULL);
130 }
131 
132 static int
msg_recv(int sock,struct msghdr * msg)133 msg_recv(int sock, struct msghdr *msg)
134 {
135 	int flags;
136 
137 	PJDLOG_ASSERT(sock >= 0);
138 
139 #ifdef MSG_CMSG_CLOEXEC
140 	flags = MSG_CMSG_CLOEXEC;
141 #else
142 	flags = 0;
143 #endif
144 
145 	for (;;) {
146 		fd_wait(sock, true);
147 		if (recvmsg(sock, msg, flags) == -1) {
148 			if (errno == EINTR)
149 				continue;
150 			return (-1);
151 		}
152 		break;
153 	}
154 
155 	return (0);
156 }
157 
158 static int
msg_send(int sock,const struct msghdr * msg)159 msg_send(int sock, const struct msghdr *msg)
160 {
161 
162 	PJDLOG_ASSERT(sock >= 0);
163 
164 	for (;;) {
165 		fd_wait(sock, false);
166 		if (sendmsg(sock, msg, 0) == -1) {
167 			if (errno == EINTR)
168 				continue;
169 			return (-1);
170 		}
171 		break;
172 	}
173 
174 	return (0);
175 }
176 
177 #ifdef __FreeBSD__
178 int
cred_send(int sock)179 cred_send(int sock)
180 {
181 	unsigned char credbuf[CMSG_SPACE(sizeof(struct cmsgcred))];
182 	struct msghdr msg;
183 	struct cmsghdr *cmsg;
184 	struct iovec iov;
185 	uint8_t dummy;
186 
187 	bzero(credbuf, sizeof(credbuf));
188 	bzero(&msg, sizeof(msg));
189 	bzero(&iov, sizeof(iov));
190 
191 	/*
192 	 * XXX: We send one byte along with the control message, because
193 	 *      setting msg_iov to NULL only works if this is the first
194 	 *      packet send over the socket. Once we send some data we
195 	 *      won't be able to send credentials anymore. This is most
196 	 *      likely a kernel bug.
197 	 */
198 	dummy = 0;
199 	iov.iov_base = &dummy;
200 	iov.iov_len = sizeof(dummy);
201 
202 	msg.msg_iov = &iov;
203 	msg.msg_iovlen = 1;
204 	msg.msg_control = credbuf;
205 	msg.msg_controllen = sizeof(credbuf);
206 
207 	cmsg = CMSG_FIRSTHDR(&msg);
208 	cmsg->cmsg_len = CMSG_LEN(sizeof(struct cmsgcred));
209 	cmsg->cmsg_level = SOL_SOCKET;
210 	cmsg->cmsg_type = SCM_CREDS;
211 
212 	if (msg_send(sock, &msg) == -1)
213 		return (-1);
214 
215 	return (0);
216 }
217 
218 int
cred_recv(int sock,struct cmsgcred * cred)219 cred_recv(int sock, struct cmsgcred *cred)
220 {
221 	unsigned char credbuf[CMSG_SPACE(sizeof(struct cmsgcred))];
222 	struct msghdr msg;
223 	struct cmsghdr *cmsg;
224 	struct iovec iov;
225 	uint8_t dummy;
226 
227 	bzero(credbuf, sizeof(credbuf));
228 	bzero(&msg, sizeof(msg));
229 	bzero(&iov, sizeof(iov));
230 
231 	iov.iov_base = &dummy;
232 	iov.iov_len = sizeof(dummy);
233 
234 	msg.msg_iov = &iov;
235 	msg.msg_iovlen = 1;
236 	msg.msg_control = credbuf;
237 	msg.msg_controllen = sizeof(credbuf);
238 
239 	if (msg_recv(sock, &msg) == -1)
240 		return (-1);
241 
242 	cmsg = CMSG_FIRSTHDR(&msg);
243 	if (cmsg == NULL ||
244 	    cmsg->cmsg_len != CMSG_LEN(sizeof(struct cmsgcred)) ||
245 	    cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_CREDS) {
246 		errno = EINVAL;
247 		return (-1);
248 	}
249 	bcopy(CMSG_DATA(cmsg), cred, sizeof(*cred));
250 
251 	return (0);
252 }
253 #endif
254 
255 static int
fd_package_send(int sock,const int * fds,size_t nfds)256 fd_package_send(int sock, const int *fds, size_t nfds)
257 {
258 	struct msghdr msg;
259 	struct cmsghdr *cmsg;
260 	struct iovec iov;
261 	unsigned int i;
262 	int serrno, ret;
263 	uint8_t dummy;
264 
265 	PJDLOG_ASSERT(sock >= 0);
266 	PJDLOG_ASSERT(fds != NULL);
267 	PJDLOG_ASSERT(nfds > 0);
268 
269 	bzero(&msg, sizeof(msg));
270 
271 	/*
272 	 * XXX: Look into cred_send function for more details.
273 	 */
274 	dummy = 0;
275 	iov.iov_base = &dummy;
276 	iov.iov_len = sizeof(dummy);
277 
278 	msg.msg_iov = &iov;
279 	msg.msg_iovlen = 1;
280 	msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int));
281 	msg.msg_control = calloc(1, msg.msg_controllen);
282 	if (msg.msg_control == NULL)
283 		return (-1);
284 
285 	ret = -1;
286 
287 	for (i = 0, cmsg = CMSG_FIRSTHDR(&msg); i < nfds && cmsg != NULL;
288 	    i++, cmsg = CMSG_NXTHDR(&msg, cmsg)) {
289 		if (msghdr_add_fd(cmsg, fds[i]) == -1)
290 			goto end;
291 	}
292 
293 	if (msg_send(sock, &msg) == -1)
294 		goto end;
295 
296 	ret = 0;
297 end:
298 	serrno = errno;
299 	free(msg.msg_control);
300 	errno = serrno;
301 	return (ret);
302 }
303 
304 static int
fd_package_recv(int sock,int * fds,size_t nfds)305 fd_package_recv(int sock, int *fds, size_t nfds)
306 {
307 	struct msghdr msg;
308 	struct cmsghdr *cmsg;
309 	unsigned int i;
310 	int serrno, ret;
311 	struct iovec iov;
312 	uint8_t dummy;
313 
314 	PJDLOG_ASSERT(sock >= 0);
315 	PJDLOG_ASSERT(nfds > 0);
316 	PJDLOG_ASSERT(fds != NULL);
317 
318 	bzero(&msg, sizeof(msg));
319 	bzero(&iov, sizeof(iov));
320 
321 	/*
322 	 * XXX: Look into cred_send function for more details.
323 	 */
324 	iov.iov_base = &dummy;
325 	iov.iov_len = sizeof(dummy);
326 
327 	msg.msg_iov = &iov;
328 	msg.msg_iovlen = 1;
329 	msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int));
330 	msg.msg_control = calloc(1, msg.msg_controllen);
331 	if (msg.msg_control == NULL)
332 		return (-1);
333 
334 	ret = -1;
335 
336 	if (msg_recv(sock, &msg) == -1)
337 		goto end;
338 
339 	for (i = 0, cmsg = CMSG_FIRSTHDR(&msg); i < nfds && cmsg != NULL;
340 	    i++, cmsg = CMSG_NXTHDR(&msg, cmsg)) {
341 		fds[i] = msghdr_get_fd(cmsg);
342 		if (fds[i] < 0)
343 			break;
344 	}
345 
346 	if (cmsg != NULL || i < nfds) {
347 		int fd;
348 
349 		/*
350 		 * We need to close all received descriptors, even if we have
351 		 * different control message (eg. SCM_CREDS) in between.
352 		 */
353 		for (cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL;
354 		    cmsg = CMSG_NXTHDR(&msg, cmsg)) {
355 			fd = msghdr_get_fd(cmsg);
356 			if (fd >= 0)
357 				close(fd);
358 		}
359 		errno = EINVAL;
360 		goto end;
361 	}
362 
363 	ret = 0;
364 end:
365 	serrno = errno;
366 	free(msg.msg_control);
367 	errno = serrno;
368 	return (ret);
369 }
370 
371 int
fd_recv(int sock,int * fds,size_t nfds)372 fd_recv(int sock, int *fds, size_t nfds)
373 {
374 	unsigned int i, step, j;
375 	int ret, serrno;
376 
377 	if (nfds == 0 || fds == NULL) {
378 		errno = EINVAL;
379 		return (-1);
380 	}
381 
382 	ret = i = step = 0;
383 	while (i < nfds) {
384 		if (PKG_MAX_SIZE < nfds - i)
385 			step = PKG_MAX_SIZE;
386 		else
387 			step = nfds - i;
388 		ret = fd_package_recv(sock, fds + i, step);
389 		if (ret != 0) {
390 			/* Close all received descriptors. */
391 			serrno = errno;
392 			for (j = 0; j < i; j++)
393 				close(fds[j]);
394 			errno = serrno;
395 			break;
396 		}
397 		i += step;
398 	}
399 
400 	return (ret);
401 }
402 
403 int
fd_send(int sock,const int * fds,size_t nfds)404 fd_send(int sock, const int *fds, size_t nfds)
405 {
406 	unsigned int i, step;
407 	int ret;
408 
409 	if (nfds == 0 || fds == NULL) {
410 		errno = EINVAL;
411 		return (-1);
412 	}
413 
414 	ret = i = step = 0;
415 	while (i < nfds) {
416 		if (PKG_MAX_SIZE < nfds - i)
417 			step = PKG_MAX_SIZE;
418 		else
419 			step = nfds - i;
420 		ret = fd_package_send(sock, fds + i, step);
421 		if (ret != 0)
422 			break;
423 		i += step;
424 	}
425 
426 	return (ret);
427 }
428 
429 int
buf_send(int sock,void * buf,size_t size)430 buf_send(int sock, void *buf, size_t size)
431 {
432 	ssize_t done;
433 	unsigned char *ptr;
434 
435 	PJDLOG_ASSERT(sock >= 0);
436 	PJDLOG_ASSERT(size > 0);
437 	PJDLOG_ASSERT(buf != NULL);
438 
439 	ptr = buf;
440 	do {
441 		fd_wait(sock, false);
442 		done = send(sock, ptr, size, 0);
443 		if (done == -1) {
444 			if (errno == EINTR)
445 				continue;
446 			return (-1);
447 		} else if (done == 0) {
448 			errno = ENOTCONN;
449 			return (-1);
450 		}
451 		size -= done;
452 		ptr += done;
453 	} while (size > 0);
454 
455 	return (0);
456 }
457 
458 int
buf_recv(int sock,void * buf,size_t size)459 buf_recv(int sock, void *buf, size_t size)
460 {
461 	ssize_t done;
462 	unsigned char *ptr;
463 
464 	PJDLOG_ASSERT(sock >= 0);
465 	PJDLOG_ASSERT(buf != NULL);
466 
467 	ptr = buf;
468 	while (size > 0) {
469 		fd_wait(sock, true);
470 		done = recv(sock, ptr, size, 0);
471 		if (done == -1) {
472 			if (errno == EINTR)
473 				continue;
474 			return (-1);
475 		} else if (done == 0) {
476 			errno = ENOTCONN;
477 			return (-1);
478 		}
479 		size -= done;
480 		ptr += done;
481 	}
482 
483 	return (0);
484 }
485