1 /* $NetBSD: msgio.c,v 1.2 2018/09/08 14:02:15 christos Exp $ */
2
3 /*-
4 * SPDX-License-Identifier: BSD-2-Clause-FreeBSD
5 *
6 * Copyright (c) 2013 The FreeBSD Foundation
7 * Copyright (c) 2013 Mariusz Zaborski <oshogbo@FreeBSD.org>
8 * All rights reserved.
9 *
10 * This software was developed by Pawel Jakub Dawidek under sponsorship from
11 * the FreeBSD Foundation.
12 *
13 * Redistribution and use in source and binary forms, with or without
14 * modification, are permitted provided that the following conditions
15 * are met:
16 * 1. Redistributions of source code must retain the above copyright
17 * notice, this list of conditions and the following disclaimer.
18 * 2. Redistributions in binary form must reproduce the above copyright
19 * notice, this list of conditions and the following disclaimer in the
20 * documentation and/or other materials provided with the distribution.
21 *
22 * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
23 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
24 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
25 * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
26 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
27 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
28 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
29 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
30 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
31 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
32 * SUCH DAMAGE.
33 */
34
35 #include <sys/cdefs.h>
36 #ifdef __FreeBSD__
37 __FBSDID("$FreeBSD: head/lib/libnv/msgio.c 326219 2017-11-26 02:00:33Z pfg $");
38 #else
39 __RCSID("$NetBSD: msgio.c,v 1.2 2018/09/08 14:02:15 christos Exp $");
40 #endif
41
42 #include <sys/param.h>
43 #include <sys/socket.h>
44 #include <sys/select.h>
45
46 #include <errno.h>
47 #include <fcntl.h>
48 #include <stdbool.h>
49 #include <stdint.h>
50 #include <stdlib.h>
51 #include <string.h>
52 #include <unistd.h>
53
54 #ifdef HAVE_PJDLOG
55 #include <pjdlog.h>
56 #endif
57
58 #include "common_impl.h"
59 #include "msgio.h"
60
61 #ifndef HAVE_PJDLOG
62 #include <assert.h>
63 #define PJDLOG_ASSERT(...) assert(__VA_ARGS__)
64 #define PJDLOG_RASSERT(expr, ...) assert(expr)
65 #define PJDLOG_ABORT(...) abort()
66 #endif
67
68 #ifdef __linux__
69 /* Linux: arbitrary size, but must be lower than SCM_MAX_FD. */
70 #define PKG_MAX_SIZE ((64U - 1) * CMSG_SPACE(sizeof(int)))
71 #else
72 #define PKG_MAX_SIZE (MCLBYTES / CMSG_SPACE(sizeof(int)) - 1)
73 #endif
74
75 static int
msghdr_add_fd(struct cmsghdr * cmsg,int fd)76 msghdr_add_fd(struct cmsghdr *cmsg, int fd)
77 {
78
79 PJDLOG_ASSERT(fd >= 0);
80
81 if (!fd_is_valid(fd)) {
82 errno = EBADF;
83 return (-1);
84 }
85
86 cmsg->cmsg_level = SOL_SOCKET;
87 cmsg->cmsg_type = SCM_RIGHTS;
88 cmsg->cmsg_len = CMSG_LEN(sizeof(fd));
89 bcopy(&fd, CMSG_DATA(cmsg), sizeof(fd));
90
91 return (0);
92 }
93
94 static int
msghdr_get_fd(struct cmsghdr * cmsg)95 msghdr_get_fd(struct cmsghdr *cmsg)
96 {
97 int fd;
98
99 if (cmsg == NULL || cmsg->cmsg_level != SOL_SOCKET ||
100 cmsg->cmsg_type != SCM_RIGHTS ||
101 cmsg->cmsg_len != CMSG_LEN(sizeof(fd))) {
102 errno = EINVAL;
103 return (-1);
104 }
105
106 bcopy(CMSG_DATA(cmsg), &fd, sizeof(fd));
107 #ifndef MSG_CMSG_CLOEXEC
108 /*
109 * If the MSG_CMSG_CLOEXEC flag is not available we cannot set the
110 * close-on-exec flag atomically, but we still want to set it for
111 * consistency.
112 */
113 (void) fcntl(fd, F_SETFD, FD_CLOEXEC);
114 #endif
115
116 return (fd);
117 }
118
119 static void
fd_wait(int fd,bool doread)120 fd_wait(int fd, bool doread)
121 {
122 fd_set fds;
123
124 PJDLOG_ASSERT(fd >= 0);
125
126 FD_ZERO(&fds);
127 FD_SET(fd, &fds);
128 (void)select(fd + 1, doread ? &fds : NULL, doread ? NULL : &fds,
129 NULL, NULL);
130 }
131
132 static int
msg_recv(int sock,struct msghdr * msg)133 msg_recv(int sock, struct msghdr *msg)
134 {
135 int flags;
136
137 PJDLOG_ASSERT(sock >= 0);
138
139 #ifdef MSG_CMSG_CLOEXEC
140 flags = MSG_CMSG_CLOEXEC;
141 #else
142 flags = 0;
143 #endif
144
145 for (;;) {
146 fd_wait(sock, true);
147 if (recvmsg(sock, msg, flags) == -1) {
148 if (errno == EINTR)
149 continue;
150 return (-1);
151 }
152 break;
153 }
154
155 return (0);
156 }
157
158 static int
msg_send(int sock,const struct msghdr * msg)159 msg_send(int sock, const struct msghdr *msg)
160 {
161
162 PJDLOG_ASSERT(sock >= 0);
163
164 for (;;) {
165 fd_wait(sock, false);
166 if (sendmsg(sock, msg, 0) == -1) {
167 if (errno == EINTR)
168 continue;
169 return (-1);
170 }
171 break;
172 }
173
174 return (0);
175 }
176
177 #ifdef __FreeBSD__
178 int
cred_send(int sock)179 cred_send(int sock)
180 {
181 unsigned char credbuf[CMSG_SPACE(sizeof(struct cmsgcred))];
182 struct msghdr msg;
183 struct cmsghdr *cmsg;
184 struct iovec iov;
185 uint8_t dummy;
186
187 bzero(credbuf, sizeof(credbuf));
188 bzero(&msg, sizeof(msg));
189 bzero(&iov, sizeof(iov));
190
191 /*
192 * XXX: We send one byte along with the control message, because
193 * setting msg_iov to NULL only works if this is the first
194 * packet send over the socket. Once we send some data we
195 * won't be able to send credentials anymore. This is most
196 * likely a kernel bug.
197 */
198 dummy = 0;
199 iov.iov_base = &dummy;
200 iov.iov_len = sizeof(dummy);
201
202 msg.msg_iov = &iov;
203 msg.msg_iovlen = 1;
204 msg.msg_control = credbuf;
205 msg.msg_controllen = sizeof(credbuf);
206
207 cmsg = CMSG_FIRSTHDR(&msg);
208 cmsg->cmsg_len = CMSG_LEN(sizeof(struct cmsgcred));
209 cmsg->cmsg_level = SOL_SOCKET;
210 cmsg->cmsg_type = SCM_CREDS;
211
212 if (msg_send(sock, &msg) == -1)
213 return (-1);
214
215 return (0);
216 }
217
218 int
cred_recv(int sock,struct cmsgcred * cred)219 cred_recv(int sock, struct cmsgcred *cred)
220 {
221 unsigned char credbuf[CMSG_SPACE(sizeof(struct cmsgcred))];
222 struct msghdr msg;
223 struct cmsghdr *cmsg;
224 struct iovec iov;
225 uint8_t dummy;
226
227 bzero(credbuf, sizeof(credbuf));
228 bzero(&msg, sizeof(msg));
229 bzero(&iov, sizeof(iov));
230
231 iov.iov_base = &dummy;
232 iov.iov_len = sizeof(dummy);
233
234 msg.msg_iov = &iov;
235 msg.msg_iovlen = 1;
236 msg.msg_control = credbuf;
237 msg.msg_controllen = sizeof(credbuf);
238
239 if (msg_recv(sock, &msg) == -1)
240 return (-1);
241
242 cmsg = CMSG_FIRSTHDR(&msg);
243 if (cmsg == NULL ||
244 cmsg->cmsg_len != CMSG_LEN(sizeof(struct cmsgcred)) ||
245 cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_CREDS) {
246 errno = EINVAL;
247 return (-1);
248 }
249 bcopy(CMSG_DATA(cmsg), cred, sizeof(*cred));
250
251 return (0);
252 }
253 #endif
254
255 static int
fd_package_send(int sock,const int * fds,size_t nfds)256 fd_package_send(int sock, const int *fds, size_t nfds)
257 {
258 struct msghdr msg;
259 struct cmsghdr *cmsg;
260 struct iovec iov;
261 unsigned int i;
262 int serrno, ret;
263 uint8_t dummy;
264
265 PJDLOG_ASSERT(sock >= 0);
266 PJDLOG_ASSERT(fds != NULL);
267 PJDLOG_ASSERT(nfds > 0);
268
269 bzero(&msg, sizeof(msg));
270
271 /*
272 * XXX: Look into cred_send function for more details.
273 */
274 dummy = 0;
275 iov.iov_base = &dummy;
276 iov.iov_len = sizeof(dummy);
277
278 msg.msg_iov = &iov;
279 msg.msg_iovlen = 1;
280 msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int));
281 msg.msg_control = calloc(1, msg.msg_controllen);
282 if (msg.msg_control == NULL)
283 return (-1);
284
285 ret = -1;
286
287 for (i = 0, cmsg = CMSG_FIRSTHDR(&msg); i < nfds && cmsg != NULL;
288 i++, cmsg = CMSG_NXTHDR(&msg, cmsg)) {
289 if (msghdr_add_fd(cmsg, fds[i]) == -1)
290 goto end;
291 }
292
293 if (msg_send(sock, &msg) == -1)
294 goto end;
295
296 ret = 0;
297 end:
298 serrno = errno;
299 free(msg.msg_control);
300 errno = serrno;
301 return (ret);
302 }
303
304 static int
fd_package_recv(int sock,int * fds,size_t nfds)305 fd_package_recv(int sock, int *fds, size_t nfds)
306 {
307 struct msghdr msg;
308 struct cmsghdr *cmsg;
309 unsigned int i;
310 int serrno, ret;
311 struct iovec iov;
312 uint8_t dummy;
313
314 PJDLOG_ASSERT(sock >= 0);
315 PJDLOG_ASSERT(nfds > 0);
316 PJDLOG_ASSERT(fds != NULL);
317
318 bzero(&msg, sizeof(msg));
319 bzero(&iov, sizeof(iov));
320
321 /*
322 * XXX: Look into cred_send function for more details.
323 */
324 iov.iov_base = &dummy;
325 iov.iov_len = sizeof(dummy);
326
327 msg.msg_iov = &iov;
328 msg.msg_iovlen = 1;
329 msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int));
330 msg.msg_control = calloc(1, msg.msg_controllen);
331 if (msg.msg_control == NULL)
332 return (-1);
333
334 ret = -1;
335
336 if (msg_recv(sock, &msg) == -1)
337 goto end;
338
339 for (i = 0, cmsg = CMSG_FIRSTHDR(&msg); i < nfds && cmsg != NULL;
340 i++, cmsg = CMSG_NXTHDR(&msg, cmsg)) {
341 fds[i] = msghdr_get_fd(cmsg);
342 if (fds[i] < 0)
343 break;
344 }
345
346 if (cmsg != NULL || i < nfds) {
347 int fd;
348
349 /*
350 * We need to close all received descriptors, even if we have
351 * different control message (eg. SCM_CREDS) in between.
352 */
353 for (cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL;
354 cmsg = CMSG_NXTHDR(&msg, cmsg)) {
355 fd = msghdr_get_fd(cmsg);
356 if (fd >= 0)
357 close(fd);
358 }
359 errno = EINVAL;
360 goto end;
361 }
362
363 ret = 0;
364 end:
365 serrno = errno;
366 free(msg.msg_control);
367 errno = serrno;
368 return (ret);
369 }
370
371 int
fd_recv(int sock,int * fds,size_t nfds)372 fd_recv(int sock, int *fds, size_t nfds)
373 {
374 unsigned int i, step, j;
375 int ret, serrno;
376
377 if (nfds == 0 || fds == NULL) {
378 errno = EINVAL;
379 return (-1);
380 }
381
382 ret = i = step = 0;
383 while (i < nfds) {
384 if (PKG_MAX_SIZE < nfds - i)
385 step = PKG_MAX_SIZE;
386 else
387 step = nfds - i;
388 ret = fd_package_recv(sock, fds + i, step);
389 if (ret != 0) {
390 /* Close all received descriptors. */
391 serrno = errno;
392 for (j = 0; j < i; j++)
393 close(fds[j]);
394 errno = serrno;
395 break;
396 }
397 i += step;
398 }
399
400 return (ret);
401 }
402
403 int
fd_send(int sock,const int * fds,size_t nfds)404 fd_send(int sock, const int *fds, size_t nfds)
405 {
406 unsigned int i, step;
407 int ret;
408
409 if (nfds == 0 || fds == NULL) {
410 errno = EINVAL;
411 return (-1);
412 }
413
414 ret = i = step = 0;
415 while (i < nfds) {
416 if (PKG_MAX_SIZE < nfds - i)
417 step = PKG_MAX_SIZE;
418 else
419 step = nfds - i;
420 ret = fd_package_send(sock, fds + i, step);
421 if (ret != 0)
422 break;
423 i += step;
424 }
425
426 return (ret);
427 }
428
429 int
buf_send(int sock,void * buf,size_t size)430 buf_send(int sock, void *buf, size_t size)
431 {
432 ssize_t done;
433 unsigned char *ptr;
434
435 PJDLOG_ASSERT(sock >= 0);
436 PJDLOG_ASSERT(size > 0);
437 PJDLOG_ASSERT(buf != NULL);
438
439 ptr = buf;
440 do {
441 fd_wait(sock, false);
442 done = send(sock, ptr, size, 0);
443 if (done == -1) {
444 if (errno == EINTR)
445 continue;
446 return (-1);
447 } else if (done == 0) {
448 errno = ENOTCONN;
449 return (-1);
450 }
451 size -= done;
452 ptr += done;
453 } while (size > 0);
454
455 return (0);
456 }
457
458 int
buf_recv(int sock,void * buf,size_t size)459 buf_recv(int sock, void *buf, size_t size)
460 {
461 ssize_t done;
462 unsigned char *ptr;
463
464 PJDLOG_ASSERT(sock >= 0);
465 PJDLOG_ASSERT(buf != NULL);
466
467 ptr = buf;
468 while (size > 0) {
469 fd_wait(sock, true);
470 done = recv(sock, ptr, size, 0);
471 if (done == -1) {
472 if (errno == EINTR)
473 continue;
474 return (-1);
475 } else if (done == 0) {
476 errno = ENOTCONN;
477 return (-1);
478 }
479 size -= done;
480 ptr += done;
481 }
482
483 return (0);
484 }
485