1 // Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors
2 // Licensed under the MIT License:
3 //
4 // Permission is hereby granted, free of charge, to any person obtaining a copy
5 // of this software and associated documentation files (the "Software"), to deal
6 // in the Software without restriction, including without limitation the rights
7 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 // copies of the Software, and to permit persons to whom the Software is
9 // furnished to do so, subject to the following conditions:
10 //
11 // The above copyright notice and this permission notice shall be included in
12 // all copies or substantial portions of the Software.
13 //
14 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 // THE SOFTWARE.
21
22 #if !_WIN32
23 // For Win32 implementation, see async-io-win32.c++.
24
25 #ifndef _GNU_SOURCE
26 #define _GNU_SOURCE
27 #endif
28
29 #include "async-io.h"
30 #include "async-io-internal.h"
31 #include "async-unix.h"
32 #include "debug.h"
33 #include "thread.h"
34 #include "io.h"
35 #include "miniposix.h"
36 #include <unistd.h>
37 #include <sys/uio.h>
38 #include <errno.h>
39 #include <fcntl.h>
40 #include <sys/types.h>
41 #include <sys/socket.h>
42 #include <sys/un.h>
43 #include <netinet/in.h>
44 #include <netinet/tcp.h>
45 #include <stddef.h>
46 #include <stdlib.h>
47 #include <arpa/inet.h>
48 #include <netdb.h>
49 #include <set>
50 #include <poll.h>
51 #include <limits.h>
52 #include <sys/ioctl.h>
53
54 #if !defined(SO_PEERCRED) && defined(LOCAL_PEERCRED)
55 #include <sys/ucred.h>
56 #endif
57
58 namespace kj {
59
60 namespace {
61
setNonblocking(int fd)62 void setNonblocking(int fd) {
63 #ifdef FIONBIO
64 int opt = 1;
65 KJ_SYSCALL(ioctl(fd, FIONBIO, &opt));
66 #else
67 int flags;
68 KJ_SYSCALL(flags = fcntl(fd, F_GETFL));
69 if ((flags & O_NONBLOCK) == 0) {
70 KJ_SYSCALL(fcntl(fd, F_SETFL, flags | O_NONBLOCK));
71 }
72 #endif
73 }
74
setCloseOnExec(int fd)75 void setCloseOnExec(int fd) {
76 #ifdef FIOCLEX
77 KJ_SYSCALL(ioctl(fd, FIOCLEX));
78 #else
79 int flags;
80 KJ_SYSCALL(flags = fcntl(fd, F_GETFD));
81 if ((flags & FD_CLOEXEC) == 0) {
82 KJ_SYSCALL(fcntl(fd, F_SETFD, flags | FD_CLOEXEC));
83 }
84 #endif
85 }
86
87 static constexpr uint NEW_FD_FLAGS =
88 #if __linux__ && !__BIONIC__
89 LowLevelAsyncIoProvider::ALREADY_CLOEXEC | LowLevelAsyncIoProvider::ALREADY_NONBLOCK |
90 #endif
91 LowLevelAsyncIoProvider::TAKE_OWNERSHIP;
92 // We always try to open FDs with CLOEXEC and NONBLOCK already set on Linux, but on other platforms
93 // this is not possible.
94
95 class OwnedFileDescriptor {
96 public:
OwnedFileDescriptor(int fd,uint flags)97 OwnedFileDescriptor(int fd, uint flags): fd(fd), flags(flags) {
98 if (flags & LowLevelAsyncIoProvider::ALREADY_NONBLOCK) {
99 KJ_DREQUIRE(fcntl(fd, F_GETFL) & O_NONBLOCK, "You claimed you set NONBLOCK, but you didn't.");
100 } else {
101 setNonblocking(fd);
102 }
103
104 if (flags & LowLevelAsyncIoProvider::TAKE_OWNERSHIP) {
105 if (flags & LowLevelAsyncIoProvider::ALREADY_CLOEXEC) {
106 KJ_DREQUIRE(fcntl(fd, F_GETFD) & FD_CLOEXEC,
107 "You claimed you set CLOEXEC, but you didn't.");
108 } else {
109 setCloseOnExec(fd);
110 }
111 }
112 }
113
~OwnedFileDescriptor()114 ~OwnedFileDescriptor() noexcept(false) {
115 // Don't use SYSCALL() here because close() should not be repeated on EINTR.
116 if ((flags & LowLevelAsyncIoProvider::TAKE_OWNERSHIP) && close(fd) < 0) {
117 KJ_FAIL_SYSCALL("close", errno, fd) {
118 // Recoverable exceptions are safe in destructors.
119 break;
120 }
121 }
122 }
123
124 protected:
125 const int fd;
126
127 private:
128 uint flags;
129 };
130
131 // =======================================================================================
132
133 class AsyncStreamFd: public OwnedFileDescriptor, public AsyncCapabilityStream {
134 public:
AsyncStreamFd(UnixEventPort & eventPort,int fd,uint flags)135 AsyncStreamFd(UnixEventPort& eventPort, int fd, uint flags)
136 : OwnedFileDescriptor(fd, flags),
137 eventPort(eventPort),
138 observer(eventPort, fd, UnixEventPort::FdObserver::OBSERVE_READ_WRITE) {}
~AsyncStreamFd()139 virtual ~AsyncStreamFd() noexcept(false) {}
140
tryRead(void * buffer,size_t minBytes,size_t maxBytes)141 Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
142 return tryReadInternal(buffer, minBytes, maxBytes, nullptr, 0, {0,0})
143 .then([](ReadResult r) { return r.byteCount; });
144 }
145
tryReadWithFds(void * buffer,size_t minBytes,size_t maxBytes,AutoCloseFd * fdBuffer,size_t maxFds)146 Promise<ReadResult> tryReadWithFds(void* buffer, size_t minBytes, size_t maxBytes,
147 AutoCloseFd* fdBuffer, size_t maxFds) override {
148 return tryReadInternal(buffer, minBytes, maxBytes, fdBuffer, maxFds, {0,0});
149 }
150
tryReadWithStreams(void * buffer,size_t minBytes,size_t maxBytes,Own<AsyncCapabilityStream> * streamBuffer,size_t maxStreams)151 Promise<ReadResult> tryReadWithStreams(
152 void* buffer, size_t minBytes, size_t maxBytes,
153 Own<AsyncCapabilityStream>* streamBuffer, size_t maxStreams) override {
154 auto fdBuffer = kj::heapArray<AutoCloseFd>(maxStreams);
155 auto promise = tryReadInternal(buffer, minBytes, maxBytes, fdBuffer.begin(), maxStreams, {0,0});
156
157 return promise.then([this, fdBuffer = kj::mv(fdBuffer), streamBuffer]
158 (ReadResult result) mutable {
159 for (auto i: kj::zeroTo(result.capCount)) {
160 streamBuffer[i] = kj::heap<AsyncStreamFd>(eventPort, fdBuffer[i].release(),
161 LowLevelAsyncIoProvider::TAKE_OWNERSHIP | LowLevelAsyncIoProvider::ALREADY_CLOEXEC);
162 }
163 return result;
164 });
165 }
166
write(const void * buffer,size_t size)167 Promise<void> write(const void* buffer, size_t size) override {
168 ssize_t n;
169 KJ_NONBLOCKING_SYSCALL(n = ::write(fd, buffer, size)) {
170 // Error.
171
172 // We can't "return kj::READY_NOW;" inside this block because it causes a memory leak due to
173 // a bug that exists in both Clang and GCC:
174 // http://gcc.gnu.org/bugzilla/show_bug.cgi?id=33799
175 // http://llvm.org/bugs/show_bug.cgi?id=12286
176 goto error;
177 }
178 if (false) {
179 error:
180 return kj::READY_NOW;
181 }
182
183 if (n < 0) {
184 // EAGAIN -- need to wait for writability and try again.
185 return observer.whenBecomesWritable().then([=]() {
186 return write(buffer, size);
187 });
188 } else if (n == size) {
189 // All done.
190 return READY_NOW;
191 } else {
192 // Fewer than `size` bytes were written, but we CANNOT assume we're out of buffer space, as
193 // Linux is known to return partial reads/writes when interrupted by a signal -- yes, even
194 // for non-blocking operations. So, we'll need to write() again now, even though it will
195 // almost certainly fail with EAGAIN. See comments in the read path for more info.
196 buffer = reinterpret_cast<const byte*>(buffer) + n;
197 size -= n;
198 return write(buffer, size);
199 }
200 }
201
write(ArrayPtr<const ArrayPtr<const byte>> pieces)202 Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override {
203 if (pieces.size() == 0) {
204 return writeInternal(nullptr, nullptr, nullptr);
205 } else {
206 return writeInternal(pieces[0], pieces.slice(1, pieces.size()), nullptr);
207 }
208 }
209
writeWithFds(ArrayPtr<const byte> data,ArrayPtr<const ArrayPtr<const byte>> moreData,ArrayPtr<const int> fds)210 Promise<void> writeWithFds(ArrayPtr<const byte> data,
211 ArrayPtr<const ArrayPtr<const byte>> moreData,
212 ArrayPtr<const int> fds) override {
213 return writeInternal(data, moreData, fds);
214 }
215
writeWithStreams(ArrayPtr<const byte> data,ArrayPtr<const ArrayPtr<const byte>> moreData,Array<Own<AsyncCapabilityStream>> streams)216 Promise<void> writeWithStreams(ArrayPtr<const byte> data,
217 ArrayPtr<const ArrayPtr<const byte>> moreData,
218 Array<Own<AsyncCapabilityStream>> streams) override {
219 auto fds = KJ_MAP(stream, streams) {
220 return downcast<AsyncStreamFd>(*stream).fd;
221 };
222 auto promise = writeInternal(data, moreData, fds);
223 return promise.attach(kj::mv(fds), kj::mv(streams));
224 }
225
whenWriteDisconnected()226 Promise<void> whenWriteDisconnected() override {
227 KJ_IF_MAYBE(p, writeDisconnectedPromise) {
228 return p->addBranch();
229 } else {
230 auto fork = observer.whenWriteDisconnected().fork();
231 auto result = fork.addBranch();
232 writeDisconnectedPromise = kj::mv(fork);
233 return kj::mv(result);
234 }
235 }
236
shutdownWrite()237 void shutdownWrite() override {
238 // There's no legitimate way to get an AsyncStreamFd that isn't a socket through the
239 // UnixAsyncIoProvider interface.
240 KJ_SYSCALL(shutdown(fd, SHUT_WR));
241 }
242
abortRead()243 void abortRead() override {
244 // There's no legitimate way to get an AsyncStreamFd that isn't a socket through the
245 // UnixAsyncIoProvider interface.
246 KJ_SYSCALL(shutdown(fd, SHUT_RD));
247 }
248
getsockopt(int level,int option,void * value,uint * length)249 void getsockopt(int level, int option, void* value, uint* length) override {
250 socklen_t socklen = *length;
251 KJ_SYSCALL(::getsockopt(fd, level, option, value, &socklen));
252 *length = socklen;
253 }
254
setsockopt(int level,int option,const void * value,uint length)255 void setsockopt(int level, int option, const void* value, uint length) override {
256 KJ_SYSCALL(::setsockopt(fd, level, option, value, length));
257 }
258
getsockname(struct sockaddr * addr,uint * length)259 void getsockname(struct sockaddr* addr, uint* length) override {
260 socklen_t socklen = *length;
261 KJ_SYSCALL(::getsockname(fd, addr, &socklen));
262 *length = socklen;
263 }
264
getpeername(struct sockaddr * addr,uint * length)265 void getpeername(struct sockaddr* addr, uint* length) override {
266 socklen_t socklen = *length;
267 KJ_SYSCALL(::getpeername(fd, addr, &socklen));
268 *length = socklen;
269 }
270
getFd() const271 kj::Maybe<int> getFd() const override {
272 return fd;
273 }
274
registerAncillaryMessageHandler(kj::Function<void (kj::ArrayPtr<AncillaryMessage>)> fn)275 void registerAncillaryMessageHandler(
276 kj::Function<void(kj::ArrayPtr<AncillaryMessage>)> fn) override {
277 ancillaryMsgCallback = kj::mv(fn);
278 }
279
waitConnected()280 Promise<void> waitConnected() {
281 // Wait until initial connection has completed. This actually just waits until it is writable.
282
283 // Can't just go directly to writeObserver.whenBecomesWritable() because of edge triggering. We
284 // need to explicitly check if the socket is already connected.
285
286 struct pollfd pollfd;
287 memset(&pollfd, 0, sizeof(pollfd));
288 pollfd.fd = fd;
289 pollfd.events = POLLOUT;
290
291 int pollResult;
292 KJ_SYSCALL(pollResult = poll(&pollfd, 1, 0));
293
294 if (pollResult == 0) {
295 // Not ready yet. We can safely use the edge-triggered observer.
296 return observer.whenBecomesWritable();
297 } else {
298 // Ready now.
299 return kj::READY_NOW;
300 }
301 }
302
303 private:
304 UnixEventPort& eventPort;
305 UnixEventPort::FdObserver observer;
306 Maybe<ForkedPromise<void>> writeDisconnectedPromise;
307 Maybe<Function<void(ArrayPtr<AncillaryMessage>)>> ancillaryMsgCallback;
308
tryReadInternal(void * buffer,size_t minBytes,size_t maxBytes,AutoCloseFd * fdBuffer,size_t maxFds,ReadResult alreadyRead)309 Promise<ReadResult> tryReadInternal(void* buffer, size_t minBytes, size_t maxBytes,
310 AutoCloseFd* fdBuffer, size_t maxFds,
311 ReadResult alreadyRead) {
312 // `alreadyRead` is the number of bytes we have already received via previous reads -- minBytes,
313 // maxBytes, and buffer have already been adjusted to account for them, but this count must
314 // be included in the final return value.
315
316 ssize_t n;
317 if (maxFds == 0 && ancillaryMsgCallback == nullptr) {
318 KJ_NONBLOCKING_SYSCALL(n = ::read(fd, buffer, maxBytes)) {
319 // Error.
320
321 // We can't "return kj::READY_NOW;" inside this block because it causes a memory leak due to
322 // a bug that exists in both Clang and GCC:
323 // http://gcc.gnu.org/bugzilla/show_bug.cgi?id=33799
324 // http://llvm.org/bugs/show_bug.cgi?id=12286
325 goto error;
326 }
327 } else {
328 struct msghdr msg;
329 memset(&msg, 0, sizeof(msg));
330
331 struct iovec iov;
332 memset(&iov, 0, sizeof(iov));
333 iov.iov_base = buffer;
334 iov.iov_len = maxBytes;
335 msg.msg_iov = &iov;
336 msg.msg_iovlen = 1;
337
338 // Allocate space to receive a cmsg.
339 size_t msgBytes;
340 if (ancillaryMsgCallback == nullptr) {
341 #if __APPLE__ || __FreeBSD__
342 // Until very recently (late 2018 / early 2019), FreeBSD suffered from a bug in which when
343 // an SCM_RIGHTS message was truncated on delivery, it would not close the FDs that weren't
344 // delivered -- they would simply leak: https://bugs.freebsd.org/131876
345 //
346 // My testing indicates that MacOS has this same bug as of today (April 2019). I don't know
347 // if they plan to fix it or are even aware of it.
348 //
349 // To handle both cases, we will always provide space to receive 512 FDs. Hopefully, this is
350 // greater than the maximum number of FDs that these kernels will transmit in one message
351 // PLUS enough space for any other ancillary messages that could be sent before the
352 // SCM_RIGHTS message to push it back in the buffer. I couldn't find any firm documentation
353 // on these limits, though -- I only know that Linux is limited to 253, and I saw a hint in
354 // a comment in someone else's application that suggested FreeBSD is the same. Hopefully,
355 // then, this is sufficient to prevent attacks. But if not, there's nothing more we can do;
356 // it's really up to the kernel to fix this.
357 msgBytes = CMSG_SPACE(sizeof(int) * 512);
358 #else
359 msgBytes = CMSG_SPACE(sizeof(int) * maxFds);
360 #endif
361 } else {
362 // If we want room for ancillary messages instead of or in addition to FDs, just use the
363 // same amount of cushion as in the MacOS/FreeBSD case above.
364 // Someday we may want to allow customization here, but there's no immediate use for it.
365 msgBytes = CMSG_SPACE(sizeof(int) * 512);
366 }
367
368 // On Linux, CMSG_SPACE will align to a word-size boundary, but on Mac it always aligns to a
369 // 32-bit boundary. I guess aligning to 32 bits helps avoid the problem where you
370 // surprisingly end up with space for two file descriptors when you only wanted one. However,
371 // cmsghdr's preferred alignment is word-size (it contains a size_t). If we stack-allocate
372 // the buffer, we need to make sure it is aligned properly (maybe not on x64, but maybe on
373 // other platforms), so we want to allocate an array of words (we use void*). So... we use
374 // CMSG_SPACE() and then additionally round up to deal with Mac.
375 size_t msgWords = (msgBytes + sizeof(void*) - 1) / sizeof(void*);
376 KJ_STACK_ARRAY(void*, cmsgSpace, msgWords, 16, 256);
377 auto cmsgBytes = cmsgSpace.asBytes();
378 memset(cmsgBytes.begin(), 0, cmsgBytes.size());
379 msg.msg_control = cmsgBytes.begin();
380 msg.msg_controllen = msgBytes;
381
382 #ifdef MSG_CMSG_CLOEXEC
383 static constexpr int RECVMSG_FLAGS = MSG_CMSG_CLOEXEC;
384 #else
385 static constexpr int RECVMSG_FLAGS = 0;
386 #endif
387
388 KJ_NONBLOCKING_SYSCALL(n = ::recvmsg(fd, &msg, RECVMSG_FLAGS)) {
389 // Error.
390
391 // We can't "return kj::READY_NOW;" inside this block because it causes a memory leak due to
392 // a bug that exists in both Clang and GCC:
393 // http://gcc.gnu.org/bugzilla/show_bug.cgi?id=33799
394 // http://llvm.org/bugs/show_bug.cgi?id=12286
395 goto error;
396 }
397
398 if (n >= 0) {
399 // Process all messages.
400 //
401 // WARNING DANGER: We have to be VERY careful not to miss a file descriptor here, because
402 // if we do, then that FD will never be closed, and a malicious peer could exploit this to
403 // fill up our FD table, creating a DoS attack. Some things to keep in mind:
404 // - CMSG_SPACE() could have rounded up the space for alignment purposes, and this could
405 // mean we permitted the kernel to deliver more file descriptors than `maxFds`. We need
406 // to close the extras.
407 // - We can receive multiple ancillary messages at once. In particular, there is also
408 // SCM_CREDENTIALS. The sender decides what to send. They could send SCM_CREDENTIALS
409 // first followed by SCM_RIGHTS. We need to make sure we see both.
410 size_t nfds = 0;
411 size_t spaceLeft = msg.msg_controllen;
412 Vector<AncillaryMessage> ancillaryMessages;
413 for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
414 cmsg != nullptr; cmsg = CMSG_NXTHDR(&msg, cmsg)) {
415 if (spaceLeft >= CMSG_LEN(0) &&
416 cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
417 // Some operating systems (like MacOS) do not adjust csmg_len when the message is
418 // truncated. We must do so ourselves or risk overrunning the buffer.
419 auto len = kj::min(cmsg->cmsg_len, spaceLeft);
420 auto data = arrayPtr(reinterpret_cast<int*>(CMSG_DATA(cmsg)),
421 (len - CMSG_LEN(0)) / sizeof(int));
422 kj::Vector<kj::AutoCloseFd> trashFds;
423 for (auto fd: data) {
424 kj::AutoCloseFd ownFd(fd);
425 if (nfds < maxFds) {
426 fdBuffer[nfds++] = kj::mv(ownFd);
427 } else {
428 trashFds.add(kj::mv(ownFd));
429 }
430 }
431 } else if (spaceLeft >= CMSG_LEN(0) && ancillaryMsgCallback != nullptr) {
432 auto len = kj::min(cmsg->cmsg_len, spaceLeft);
433 auto data = ArrayPtr<const byte>(CMSG_DATA(cmsg), len - CMSG_LEN(0));
434 ancillaryMessages.add(cmsg->cmsg_level, cmsg->cmsg_type, data);
435 }
436
437 if (spaceLeft >= CMSG_LEN(0) && spaceLeft >= cmsg->cmsg_len) {
438 spaceLeft -= cmsg->cmsg_len;
439 } else {
440 spaceLeft = 0;
441 }
442 }
443
444 #ifndef MSG_CMSG_CLOEXEC
445 for (size_t i = 0; i < nfds; i++) {
446 setCloseOnExec(fdBuffer[i]);
447 }
448 #endif
449
450 if (ancillaryMessages.size() > 0) {
451 KJ_IF_MAYBE(fn, ancillaryMsgCallback) {
452 (*fn)(ancillaryMessages.asPtr());
453 }
454 }
455
456 alreadyRead.capCount += nfds;
457 fdBuffer += nfds;
458 maxFds -= nfds;
459 }
460 }
461
462 if (false) {
463 error:
464 return alreadyRead;
465 }
466
467 if (n < 0) {
468 // Read would block.
469 return observer.whenBecomesReadable().then([=]() {
470 return tryReadInternal(buffer, minBytes, maxBytes, fdBuffer, maxFds, alreadyRead);
471 });
472 } else if (n == 0) {
473 // EOF -OR- maxBytes == 0.
474 return alreadyRead;
475 } else if (implicitCast<size_t>(n) >= minBytes) {
476 // We read enough to stop here.
477 alreadyRead.byteCount += n;
478 return alreadyRead;
479 } else {
480 // The kernel returned fewer bytes than we asked for (and fewer than we need).
481
482 buffer = reinterpret_cast<byte*>(buffer) + n;
483 minBytes -= n;
484 maxBytes -= n;
485 alreadyRead.byteCount += n;
486
487 // According to David Klempner, who works on Stubby at Google, we sadly CANNOT assume that
488 // we've consumed the whole read buffer here. If a signal is delivered in the middle of a
489 // read() -- yes, even a non-blocking read -- it can cause the kernel to return a partial
490 // result, with data still in the buffer.
491 // https://bugzilla.kernel.org/show_bug.cgi?id=199131
492 // https://twitter.com/CaptainSegfault/status/1112622245531144194
493 //
494 // Unfortunately, we have no choice but to issue more read()s until it either tells us EOF
495 // or EAGAIN. We used to have an optimization here using observer.atEndHint() (when it is
496 // non-null) to avoid a redundant call to read(). Alas...
497 return tryReadInternal(buffer, minBytes, maxBytes, fdBuffer, maxFds, alreadyRead);
498 }
499 }
500
writeInternal(ArrayPtr<const byte> firstPiece,ArrayPtr<const ArrayPtr<const byte>> morePieces,ArrayPtr<const int> fds)501 Promise<void> writeInternal(ArrayPtr<const byte> firstPiece,
502 ArrayPtr<const ArrayPtr<const byte>> morePieces,
503 ArrayPtr<const int> fds) {
504 const size_t iovmax = kj::miniposix::iovMax();
505 // If there are more than IOV_MAX pieces, we'll only write the first IOV_MAX for now, and
506 // then we'll loop later.
507 KJ_STACK_ARRAY(struct iovec, iov, kj::min(1 + morePieces.size(), iovmax), 16, 128);
508 size_t iovTotal = 0;
509
510 // writev() interface is not const-correct. :(
511 iov[0].iov_base = const_cast<byte*>(firstPiece.begin());
512 iov[0].iov_len = firstPiece.size();
513 iovTotal += iov[0].iov_len;
514 for (uint i = 1; i < iov.size(); i++) {
515 iov[i].iov_base = const_cast<byte*>(morePieces[i - 1].begin());
516 iov[i].iov_len = morePieces[i - 1].size();
517 iovTotal += iov[i].iov_len;
518 }
519
520 if (iovTotal == 0) {
521 KJ_REQUIRE(fds.size() == 0, "can't write FDs without bytes");
522 return kj::READY_NOW;
523 }
524
525 ssize_t n;
526 if (fds.size() == 0) {
527 KJ_NONBLOCKING_SYSCALL(n = ::writev(fd, iov.begin(), iov.size()), iovTotal, iov.size()) {
528 // Error.
529
530 // We can't "return kj::READY_NOW;" inside this block because it causes a memory leak due to
531 // a bug that exists in both Clang and GCC:
532 // http://gcc.gnu.org/bugzilla/show_bug.cgi?id=33799
533 // http://llvm.org/bugs/show_bug.cgi?id=12286
534 goto error;
535 }
536 } else {
537 struct msghdr msg;
538 memset(&msg, 0, sizeof(msg));
539 msg.msg_iov = iov.begin();
540 msg.msg_iovlen = iov.size();
541
542 // Allocate space to send a cmsg.
543 size_t msgBytes = CMSG_SPACE(sizeof(int) * fds.size());
544 // On Linux, CMSG_SPACE will align to a word-size boundary, but on Mac it always aligns to a
545 // 32-bit boundary. I guess aligning to 32 bits helps avoid the problem where you
546 // surprisingly end up with space for two file descriptors when you only wanted one. However,
547 // cmsghdr's preferred alignment is word-size (it contains a size_t). If we stack-allocate
548 // the buffer, we need to make sure it is aligned properly (maybe not on x64, but maybe on
549 // other platforms), so we want to allocate an array of words (we use void*). So... we use
550 // CMSG_SPACE() and then additionally round up to deal with Mac.
551 size_t msgWords = (msgBytes + sizeof(void*) - 1) / sizeof(void*);
552 KJ_STACK_ARRAY(void*, cmsgSpace, msgWords, 16, 256);
553 auto cmsgBytes = cmsgSpace.asBytes();
554 memset(cmsgBytes.begin(), 0, cmsgBytes.size());
555 msg.msg_control = cmsgBytes.begin();
556 msg.msg_controllen = msgBytes;
557
558 struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
559 cmsg->cmsg_level = SOL_SOCKET;
560 cmsg->cmsg_type = SCM_RIGHTS;
561 cmsg->cmsg_len = CMSG_LEN(sizeof(int) * fds.size());
562 memcpy(CMSG_DATA(cmsg), fds.begin(), fds.asBytes().size());
563
564 KJ_NONBLOCKING_SYSCALL(n = ::sendmsg(fd, &msg, 0)) {
565 // Error.
566
567 // We can't "return kj::READY_NOW;" inside this block because it causes a memory leak due to
568 // a bug that exists in both Clang and GCC:
569 // http://gcc.gnu.org/bugzilla/show_bug.cgi?id=33799
570 // http://llvm.org/bugs/show_bug.cgi?id=12286
571 goto error;
572 }
573 }
574
575 if (false) {
576 error:
577 return kj::READY_NOW;
578 }
579
580 if (n < 0) {
581 // Got EAGAIN. Nothing was written.
582 return observer.whenBecomesWritable().then([=]() {
583 return writeInternal(firstPiece, morePieces, fds);
584 });
585 } else if (n == 0) {
586 // Why would a sendmsg() with a non-empty message ever return 0 when writing to a stream
587 // socket? If there's no room in the send buffer, it should fail with EAGAIN. If the
588 // connection is closed, it should fail with EPIPE. Various documents and forum posts around
589 // the internet claim this can happen but no one seems to know when. My guess is it can only
590 // happen if we try to send an empty message -- which we didn't. So I think this is
591 // impossible. If it is possible, we need to figure out how to correctly handle it, which
592 // depends on what caused it.
593 //
594 // Note in particular that if 0 is a valid return here, and we sent an SCM_RIGHTS message,
595 // we need to know whether the message was sent or not, in order to decide whether to retry
596 // sending it!
597 KJ_FAIL_ASSERT("non-empty sendmsg() returned 0");
598 }
599
600 // Non-zero bytes were written. This also implies that *all* FDs were written.
601
602 // Discard all data that was written, then issue a new write for what's left (if any).
603 for (;;) {
604 if (n < firstPiece.size()) {
605 // Only part of the first piece was consumed. Wait for buffer space and then write again.
606 firstPiece = firstPiece.slice(n, firstPiece.size());
607 iovTotal -= n;
608
609 if (iovTotal == 0) {
610 // Oops, what actually happened is that we hit the IOV_MAX limit. Don't wait.
611 return writeInternal(firstPiece, morePieces, nullptr);
612 }
613
614 // As with read(), we cannot assume that a short write() really means the write buffer is
615 // full (see comments in the read path above). We have to write again.
616 return writeInternal(firstPiece, morePieces, nullptr);
617 } else if (morePieces.size() == 0) {
618 // First piece was fully-consumed and there are no more pieces, so we're done.
619 KJ_DASSERT(n == firstPiece.size(), n);
620 return READY_NOW;
621 } else {
622 // First piece was fully consumed, so move on to the next piece.
623 n -= firstPiece.size();
624 iovTotal -= firstPiece.size();
625 firstPiece = morePieces[0];
626 morePieces = morePieces.slice(1, morePieces.size());
627 }
628 }
629 }
630 };
631
632 // =======================================================================================
633
634 class SocketAddress {
635 public:
SocketAddress(const void * sockaddr,uint len)636 SocketAddress(const void* sockaddr, uint len): addrlen(len) {
637 KJ_REQUIRE(len <= sizeof(addr), "Sorry, your sockaddr is too big for me.");
638 memcpy(&addr.generic, sockaddr, len);
639 }
640
operator <(const SocketAddress & other) const641 bool operator<(const SocketAddress& other) const {
642 // So we can use std::set<SocketAddress>... see DNS lookup code.
643
644 if (wildcard < other.wildcard) return true;
645 if (wildcard > other.wildcard) return false;
646
647 if (addrlen < other.addrlen) return true;
648 if (addrlen > other.addrlen) return false;
649
650 return memcmp(&addr.generic, &other.addr.generic, addrlen) < 0;
651 }
652
getRaw() const653 const struct sockaddr* getRaw() const { return &addr.generic; }
getRawSize() const654 socklen_t getRawSize() const { return addrlen; }
655
socket(int type) const656 int socket(int type) const {
657 bool isStream = type == SOCK_STREAM;
658
659 int result;
660 #if __linux__ && !__BIONIC__
661 type |= SOCK_NONBLOCK | SOCK_CLOEXEC;
662 #endif
663 KJ_SYSCALL(result = ::socket(addr.generic.sa_family, type, 0));
664
665 if (isStream && (addr.generic.sa_family == AF_INET ||
666 addr.generic.sa_family == AF_INET6)) {
667 // TODO(perf): As a hack for the 0.4 release we are always setting
668 // TCP_NODELAY because Nagle's algorithm pretty much kills Cap'n Proto's
669 // RPC protocol. Later, we should extend the interface to provide more
670 // control over this. Perhaps write() should have a flag which
671 // specifies whether to pass MSG_MORE.
672 int one = 1;
673 KJ_SYSCALL(setsockopt(
674 result, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(one)));
675 }
676
677 return result;
678 }
679
bind(int sockfd) const680 void bind(int sockfd) const {
681 #if !defined(__OpenBSD__)
682 if (wildcard) {
683 // Disable IPV6_V6ONLY because we want to handle both ipv4 and ipv6 on this socket. (The
684 // default value of this option varies across platforms.)
685 int value = 0;
686 KJ_SYSCALL(setsockopt(sockfd, IPPROTO_IPV6, IPV6_V6ONLY, &value, sizeof(value)));
687 }
688 #endif
689
690 KJ_SYSCALL(::bind(sockfd, &addr.generic, addrlen), toString());
691 }
692
getPort() const693 uint getPort() const {
694 switch (addr.generic.sa_family) {
695 case AF_INET: return ntohs(addr.inet4.sin_port);
696 case AF_INET6: return ntohs(addr.inet6.sin6_port);
697 default: return 0;
698 }
699 }
700
toString() const701 String toString() const {
702 if (wildcard) {
703 return str("*:", getPort());
704 }
705
706 switch (addr.generic.sa_family) {
707 case AF_INET: {
708 char buffer[INET6_ADDRSTRLEN];
709 if (inet_ntop(addr.inet4.sin_family, &addr.inet4.sin_addr,
710 buffer, sizeof(buffer)) == nullptr) {
711 KJ_FAIL_SYSCALL("inet_ntop", errno) { break; }
712 return heapString("(inet_ntop error)");
713 }
714 return str(buffer, ':', ntohs(addr.inet4.sin_port));
715 }
716 case AF_INET6: {
717 char buffer[INET6_ADDRSTRLEN];
718 if (inet_ntop(addr.inet6.sin6_family, &addr.inet6.sin6_addr,
719 buffer, sizeof(buffer)) == nullptr) {
720 KJ_FAIL_SYSCALL("inet_ntop", errno) { break; }
721 return heapString("(inet_ntop error)");
722 }
723 return str('[', buffer, "]:", ntohs(addr.inet6.sin6_port));
724 }
725 case AF_UNIX: {
726 auto path = _::safeUnixPath(&addr.unixDomain, addrlen);
727 if (path.size() > 0 && path[0] == '\0') {
728 return str("unix-abstract:", path.slice(1, path.size()));
729 } else {
730 return str("unix:", path);
731 }
732 }
733 default:
734 return str("(unknown address family ", addr.generic.sa_family, ")");
735 }
736 }
737
738 static Promise<Array<SocketAddress>> lookupHost(
739 LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint,
740 _::NetworkFilter& filter);
741 // Perform a DNS lookup.
742
parse(LowLevelAsyncIoProvider & lowLevel,StringPtr str,uint portHint,_::NetworkFilter & filter)743 static Promise<Array<SocketAddress>> parse(
744 LowLevelAsyncIoProvider& lowLevel, StringPtr str, uint portHint, _::NetworkFilter& filter) {
745 // TODO(someday): Allow commas in `str`.
746
747 SocketAddress result;
748
749 if (str.startsWith("unix:")) {
750 StringPtr path = str.slice(strlen("unix:"));
751 KJ_REQUIRE(path.size() < sizeof(addr.unixDomain.sun_path),
752 "Unix domain socket address is too long.", str);
753 KJ_REQUIRE(path.size() == strlen(path.cStr()),
754 "Unix domain socket address contains NULL. Use"
755 " 'unix-abstract:' for the abstract namespace.");
756 result.addr.unixDomain.sun_family = AF_UNIX;
757 strcpy(result.addr.unixDomain.sun_path, path.cStr());
758 result.addrlen = offsetof(struct sockaddr_un, sun_path) + path.size() + 1;
759
760 if (!result.parseAllowedBy(filter)) {
761 KJ_FAIL_REQUIRE("unix sockets blocked by restrictPeers()");
762 return Array<SocketAddress>();
763 }
764
765 auto array = kj::heapArrayBuilder<SocketAddress>(1);
766 array.add(result);
767 return array.finish();
768 }
769
770 if (str.startsWith("unix-abstract:")) {
771 StringPtr path = str.slice(strlen("unix-abstract:"));
772 KJ_REQUIRE(path.size() + 1 < sizeof(addr.unixDomain.sun_path),
773 "Unix domain socket address is too long.", str);
774 result.addr.unixDomain.sun_family = AF_UNIX;
775 result.addr.unixDomain.sun_path[0] = '\0';
776 // although not strictly required by Linux, also copy the trailing
777 // NULL terminator so that we can safely read it back in toString
778 memcpy(result.addr.unixDomain.sun_path + 1, path.cStr(), path.size() + 1);
779 result.addrlen = offsetof(struct sockaddr_un, sun_path) + path.size() + 1;
780
781 if (!result.parseAllowedBy(filter)) {
782 KJ_FAIL_REQUIRE("abstract unix sockets blocked by restrictPeers()");
783 return Array<SocketAddress>();
784 }
785
786 auto array = kj::heapArrayBuilder<SocketAddress>(1);
787 array.add(result);
788 return array.finish();
789 }
790
791 // Try to separate the address and port.
792 ArrayPtr<const char> addrPart;
793 Maybe<StringPtr> portPart;
794
795 int af;
796
797 if (str.startsWith("[")) {
798 // Address starts with a bracket, which is a common way to write an ip6 address with a port,
799 // since without brackets around the address part, the port looks like another segment of
800 // the address.
801 af = AF_INET6;
802 size_t closeBracket = KJ_ASSERT_NONNULL(str.findLast(']'),
803 "Unclosed '[' in address string.", str);
804
805 addrPart = str.slice(1, closeBracket);
806 if (str.size() > closeBracket + 1) {
807 KJ_REQUIRE(str.slice(closeBracket + 1).startsWith(":"),
808 "Expected port suffix after ']'.", str);
809 portPart = str.slice(closeBracket + 2);
810 }
811 } else {
812 KJ_IF_MAYBE(colon, str.findFirst(':')) {
813 if (str.slice(*colon + 1).findFirst(':') == nullptr) {
814 // There is exactly one colon and no brackets, so it must be an ip4 address with port.
815 af = AF_INET;
816 addrPart = str.slice(0, *colon);
817 portPart = str.slice(*colon + 1);
818 } else {
819 // There are two or more colons and no brackets, so the whole thing must be an ip6
820 // address with no port.
821 af = AF_INET6;
822 addrPart = str;
823 }
824 } else {
825 // No colons, so it must be an ip4 address without port.
826 af = AF_INET;
827 addrPart = str;
828 }
829 }
830
831 // Parse the port.
832 unsigned long port;
833 KJ_IF_MAYBE(portText, portPart) {
834 char* endptr;
835 port = strtoul(portText->cStr(), &endptr, 0);
836 if (portText->size() == 0 || *endptr != '\0') {
837 // Not a number. Maybe it's a service name. Fall back to DNS.
838 return lookupHost(lowLevel, kj::heapString(addrPart), kj::heapString(*portText), portHint,
839 filter);
840 }
841 KJ_REQUIRE(port < 65536, "Port number too large.");
842 } else {
843 port = portHint;
844 }
845
846 // Check for wildcard.
847 if (addrPart.size() == 1 && addrPart[0] == '*') {
848 result.wildcard = true;
849 #if defined(__OpenBSD__)
850 // On OpenBSD, all sockets are either v4-only or v6-only, so use v4 as a
851 // temporary workaround for wildcards.
852 result.addrlen = sizeof(addr.inet4);
853 result.addr.inet4.sin_family = AF_INET;
854 result.addr.inet4.sin_port = htons(port);
855 #else
856 // Create an ip6 socket and set IPV6_V6ONLY to 0 later.
857 result.addrlen = sizeof(addr.inet6);
858 result.addr.inet6.sin6_family = AF_INET6;
859 result.addr.inet6.sin6_port = htons(port);
860 #endif
861
862 auto array = kj::heapArrayBuilder<SocketAddress>(1);
863 array.add(result);
864 return array.finish();
865 }
866
867 void* addrTarget;
868 if (af == AF_INET6) {
869 result.addrlen = sizeof(addr.inet6);
870 result.addr.inet6.sin6_family = AF_INET6;
871 result.addr.inet6.sin6_port = htons(port);
872 addrTarget = &result.addr.inet6.sin6_addr;
873 } else {
874 result.addrlen = sizeof(addr.inet4);
875 result.addr.inet4.sin_family = AF_INET;
876 result.addr.inet4.sin_port = htons(port);
877 addrTarget = &result.addr.inet4.sin_addr;
878 }
879
880 if (addrPart.size() < INET6_ADDRSTRLEN - 1) {
881 // addrPart is not necessarily NUL-terminated so we have to make a copy. :(
882 char buffer[INET6_ADDRSTRLEN];
883 memcpy(buffer, addrPart.begin(), addrPart.size());
884 buffer[addrPart.size()] = '\0';
885
886 // OK, parse it!
887 switch (inet_pton(af, buffer, addrTarget)) {
888 case 1: {
889 // success.
890 if (!result.parseAllowedBy(filter)) {
891 KJ_FAIL_REQUIRE("address family blocked by restrictPeers()");
892 return Array<SocketAddress>();
893 }
894
895 auto array = kj::heapArrayBuilder<SocketAddress>(1);
896 array.add(result);
897 return array.finish();
898 }
899 case 0:
900 // It's apparently not a simple address... fall back to DNS.
901 break;
902 default:
903 KJ_FAIL_SYSCALL("inet_pton", errno, af, addrPart);
904 }
905 }
906
907 return lookupHost(lowLevel, kj::heapString(addrPart), nullptr, port, filter);
908 }
909
getLocalAddress(int sockfd)910 static SocketAddress getLocalAddress(int sockfd) {
911 SocketAddress result;
912 result.addrlen = sizeof(addr);
913 KJ_SYSCALL(getsockname(sockfd, &result.addr.generic, &result.addrlen));
914 return result;
915 }
916
allowedBy(LowLevelAsyncIoProvider::NetworkFilter & filter)917 bool allowedBy(LowLevelAsyncIoProvider::NetworkFilter& filter) {
918 return filter.shouldAllow(&addr.generic, addrlen);
919 }
920
parseAllowedBy(_::NetworkFilter & filter)921 bool parseAllowedBy(_::NetworkFilter& filter) {
922 return filter.shouldAllowParse(&addr.generic, addrlen);
923 }
924
925 kj::Own<PeerIdentity> getIdentity(LowLevelAsyncIoProvider& llaiop,
926 LowLevelAsyncIoProvider::NetworkFilter& filter,
927 AsyncIoStream& stream) const;
928
929 private:
SocketAddress()930 SocketAddress() {
931 // We need to memset the whole object 0 otherwise Valgrind gets unhappy when we write it to a
932 // pipe, due to the padding bytes being uninitialized.
933 memset(this, 0, sizeof(*this));
934 }
935
936 socklen_t addrlen;
937 bool wildcard = false;
938 union {
939 struct sockaddr generic;
940 struct sockaddr_in inet4;
941 struct sockaddr_in6 inet6;
942 struct sockaddr_un unixDomain;
943 struct sockaddr_storage storage;
944 } addr;
945
946 struct LookupParams;
947 class LookupReader;
948 };
949
950 class SocketAddress::LookupReader {
951 // Reads SocketAddresses off of a pipe coming from another thread that is performing
952 // getaddrinfo.
953
954 public:
LookupReader(kj::Own<Thread> && thread,kj::Own<AsyncInputStream> && input,_::NetworkFilter & filter)955 LookupReader(kj::Own<Thread>&& thread, kj::Own<AsyncInputStream>&& input,
956 _::NetworkFilter& filter)
957 : thread(kj::mv(thread)), input(kj::mv(input)), filter(filter) {}
958
~LookupReader()959 ~LookupReader() {
960 if (thread) thread->detach();
961 }
962
read()963 Promise<Array<SocketAddress>> read() {
964 return input->tryRead(¤t, sizeof(current), sizeof(current)).then(
965 [this](size_t n) -> Promise<Array<SocketAddress>> {
966 if (n < sizeof(current)) {
967 thread = nullptr;
968 // getaddrinfo()'s docs seem to say it will never return an empty list, but let's check
969 // anyway.
970 KJ_REQUIRE(addresses.size() > 0, "DNS lookup returned no permitted addresses.") { break; }
971 return addresses.releaseAsArray();
972 } else {
973 // getaddrinfo() can return multiple copies of the same address for several reasons.
974 // A major one is that we don't give it a socket type (SOCK_STREAM vs. SOCK_DGRAM), so
975 // it may return two copies of the same address, one for each type, unless it explicitly
976 // knows that the service name given is specific to one type. But we can't tell it a type,
977 // because we don't actually know which one the user wants, and if we specify SOCK_STREAM
978 // while the user specified a UDP service name then they'll get a resolution error which
979 // is lame. (At least, I think that's how it works.)
980 //
981 // So we instead resort to de-duping results.
982 if (alreadySeen.insert(current).second) {
983 if (current.parseAllowedBy(filter)) {
984 addresses.add(current);
985 }
986 }
987 return read();
988 }
989 });
990 }
991
992 private:
993 kj::Own<Thread> thread;
994 kj::Own<AsyncInputStream> input;
995 _::NetworkFilter& filter;
996 SocketAddress current;
997 kj::Vector<SocketAddress> addresses;
998 std::set<SocketAddress> alreadySeen;
999 };
1000
1001 struct SocketAddress::LookupParams {
1002 kj::String host;
1003 kj::String service;
1004 };
1005
lookupHost(LowLevelAsyncIoProvider & lowLevel,kj::String host,kj::String service,uint portHint,_::NetworkFilter & filter)1006 Promise<Array<SocketAddress>> SocketAddress::lookupHost(
1007 LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint,
1008 _::NetworkFilter& filter) {
1009 // This shitty function spawns a thread to run getaddrinfo(). Unfortunately, getaddrinfo() is
1010 // the only cross-platform DNS API and it is blocking.
1011 //
1012 // TODO(perf): Use a thread pool? Maybe kj::Thread should use a thread pool automatically?
1013 // Maybe use the various platform-specific asynchronous DNS libraries? Please do not implement
1014 // a custom DNS resolver...
1015
1016 int fds[2];
1017 #if __linux__ && !__BIONIC__
1018 KJ_SYSCALL(pipe2(fds, O_NONBLOCK | O_CLOEXEC));
1019 #else
1020 KJ_SYSCALL(pipe(fds));
1021 #endif
1022
1023 auto input = lowLevel.wrapInputFd(fds[0], NEW_FD_FLAGS);
1024
1025 int outFd = fds[1];
1026
1027 LookupParams params = { kj::mv(host), kj::mv(service) };
1028
1029 auto thread = heap<Thread>(kj::mvCapture(params, [outFd,portHint](LookupParams&& params) {
1030 FdOutputStream output((AutoCloseFd(outFd)));
1031
1032 struct addrinfo* list;
1033 int status = getaddrinfo(
1034 params.host == "*" ? nullptr : params.host.cStr(),
1035 params.service == nullptr ? nullptr : params.service.cStr(),
1036 nullptr, &list);
1037 if (status == 0) {
1038 KJ_DEFER(freeaddrinfo(list));
1039
1040 struct addrinfo* cur = list;
1041 while (cur != nullptr) {
1042 if (params.service == nullptr) {
1043 switch (cur->ai_addr->sa_family) {
1044 case AF_INET:
1045 ((struct sockaddr_in*)cur->ai_addr)->sin_port = htons(portHint);
1046 break;
1047 case AF_INET6:
1048 ((struct sockaddr_in6*)cur->ai_addr)->sin6_port = htons(portHint);
1049 break;
1050 default:
1051 break;
1052 }
1053 }
1054
1055 SocketAddress addr;
1056 if (params.host == "*") {
1057 // Set up a wildcard SocketAddress. Only use the port number returned by getaddrinfo().
1058 addr.wildcard = true;
1059 addr.addrlen = sizeof(addr.addr.inet6);
1060 addr.addr.inet6.sin6_family = AF_INET6;
1061 switch (cur->ai_addr->sa_family) {
1062 case AF_INET:
1063 addr.addr.inet6.sin6_port = ((struct sockaddr_in*)cur->ai_addr)->sin_port;
1064 break;
1065 case AF_INET6:
1066 addr.addr.inet6.sin6_port = ((struct sockaddr_in6*)cur->ai_addr)->sin6_port;
1067 break;
1068 default:
1069 addr.addr.inet6.sin6_port = portHint;
1070 break;
1071 }
1072 } else {
1073 addr.addrlen = cur->ai_addrlen;
1074 memcpy(&addr.addr.generic, cur->ai_addr, cur->ai_addrlen);
1075 }
1076 KJ_ASSERT_CAN_MEMCPY(SocketAddress);
1077 output.write(&addr, sizeof(addr));
1078 cur = cur->ai_next;
1079 }
1080 } else if (status == EAI_SYSTEM) {
1081 KJ_FAIL_SYSCALL("getaddrinfo", errno, params.host, params.service) {
1082 return;
1083 }
1084 } else {
1085 KJ_FAIL_REQUIRE("DNS lookup failed.",
1086 params.host, params.service, gai_strerror(status)) {
1087 return;
1088 }
1089 }
1090 }));
1091
1092 auto reader = heap<LookupReader>(kj::mv(thread), kj::mv(input), filter);
1093 return reader->read().attach(kj::mv(reader));
1094 }
1095
1096 // =======================================================================================
1097
1098 class FdConnectionReceiver final: public ConnectionReceiver, public OwnedFileDescriptor {
1099 public:
FdConnectionReceiver(LowLevelAsyncIoProvider & lowLevel,UnixEventPort & eventPort,int fd,LowLevelAsyncIoProvider::NetworkFilter & filter,uint flags)1100 FdConnectionReceiver(LowLevelAsyncIoProvider& lowLevel,
1101 UnixEventPort& eventPort, int fd,
1102 LowLevelAsyncIoProvider::NetworkFilter& filter, uint flags)
1103 : OwnedFileDescriptor(fd, flags), lowLevel(lowLevel), eventPort(eventPort), filter(filter),
1104 observer(eventPort, fd, UnixEventPort::FdObserver::OBSERVE_READ) {}
1105
accept()1106 Promise<Own<AsyncIoStream>> accept() override {
1107 return acceptImpl(false).then([](AuthenticatedStream&& a) { return kj::mv(a.stream); });
1108 }
1109
acceptAuthenticated()1110 Promise<AuthenticatedStream> acceptAuthenticated() override {
1111 return acceptImpl(true);
1112 }
1113
acceptImpl(bool authenticated)1114 Promise<AuthenticatedStream> acceptImpl(bool authenticated) {
1115 int newFd;
1116
1117 struct sockaddr_storage addr;
1118 socklen_t addrlen = sizeof(addr);
1119
1120 retry:
1121 #if __linux__ && !__BIONIC__
1122 newFd = ::accept4(fd, reinterpret_cast<struct sockaddr*>(&addr), &addrlen,
1123 SOCK_NONBLOCK | SOCK_CLOEXEC);
1124 #else
1125 newFd = ::accept(fd, reinterpret_cast<struct sockaddr*>(&addr), &addrlen);
1126 #endif
1127
1128 if (newFd >= 0) {
1129 kj::AutoCloseFd ownFd(newFd);
1130 if (!filter.shouldAllow(reinterpret_cast<struct sockaddr*>(&addr), addrlen)) {
1131 // Ignore disallowed address.
1132 return acceptImpl(authenticated);
1133 } else {
1134 // TODO(perf): As a hack for the 0.4 release we are always setting
1135 // TCP_NODELAY because Nagle's algorithm pretty much kills Cap'n Proto's
1136 // RPC protocol. Later, we should extend the interface to provide more
1137 // control over this. Perhaps write() should have a flag which
1138 // specifies whether to pass MSG_MORE.
1139 int one = 1;
1140 KJ_SYSCALL_HANDLE_ERRORS(::setsockopt(
1141 ownFd.get(), IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(one))) {
1142 case EOPNOTSUPP:
1143 case ENOPROTOOPT: // (returned for AF_UNIX in cygwin)
1144 break;
1145 default:
1146 KJ_FAIL_SYSCALL("setsocketopt(IPPROTO_TCP, TCP_NODELAY)", error);
1147 }
1148
1149 AuthenticatedStream result;
1150 result.stream = heap<AsyncStreamFd>(eventPort, ownFd.release(), NEW_FD_FLAGS);
1151 if (authenticated) {
1152 result.peerIdentity = SocketAddress(reinterpret_cast<struct sockaddr*>(&addr), addrlen)
1153 .getIdentity(lowLevel, filter, *result.stream);
1154 }
1155 return kj::mv(result);
1156 }
1157 } else {
1158 int error = errno;
1159
1160 switch (error) {
1161 case EAGAIN:
1162 #if EAGAIN != EWOULDBLOCK
1163 case EWOULDBLOCK:
1164 #endif
1165 // Not ready yet.
1166 return observer.whenBecomesReadable().then([this,authenticated]() {
1167 return acceptImpl(authenticated);
1168 });
1169
1170 case EINTR:
1171 case ENETDOWN:
1172 #ifdef EPROTO
1173 // EPROTO is not defined on OpenBSD.
1174 case EPROTO:
1175 #endif
1176 case EHOSTDOWN:
1177 case EHOSTUNREACH:
1178 case ENETUNREACH:
1179 case ECONNABORTED:
1180 case ETIMEDOUT:
1181 // According to the Linux man page, accept() may report an error if the accepted
1182 // connection is already broken. In this case, we really ought to just ignore it and
1183 // keep waiting. But it's hard to say exactly what errors are such network errors and
1184 // which ones are permanent errors. We've made a guess here.
1185 goto retry;
1186
1187 default:
1188 KJ_FAIL_SYSCALL("accept", error);
1189 }
1190
1191 }
1192 }
1193
getPort()1194 uint getPort() override {
1195 return SocketAddress::getLocalAddress(fd).getPort();
1196 }
1197
getsockopt(int level,int option,void * value,uint * length)1198 void getsockopt(int level, int option, void* value, uint* length) override {
1199 socklen_t socklen = *length;
1200 KJ_SYSCALL(::getsockopt(fd, level, option, value, &socklen));
1201 *length = socklen;
1202 }
setsockopt(int level,int option,const void * value,uint length)1203 void setsockopt(int level, int option, const void* value, uint length) override {
1204 KJ_SYSCALL(::setsockopt(fd, level, option, value, length));
1205 }
getsockname(struct sockaddr * addr,uint * length)1206 void getsockname(struct sockaddr* addr, uint* length) override {
1207 socklen_t socklen = *length;
1208 KJ_SYSCALL(::getsockname(fd, addr, &socklen));
1209 *length = socklen;
1210 }
1211
1212 public:
1213 LowLevelAsyncIoProvider& lowLevel;
1214 UnixEventPort& eventPort;
1215 LowLevelAsyncIoProvider::NetworkFilter& filter;
1216 UnixEventPort::FdObserver observer;
1217 };
1218
1219 class DatagramPortImpl final: public DatagramPort, public OwnedFileDescriptor {
1220 public:
DatagramPortImpl(LowLevelAsyncIoProvider & lowLevel,UnixEventPort & eventPort,int fd,LowLevelAsyncIoProvider::NetworkFilter & filter,uint flags)1221 DatagramPortImpl(LowLevelAsyncIoProvider& lowLevel, UnixEventPort& eventPort, int fd,
1222 LowLevelAsyncIoProvider::NetworkFilter& filter, uint flags)
1223 : OwnedFileDescriptor(fd, flags), lowLevel(lowLevel), eventPort(eventPort), filter(filter),
1224 observer(eventPort, fd, UnixEventPort::FdObserver::OBSERVE_READ |
1225 UnixEventPort::FdObserver::OBSERVE_WRITE) {}
1226
1227 Promise<size_t> send(const void* buffer, size_t size, NetworkAddress& destination) override;
1228 Promise<size_t> send(
1229 ArrayPtr<const ArrayPtr<const byte>> pieces, NetworkAddress& destination) override;
1230
1231 class ReceiverImpl;
1232
1233 Own<DatagramReceiver> makeReceiver(DatagramReceiver::Capacity capacity) override;
1234
getPort()1235 uint getPort() override {
1236 return SocketAddress::getLocalAddress(fd).getPort();
1237 }
1238
getsockopt(int level,int option,void * value,uint * length)1239 void getsockopt(int level, int option, void* value, uint* length) override {
1240 socklen_t socklen = *length;
1241 KJ_SYSCALL(::getsockopt(fd, level, option, value, &socklen));
1242 *length = socklen;
1243 }
setsockopt(int level,int option,const void * value,uint length)1244 void setsockopt(int level, int option, const void* value, uint length) override {
1245 KJ_SYSCALL(::setsockopt(fd, level, option, value, length));
1246 }
1247
1248 public:
1249 LowLevelAsyncIoProvider& lowLevel;
1250 UnixEventPort& eventPort;
1251 LowLevelAsyncIoProvider::NetworkFilter& filter;
1252 UnixEventPort::FdObserver observer;
1253 };
1254
1255 class LowLevelAsyncIoProviderImpl final: public LowLevelAsyncIoProvider {
1256 public:
LowLevelAsyncIoProviderImpl()1257 LowLevelAsyncIoProviderImpl()
1258 : eventLoop(eventPort), waitScope(eventLoop) {}
1259
getWaitScope()1260 inline WaitScope& getWaitScope() { return waitScope; }
1261
wrapInputFd(int fd,uint flags=0)1262 Own<AsyncInputStream> wrapInputFd(int fd, uint flags = 0) override {
1263 return heap<AsyncStreamFd>(eventPort, fd, flags);
1264 }
wrapOutputFd(int fd,uint flags=0)1265 Own<AsyncOutputStream> wrapOutputFd(int fd, uint flags = 0) override {
1266 return heap<AsyncStreamFd>(eventPort, fd, flags);
1267 }
wrapSocketFd(int fd,uint flags=0)1268 Own<AsyncIoStream> wrapSocketFd(int fd, uint flags = 0) override {
1269 return heap<AsyncStreamFd>(eventPort, fd, flags);
1270 }
wrapUnixSocketFd(Fd fd,uint flags=0)1271 Own<AsyncCapabilityStream> wrapUnixSocketFd(Fd fd, uint flags = 0) override {
1272 return heap<AsyncStreamFd>(eventPort, fd, flags);
1273 }
wrapConnectingSocketFd(int fd,const struct sockaddr * addr,uint addrlen,uint flags=0)1274 Promise<Own<AsyncIoStream>> wrapConnectingSocketFd(
1275 int fd, const struct sockaddr* addr, uint addrlen, uint flags = 0) override {
1276 // It's important that we construct the AsyncStreamFd first, so that `flags` are honored,
1277 // especially setting nonblocking mode and taking ownership.
1278 auto result = heap<AsyncStreamFd>(eventPort, fd, flags);
1279
1280 // Unfortunately connect() doesn't fit the mold of KJ_NONBLOCKING_SYSCALL, since it indicates
1281 // non-blocking using EINPROGRESS.
1282 for (;;) {
1283 if (::connect(fd, addr, addrlen) < 0) {
1284 int error = errno;
1285 if (error == EINPROGRESS) {
1286 // Fine.
1287 break;
1288 } else if (error != EINTR) {
1289 KJ_FAIL_SYSCALL("connect()", error) { break; }
1290 return Own<AsyncIoStream>();
1291 }
1292 } else {
1293 // no error
1294 break;
1295 }
1296 }
1297
1298 auto connected = result->waitConnected();
1299 return connected.then(kj::mvCapture(result, [fd](Own<AsyncIoStream>&& stream) {
1300 int err;
1301 socklen_t errlen = sizeof(err);
1302 KJ_SYSCALL(getsockopt(fd, SOL_SOCKET, SO_ERROR, &err, &errlen));
1303 if (err != 0) {
1304 KJ_FAIL_SYSCALL("connect()", err) { break; }
1305 }
1306 return kj::mv(stream);
1307 }));
1308 }
wrapListenSocketFd(int fd,NetworkFilter & filter,uint flags=0)1309 Own<ConnectionReceiver> wrapListenSocketFd(
1310 int fd, NetworkFilter& filter, uint flags = 0) override {
1311 return heap<FdConnectionReceiver>(*this, eventPort, fd, filter, flags);
1312 }
wrapDatagramSocketFd(int fd,NetworkFilter & filter,uint flags=0)1313 Own<DatagramPort> wrapDatagramSocketFd(
1314 int fd, NetworkFilter& filter, uint flags = 0) override {
1315 return heap<DatagramPortImpl>(*this, eventPort, fd, filter, flags);
1316 }
1317
getTimer()1318 Timer& getTimer() override { return eventPort.getTimer(); }
1319
getEventPort()1320 UnixEventPort& getEventPort() { return eventPort; }
1321
1322 private:
1323 UnixEventPort eventPort;
1324 EventLoop eventLoop;
1325 WaitScope waitScope;
1326 };
1327
1328 // =======================================================================================
1329
1330 class NetworkAddressImpl final: public NetworkAddress {
1331 public:
NetworkAddressImpl(LowLevelAsyncIoProvider & lowLevel,LowLevelAsyncIoProvider::NetworkFilter & filter,Array<SocketAddress> addrs)1332 NetworkAddressImpl(LowLevelAsyncIoProvider& lowLevel,
1333 LowLevelAsyncIoProvider::NetworkFilter& filter,
1334 Array<SocketAddress> addrs)
1335 : lowLevel(lowLevel), filter(filter), addrs(kj::mv(addrs)) {}
1336
connect()1337 Promise<Own<AsyncIoStream>> connect() override {
1338 auto addrsCopy = heapArray(addrs.asPtr());
1339 auto promise = connectImpl(lowLevel, filter, addrsCopy, false);
1340 return promise.attach(kj::mv(addrsCopy))
1341 .then([](AuthenticatedStream&& a) { return kj::mv(a.stream); });
1342 }
1343
connectAuthenticated()1344 Promise<AuthenticatedStream> connectAuthenticated() override {
1345 auto addrsCopy = heapArray(addrs.asPtr());
1346 auto promise = connectImpl(lowLevel, filter, addrsCopy, true);
1347 return promise.attach(kj::mv(addrsCopy));
1348 }
1349
listen()1350 Own<ConnectionReceiver> listen() override {
1351 if (addrs.size() > 1) {
1352 KJ_LOG(WARNING, "Bind address resolved to multiple addresses. Only the first address will "
1353 "be used. If this is incorrect, specify the address numerically. This may be fixed "
1354 "in the future.", addrs[0].toString());
1355 }
1356
1357 int fd = addrs[0].socket(SOCK_STREAM);
1358
1359 {
1360 KJ_ON_SCOPE_FAILURE(close(fd));
1361
1362 // We always enable SO_REUSEADDR because having to take your server down for five minutes
1363 // before it can restart really sucks.
1364 int optval = 1;
1365 KJ_SYSCALL(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval)));
1366
1367 addrs[0].bind(fd);
1368
1369 // TODO(someday): Let queue size be specified explicitly in string addresses.
1370 KJ_SYSCALL(::listen(fd, SOMAXCONN));
1371 }
1372
1373 return lowLevel.wrapListenSocketFd(fd, filter, NEW_FD_FLAGS);
1374 }
1375
bindDatagramPort()1376 Own<DatagramPort> bindDatagramPort() override {
1377 if (addrs.size() > 1) {
1378 KJ_LOG(WARNING, "Bind address resolved to multiple addresses. Only the first address will "
1379 "be used. If this is incorrect, specify the address numerically. This may be fixed "
1380 "in the future.", addrs[0].toString());
1381 }
1382
1383 int fd = addrs[0].socket(SOCK_DGRAM);
1384
1385 {
1386 KJ_ON_SCOPE_FAILURE(close(fd));
1387
1388 // We always enable SO_REUSEADDR because having to take your server down for five minutes
1389 // before it can restart really sucks.
1390 int optval = 1;
1391 KJ_SYSCALL(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval)));
1392
1393 addrs[0].bind(fd);
1394 }
1395
1396 return lowLevel.wrapDatagramSocketFd(fd, filter, NEW_FD_FLAGS);
1397 }
1398
clone()1399 Own<NetworkAddress> clone() override {
1400 return kj::heap<NetworkAddressImpl>(lowLevel, filter, kj::heapArray(addrs.asPtr()));
1401 }
1402
toString()1403 String toString() override {
1404 return strArray(KJ_MAP(addr, addrs) { return addr.toString(); }, ",");
1405 }
1406
chooseOneAddress()1407 const SocketAddress& chooseOneAddress() {
1408 KJ_REQUIRE(addrs.size() > 0, "No addresses available.");
1409 return addrs[counter++ % addrs.size()];
1410 }
1411
1412 private:
1413 LowLevelAsyncIoProvider& lowLevel;
1414 LowLevelAsyncIoProvider::NetworkFilter& filter;
1415 Array<SocketAddress> addrs;
1416 uint counter = 0;
1417
connectImpl(LowLevelAsyncIoProvider & lowLevel,LowLevelAsyncIoProvider::NetworkFilter & filter,ArrayPtr<SocketAddress> addrs,bool authenticated)1418 static Promise<AuthenticatedStream> connectImpl(
1419 LowLevelAsyncIoProvider& lowLevel,
1420 LowLevelAsyncIoProvider::NetworkFilter& filter,
1421 ArrayPtr<SocketAddress> addrs,
1422 bool authenticated) {
1423 KJ_ASSERT(addrs.size() > 0);
1424
1425 return kj::evalNow([&]() -> Promise<Own<AsyncIoStream>> {
1426 if (!addrs[0].allowedBy(filter)) {
1427 return KJ_EXCEPTION(FAILED, "connect() blocked by restrictPeers()");
1428 } else {
1429 int fd = addrs[0].socket(SOCK_STREAM);
1430 return lowLevel.wrapConnectingSocketFd(
1431 fd, addrs[0].getRaw(), addrs[0].getRawSize(), NEW_FD_FLAGS);
1432 }
1433 }).then([&lowLevel,&filter,addrs,authenticated](Own<AsyncIoStream>&& stream)
1434 -> Promise<AuthenticatedStream> {
1435 // Success, pass along.
1436 AuthenticatedStream result;
1437 result.stream = kj::mv(stream);
1438 if (authenticated) {
1439 result.peerIdentity = addrs[0].getIdentity(lowLevel, filter, *result.stream);
1440 }
1441 return kj::mv(result);
1442 }, [&lowLevel,&filter,addrs,authenticated](Exception&& exception) mutable
1443 -> Promise<AuthenticatedStream> {
1444 // Connect failed.
1445 if (addrs.size() > 1) {
1446 // Try the next address instead.
1447 return connectImpl(lowLevel, filter, addrs.slice(1, addrs.size()), authenticated);
1448 } else {
1449 // No more addresses to try, so propagate the exception.
1450 return kj::mv(exception);
1451 }
1452 });
1453 }
1454 };
1455
getIdentity(kj::LowLevelAsyncIoProvider & llaiop,LowLevelAsyncIoProvider::NetworkFilter & filter,AsyncIoStream & stream) const1456 kj::Own<PeerIdentity> SocketAddress::getIdentity(kj::LowLevelAsyncIoProvider& llaiop,
1457 LowLevelAsyncIoProvider::NetworkFilter& filter,
1458 AsyncIoStream& stream) const {
1459 switch (addr.generic.sa_family) {
1460 case AF_INET:
1461 case AF_INET6: {
1462 auto builder = kj::heapArrayBuilder<SocketAddress>(1);
1463 builder.add(*this);
1464 return NetworkPeerIdentity::newInstance(
1465 kj::heap<NetworkAddressImpl>(llaiop, filter, builder.finish()));
1466 }
1467 case AF_UNIX: {
1468 LocalPeerIdentity::Credentials result;
1469
1470 // There is little documentation on what happens when the uid/pid can't be obtained, but I've
1471 // seen vague references on the internet saying that a PID of 0 and a UID of uid_t(-1) are used
1472 // as invalid values.
1473
1474 #if defined(SO_PEERCRED)
1475 struct ucred creds;
1476 uint length = sizeof(creds);
1477 stream.getsockopt(SOL_SOCKET, SO_PEERCRED, &creds, &length);
1478 if (creds.pid > 0) {
1479 result.pid = creds.pid;
1480 }
1481 if (creds.uid != static_cast<uid_t>(-1)) {
1482 result.uid = creds.uid;
1483 }
1484
1485 #elif defined(LOCAL_PEERCRED)
1486 // MacOS / FreeBSD
1487 struct xucred creds;
1488 uint length = sizeof(creds);
1489 #if defined SOL_LOCAL
1490 stream.getsockopt(SOL_LOCAL, LOCAL_PEERCRED, &creds, &length);
1491 #else
1492 stream.getsockopt(0, LOCAL_PEERCRED, &creds, &length);
1493 #endif
1494 KJ_ASSERT(length == sizeof(creds));
1495 if (creds.cr_uid != static_cast<uid_t>(-1)) {
1496 result.uid = creds.cr_uid;
1497 }
1498
1499 #if defined(LOCAL_PEERPID)
1500 // MacOS only?
1501 pid_t pid;
1502 length = sizeof(pid);
1503 stream.getsockopt(SOL_LOCAL, LOCAL_PEERPID, &pid, &length);
1504 KJ_ASSERT(length == sizeof(pid));
1505 if (pid > 0) {
1506 result.pid = pid;
1507 }
1508 #endif
1509 #endif
1510
1511 return LocalPeerIdentity::newInstance(result);
1512 }
1513 default:
1514 return UnknownPeerIdentity::newInstance();
1515 }
1516 }
1517
1518 class SocketNetwork final: public Network {
1519 public:
SocketNetwork(LowLevelAsyncIoProvider & lowLevel)1520 explicit SocketNetwork(LowLevelAsyncIoProvider& lowLevel): lowLevel(lowLevel) {}
SocketNetwork(SocketNetwork & parent,kj::ArrayPtr<const kj::StringPtr> allow,kj::ArrayPtr<const kj::StringPtr> deny)1521 explicit SocketNetwork(SocketNetwork& parent,
1522 kj::ArrayPtr<const kj::StringPtr> allow,
1523 kj::ArrayPtr<const kj::StringPtr> deny)
1524 : lowLevel(parent.lowLevel), filter(allow, deny, parent.filter) {}
1525
parseAddress(StringPtr addr,uint portHint=0)1526 Promise<Own<NetworkAddress>> parseAddress(StringPtr addr, uint portHint = 0) override {
1527 return evalLater(mvCapture(heapString(addr), [this,portHint](String&& addr) {
1528 return SocketAddress::parse(lowLevel, addr, portHint, filter);
1529 })).then([this](Array<SocketAddress> addresses) -> Own<NetworkAddress> {
1530 return heap<NetworkAddressImpl>(lowLevel, filter, kj::mv(addresses));
1531 });
1532 }
1533
getSockaddr(const void * sockaddr,uint len)1534 Own<NetworkAddress> getSockaddr(const void* sockaddr, uint len) override {
1535 auto array = kj::heapArrayBuilder<SocketAddress>(1);
1536 array.add(SocketAddress(sockaddr, len));
1537 KJ_REQUIRE(array[0].allowedBy(filter), "address blocked by restrictPeers()") { break; }
1538 return Own<NetworkAddress>(heap<NetworkAddressImpl>(lowLevel, filter, array.finish()));
1539 }
1540
restrictPeers(kj::ArrayPtr<const kj::StringPtr> allow,kj::ArrayPtr<const kj::StringPtr> deny=nullptr)1541 Own<Network> restrictPeers(
1542 kj::ArrayPtr<const kj::StringPtr> allow,
1543 kj::ArrayPtr<const kj::StringPtr> deny = nullptr) override {
1544 return heap<SocketNetwork>(*this, allow, deny);
1545 }
1546
1547 private:
1548 LowLevelAsyncIoProvider& lowLevel;
1549 _::NetworkFilter filter;
1550 };
1551
1552 // =======================================================================================
1553
send(const void * buffer,size_t size,NetworkAddress & destination)1554 Promise<size_t> DatagramPortImpl::send(
1555 const void* buffer, size_t size, NetworkAddress& destination) {
1556 auto& addr = downcast<NetworkAddressImpl>(destination).chooseOneAddress();
1557
1558 ssize_t n;
1559 KJ_NONBLOCKING_SYSCALL(n = sendto(fd, buffer, size, 0, addr.getRaw(), addr.getRawSize()));
1560 if (n < 0) {
1561 // Write buffer full.
1562 return observer.whenBecomesWritable().then([this, buffer, size, &destination]() {
1563 return send(buffer, size, destination);
1564 });
1565 } else {
1566 // If less than the whole message was sent, then it got truncated, and there's nothing we can
1567 // do about it.
1568 return n;
1569 }
1570 }
1571
send(ArrayPtr<const ArrayPtr<const byte>> pieces,NetworkAddress & destination)1572 Promise<size_t> DatagramPortImpl::send(
1573 ArrayPtr<const ArrayPtr<const byte>> pieces, NetworkAddress& destination) {
1574 struct msghdr msg;
1575 memset(&msg, 0, sizeof(msg));
1576
1577 auto& addr = downcast<NetworkAddressImpl>(destination).chooseOneAddress();
1578 msg.msg_name = const_cast<void*>(implicitCast<const void*>(addr.getRaw()));
1579 msg.msg_namelen = addr.getRawSize();
1580
1581 const size_t iovmax = kj::miniposix::iovMax();
1582 KJ_STACK_ARRAY(struct iovec, iov, kj::min(pieces.size(), iovmax), 16, 64);
1583
1584 for (size_t i: kj::indices(pieces)) {
1585 iov[i].iov_base = const_cast<void*>(implicitCast<const void*>(pieces[i].begin()));
1586 iov[i].iov_len = pieces[i].size();
1587 }
1588
1589 Array<byte> extra;
1590 if (pieces.size() > iovmax) {
1591 // Too many pieces, but we can't use multiple syscalls because they'd send separate
1592 // datagrams. We'll have to copy the trailing pieces into a temporary array.
1593 //
1594 // TODO(perf): On Linux we could use multiple syscalls via MSG_MORE or sendmsg/sendmmsg.
1595 size_t extraSize = 0;
1596 for (size_t i = iovmax - 1; i < pieces.size(); i++) {
1597 extraSize += pieces[i].size();
1598 }
1599 extra = kj::heapArray<byte>(extraSize);
1600 extraSize = 0;
1601 for (size_t i = iovmax - 1; i < pieces.size(); i++) {
1602 memcpy(extra.begin() + extraSize, pieces[i].begin(), pieces[i].size());
1603 extraSize += pieces[i].size();
1604 }
1605 iov.back().iov_base = extra.begin();
1606 iov.back().iov_len = extra.size();
1607 }
1608
1609 msg.msg_iov = iov.begin();
1610 msg.msg_iovlen = iov.size();
1611
1612 ssize_t n;
1613 KJ_NONBLOCKING_SYSCALL(n = sendmsg(fd, &msg, 0));
1614 if (n < 0) {
1615 // Write buffer full.
1616 return observer.whenBecomesWritable().then([this, pieces, &destination]() {
1617 return send(pieces, destination);
1618 });
1619 } else {
1620 // If less than the whole message was sent, then it was truncated, and there's nothing we can
1621 // do about that now.
1622 return n;
1623 }
1624 }
1625
1626 class DatagramPortImpl::ReceiverImpl final: public DatagramReceiver {
1627 public:
ReceiverImpl(DatagramPortImpl & port,Capacity capacity)1628 explicit ReceiverImpl(DatagramPortImpl& port, Capacity capacity)
1629 : port(port),
1630 contentBuffer(heapArray<byte>(capacity.content)),
1631 ancillaryBuffer(capacity.ancillary > 0 ? heapArray<byte>(capacity.ancillary)
1632 : Array<byte>(nullptr)) {}
1633
receive()1634 Promise<void> receive() override {
1635 struct msghdr msg;
1636 memset(&msg, 0, sizeof(msg));
1637
1638 struct sockaddr_storage addr;
1639 memset(&addr, 0, sizeof(addr));
1640 msg.msg_name = &addr;
1641 msg.msg_namelen = sizeof(addr);
1642
1643 struct iovec iov;
1644 iov.iov_base = contentBuffer.begin();
1645 iov.iov_len = contentBuffer.size();
1646 msg.msg_iov = &iov;
1647 msg.msg_iovlen = 1;
1648 msg.msg_control = ancillaryBuffer.begin();
1649 msg.msg_controllen = ancillaryBuffer.size();
1650
1651 ssize_t n;
1652 KJ_NONBLOCKING_SYSCALL(n = recvmsg(port.fd, &msg, 0));
1653
1654 if (n < 0) {
1655 // No data available. Wait.
1656 return port.observer.whenBecomesReadable().then([this]() {
1657 return receive();
1658 });
1659 } else {
1660 if (!port.filter.shouldAllow(reinterpret_cast<const struct sockaddr*>(msg.msg_name),
1661 msg.msg_namelen)) {
1662 // Ignore message from disallowed source.
1663 return receive();
1664 }
1665
1666 receivedSize = n;
1667 contentTruncated = msg.msg_flags & MSG_TRUNC;
1668
1669 source.emplace(port.lowLevel, port.filter, msg.msg_name, msg.msg_namelen);
1670
1671 ancillaryList.resize(0);
1672 ancillaryTruncated = msg.msg_flags & MSG_CTRUNC;
1673
1674 for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); cmsg != nullptr;
1675 cmsg = CMSG_NXTHDR(&msg, cmsg)) {
1676 // On some platforms (OSX), a cmsghdr's length may cross the end of the ancillary buffer
1677 // when truncated. On other platforms (Linux) the length in cmsghdr will itself be
1678 // truncated to fit within the buffer.
1679
1680 #if __APPLE__
1681 // On MacOS, `CMSG_SPACE(0)` triggers a bogus warning.
1682 #pragma GCC diagnostic ignored "-Wnull-pointer-arithmetic"
1683 #endif
1684 const byte* pos = reinterpret_cast<const byte*>(cmsg);
1685 size_t available = ancillaryBuffer.end() - pos;
1686 if (available < CMSG_SPACE(0)) {
1687 // The buffer ends in the middle of the header. We can't use this message.
1688 // (On Linux, this never happens, because the message is not included if there isn't
1689 // space for a header. I'm not sure how other systems behave, though, so let's be safe.)
1690 break;
1691 }
1692
1693 // OK, we know the cmsghdr is valid, at least.
1694
1695 // Find the start of the message payload.
1696 const byte* begin = (const byte *)CMSG_DATA(cmsg);
1697
1698 // Cap the message length to the available space.
1699 const byte* end = pos + kj::min(available, cmsg->cmsg_len);
1700
1701 ancillaryList.add(AncillaryMessage(
1702 cmsg->cmsg_level, cmsg->cmsg_type, arrayPtr(begin, end)));
1703 }
1704
1705 return READY_NOW;
1706 }
1707 }
1708
getContent()1709 MaybeTruncated<ArrayPtr<const byte>> getContent() override {
1710 return { contentBuffer.slice(0, receivedSize), contentTruncated };
1711 }
1712
getAncillary()1713 MaybeTruncated<ArrayPtr<const AncillaryMessage>> getAncillary() override {
1714 return { ancillaryList.asPtr(), ancillaryTruncated };
1715 }
1716
getSource()1717 NetworkAddress& getSource() override {
1718 return KJ_REQUIRE_NONNULL(source, "Haven't sent a message yet.").abstract;
1719 }
1720
1721 private:
1722 DatagramPortImpl& port;
1723 Array<byte> contentBuffer;
1724 Array<byte> ancillaryBuffer;
1725 Vector<AncillaryMessage> ancillaryList;
1726 size_t receivedSize = 0;
1727 bool contentTruncated = false;
1728 bool ancillaryTruncated = false;
1729
1730 struct StoredAddress {
StoredAddresskj::__anona1d21b730111::DatagramPortImpl::ReceiverImpl::StoredAddress1731 StoredAddress(LowLevelAsyncIoProvider& lowLevel, LowLevelAsyncIoProvider::NetworkFilter& filter,
1732 const void* sockaddr, uint length)
1733 : raw(sockaddr, length),
1734 abstract(lowLevel, filter, Array<SocketAddress>(&raw, 1, NullArrayDisposer::instance)) {}
1735
1736 SocketAddress raw;
1737 NetworkAddressImpl abstract;
1738 };
1739
1740 kj::Maybe<StoredAddress> source;
1741 };
1742
makeReceiver(DatagramReceiver::Capacity capacity)1743 Own<DatagramReceiver> DatagramPortImpl::makeReceiver(DatagramReceiver::Capacity capacity) {
1744 return kj::heap<ReceiverImpl>(*this, capacity);
1745 }
1746
1747 // =======================================================================================
1748
1749 class AsyncIoProviderImpl final: public AsyncIoProvider {
1750 public:
AsyncIoProviderImpl(LowLevelAsyncIoProvider & lowLevel)1751 AsyncIoProviderImpl(LowLevelAsyncIoProvider& lowLevel)
1752 : lowLevel(lowLevel), network(lowLevel) {}
1753
newOneWayPipe()1754 OneWayPipe newOneWayPipe() override {
1755 int fds[2];
1756 #if __linux__ && !__BIONIC__
1757 KJ_SYSCALL(pipe2(fds, O_NONBLOCK | O_CLOEXEC));
1758 #else
1759 KJ_SYSCALL(pipe(fds));
1760 #endif
1761 return OneWayPipe {
1762 lowLevel.wrapInputFd(fds[0], NEW_FD_FLAGS),
1763 lowLevel.wrapOutputFd(fds[1], NEW_FD_FLAGS)
1764 };
1765 }
1766
newTwoWayPipe()1767 TwoWayPipe newTwoWayPipe() override {
1768 int fds[2];
1769 int type = SOCK_STREAM;
1770 #if __linux__ && !__BIONIC__
1771 type |= SOCK_NONBLOCK | SOCK_CLOEXEC;
1772 #endif
1773 KJ_SYSCALL(socketpair(AF_UNIX, type, 0, fds));
1774 return TwoWayPipe { {
1775 lowLevel.wrapSocketFd(fds[0], NEW_FD_FLAGS),
1776 lowLevel.wrapSocketFd(fds[1], NEW_FD_FLAGS)
1777 } };
1778 }
1779
newCapabilityPipe()1780 CapabilityPipe newCapabilityPipe() override {
1781 int fds[2];
1782 int type = SOCK_STREAM;
1783 #if __linux__ && !__BIONIC__
1784 type |= SOCK_NONBLOCK | SOCK_CLOEXEC;
1785 #endif
1786 KJ_SYSCALL(socketpair(AF_UNIX, type, 0, fds));
1787 return CapabilityPipe { {
1788 lowLevel.wrapUnixSocketFd(fds[0], NEW_FD_FLAGS),
1789 lowLevel.wrapUnixSocketFd(fds[1], NEW_FD_FLAGS)
1790 } };
1791 }
1792
getNetwork()1793 Network& getNetwork() override {
1794 return network;
1795 }
1796
newPipeThread(Function<void (AsyncIoProvider &,AsyncIoStream &,WaitScope &)> startFunc)1797 PipeThread newPipeThread(
1798 Function<void(AsyncIoProvider&, AsyncIoStream&, WaitScope&)> startFunc) override {
1799 int fds[2];
1800 int type = SOCK_STREAM;
1801 #if __linux__ && !__BIONIC__
1802 type |= SOCK_NONBLOCK | SOCK_CLOEXEC;
1803 #endif
1804 KJ_SYSCALL(socketpair(AF_UNIX, type, 0, fds));
1805
1806 int threadFd = fds[1];
1807 KJ_ON_SCOPE_FAILURE(close(threadFd));
1808
1809 auto pipe = lowLevel.wrapSocketFd(fds[0], NEW_FD_FLAGS);
1810
1811 auto thread = heap<Thread>(kj::mvCapture(startFunc,
1812 [threadFd](Function<void(AsyncIoProvider&, AsyncIoStream&, WaitScope&)>&& startFunc) {
1813 LowLevelAsyncIoProviderImpl lowLevel;
1814 auto stream = lowLevel.wrapSocketFd(threadFd, NEW_FD_FLAGS);
1815 AsyncIoProviderImpl ioProvider(lowLevel);
1816 startFunc(ioProvider, *stream, lowLevel.getWaitScope());
1817 }));
1818
1819 return { kj::mv(thread), kj::mv(pipe) };
1820 }
1821
getTimer()1822 Timer& getTimer() override { return lowLevel.getTimer(); }
1823
1824 private:
1825 LowLevelAsyncIoProvider& lowLevel;
1826 SocketNetwork network;
1827 };
1828
1829 } // namespace
1830
newAsyncIoProvider(LowLevelAsyncIoProvider & lowLevel)1831 Own<AsyncIoProvider> newAsyncIoProvider(LowLevelAsyncIoProvider& lowLevel) {
1832 return kj::heap<AsyncIoProviderImpl>(lowLevel);
1833 }
1834
setupAsyncIo()1835 AsyncIoContext setupAsyncIo() {
1836 auto lowLevel = heap<LowLevelAsyncIoProviderImpl>();
1837 auto ioProvider = kj::heap<AsyncIoProviderImpl>(*lowLevel);
1838 auto& waitScope = lowLevel->getWaitScope();
1839 auto& eventPort = lowLevel->getEventPort();
1840 return { kj::mv(lowLevel), kj::mv(ioProvider), waitScope, eventPort };
1841 }
1842
1843 } // namespace kj
1844
1845 #endif // !_WIN32
1846