1 // Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors
2 // Licensed under the MIT License:
3 //
4 // Permission is hereby granted, free of charge, to any person obtaining a copy
5 // of this software and associated documentation files (the "Software"), to deal
6 // in the Software without restriction, including without limitation the rights
7 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 // copies of the Software, and to permit persons to whom the Software is
9 // furnished to do so, subject to the following conditions:
10 //
11 // The above copyright notice and this permission notice shall be included in
12 // all copies or substantial portions of the Software.
13 //
14 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 // THE SOFTWARE.
21 
22 #if !_WIN32
23 // For Win32 implementation, see async-io-win32.c++.
24 
25 #ifndef _GNU_SOURCE
26 #define _GNU_SOURCE
27 #endif
28 
29 #include "async-io.h"
30 #include "async-io-internal.h"
31 #include "async-unix.h"
32 #include "debug.h"
33 #include "thread.h"
34 #include "io.h"
35 #include "miniposix.h"
36 #include <unistd.h>
37 #include <sys/uio.h>
38 #include <errno.h>
39 #include <fcntl.h>
40 #include <sys/types.h>
41 #include <sys/socket.h>
42 #include <sys/un.h>
43 #include <netinet/in.h>
44 #include <netinet/tcp.h>
45 #include <stddef.h>
46 #include <stdlib.h>
47 #include <arpa/inet.h>
48 #include <netdb.h>
49 #include <set>
50 #include <poll.h>
51 #include <limits.h>
52 #include <sys/ioctl.h>
53 
54 #if !defined(SO_PEERCRED) && defined(LOCAL_PEERCRED)
55 #include <sys/ucred.h>
56 #endif
57 
58 namespace kj {
59 
60 namespace {
61 
setNonblocking(int fd)62 void setNonblocking(int fd) {
63 #ifdef FIONBIO
64   int opt = 1;
65   KJ_SYSCALL(ioctl(fd, FIONBIO, &opt));
66 #else
67   int flags;
68   KJ_SYSCALL(flags = fcntl(fd, F_GETFL));
69   if ((flags & O_NONBLOCK) == 0) {
70     KJ_SYSCALL(fcntl(fd, F_SETFL, flags | O_NONBLOCK));
71   }
72 #endif
73 }
74 
setCloseOnExec(int fd)75 void setCloseOnExec(int fd) {
76 #ifdef FIOCLEX
77   KJ_SYSCALL(ioctl(fd, FIOCLEX));
78 #else
79   int flags;
80   KJ_SYSCALL(flags = fcntl(fd, F_GETFD));
81   if ((flags & FD_CLOEXEC) == 0) {
82     KJ_SYSCALL(fcntl(fd, F_SETFD, flags | FD_CLOEXEC));
83   }
84 #endif
85 }
86 
87 static constexpr uint NEW_FD_FLAGS =
88 #if __linux__ && !__BIONIC__
89     LowLevelAsyncIoProvider::ALREADY_CLOEXEC | LowLevelAsyncIoProvider::ALREADY_NONBLOCK |
90 #endif
91     LowLevelAsyncIoProvider::TAKE_OWNERSHIP;
92 // We always try to open FDs with CLOEXEC and NONBLOCK already set on Linux, but on other platforms
93 // this is not possible.
94 
95 class OwnedFileDescriptor {
96 public:
OwnedFileDescriptor(int fd,uint flags)97   OwnedFileDescriptor(int fd, uint flags): fd(fd), flags(flags) {
98     if (flags & LowLevelAsyncIoProvider::ALREADY_NONBLOCK) {
99       KJ_DREQUIRE(fcntl(fd, F_GETFL) & O_NONBLOCK, "You claimed you set NONBLOCK, but you didn't.");
100     } else {
101       setNonblocking(fd);
102     }
103 
104     if (flags & LowLevelAsyncIoProvider::TAKE_OWNERSHIP) {
105       if (flags & LowLevelAsyncIoProvider::ALREADY_CLOEXEC) {
106         KJ_DREQUIRE(fcntl(fd, F_GETFD) & FD_CLOEXEC,
107                     "You claimed you set CLOEXEC, but you didn't.");
108       } else {
109         setCloseOnExec(fd);
110       }
111     }
112   }
113 
~OwnedFileDescriptor()114   ~OwnedFileDescriptor() noexcept(false) {
115     // Don't use SYSCALL() here because close() should not be repeated on EINTR.
116     if ((flags & LowLevelAsyncIoProvider::TAKE_OWNERSHIP) && close(fd) < 0) {
117       KJ_FAIL_SYSCALL("close", errno, fd) {
118         // Recoverable exceptions are safe in destructors.
119         break;
120       }
121     }
122   }
123 
124 protected:
125   const int fd;
126 
127 private:
128   uint flags;
129 };
130 
131 // =======================================================================================
132 
133 class AsyncStreamFd: public OwnedFileDescriptor, public AsyncCapabilityStream {
134 public:
AsyncStreamFd(UnixEventPort & eventPort,int fd,uint flags)135   AsyncStreamFd(UnixEventPort& eventPort, int fd, uint flags)
136       : OwnedFileDescriptor(fd, flags),
137         eventPort(eventPort),
138         observer(eventPort, fd, UnixEventPort::FdObserver::OBSERVE_READ_WRITE) {}
~AsyncStreamFd()139   virtual ~AsyncStreamFd() noexcept(false) {}
140 
tryRead(void * buffer,size_t minBytes,size_t maxBytes)141   Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
142     return tryReadInternal(buffer, minBytes, maxBytes, nullptr, 0, {0,0})
143         .then([](ReadResult r) { return r.byteCount; });
144   }
145 
tryReadWithFds(void * buffer,size_t minBytes,size_t maxBytes,AutoCloseFd * fdBuffer,size_t maxFds)146   Promise<ReadResult> tryReadWithFds(void* buffer, size_t minBytes, size_t maxBytes,
147                                      AutoCloseFd* fdBuffer, size_t maxFds) override {
148     return tryReadInternal(buffer, minBytes, maxBytes, fdBuffer, maxFds, {0,0});
149   }
150 
tryReadWithStreams(void * buffer,size_t minBytes,size_t maxBytes,Own<AsyncCapabilityStream> * streamBuffer,size_t maxStreams)151   Promise<ReadResult> tryReadWithStreams(
152       void* buffer, size_t minBytes, size_t maxBytes,
153       Own<AsyncCapabilityStream>* streamBuffer, size_t maxStreams) override {
154     auto fdBuffer = kj::heapArray<AutoCloseFd>(maxStreams);
155     auto promise = tryReadInternal(buffer, minBytes, maxBytes, fdBuffer.begin(), maxStreams, {0,0});
156 
157     return promise.then([this, fdBuffer = kj::mv(fdBuffer), streamBuffer]
158                         (ReadResult result) mutable {
159       for (auto i: kj::zeroTo(result.capCount)) {
160         streamBuffer[i] = kj::heap<AsyncStreamFd>(eventPort, fdBuffer[i].release(),
161             LowLevelAsyncIoProvider::TAKE_OWNERSHIP | LowLevelAsyncIoProvider::ALREADY_CLOEXEC);
162       }
163       return result;
164     });
165   }
166 
write(const void * buffer,size_t size)167   Promise<void> write(const void* buffer, size_t size) override {
168     ssize_t n;
169     KJ_NONBLOCKING_SYSCALL(n = ::write(fd, buffer, size)) {
170       // Error.
171 
172       // We can't "return kj::READY_NOW;" inside this block because it causes a memory leak due to
173       // a bug that exists in both Clang and GCC:
174       //   http://gcc.gnu.org/bugzilla/show_bug.cgi?id=33799
175       //   http://llvm.org/bugs/show_bug.cgi?id=12286
176       goto error;
177     }
178     if (false) {
179     error:
180       return kj::READY_NOW;
181     }
182 
183     if (n < 0) {
184       // EAGAIN -- need to wait for writability and try again.
185       return observer.whenBecomesWritable().then([=]() {
186         return write(buffer, size);
187       });
188     } else if (n == size) {
189       // All done.
190       return READY_NOW;
191     } else {
192       // Fewer than `size` bytes were written, but we CANNOT assume we're out of buffer space, as
193       // Linux is known to return partial reads/writes when interrupted by a signal -- yes, even
194       // for non-blocking operations. So, we'll need to write() again now, even though it will
195       // almost certainly fail with EAGAIN. See comments in the read path for more info.
196       buffer = reinterpret_cast<const byte*>(buffer) + n;
197       size -= n;
198       return write(buffer, size);
199     }
200   }
201 
write(ArrayPtr<const ArrayPtr<const byte>> pieces)202   Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override {
203     if (pieces.size() == 0) {
204       return writeInternal(nullptr, nullptr, nullptr);
205     } else {
206       return writeInternal(pieces[0], pieces.slice(1, pieces.size()), nullptr);
207     }
208   }
209 
writeWithFds(ArrayPtr<const byte> data,ArrayPtr<const ArrayPtr<const byte>> moreData,ArrayPtr<const int> fds)210   Promise<void> writeWithFds(ArrayPtr<const byte> data,
211                              ArrayPtr<const ArrayPtr<const byte>> moreData,
212                              ArrayPtr<const int> fds) override {
213     return writeInternal(data, moreData, fds);
214   }
215 
writeWithStreams(ArrayPtr<const byte> data,ArrayPtr<const ArrayPtr<const byte>> moreData,Array<Own<AsyncCapabilityStream>> streams)216   Promise<void> writeWithStreams(ArrayPtr<const byte> data,
217                                  ArrayPtr<const ArrayPtr<const byte>> moreData,
218                                  Array<Own<AsyncCapabilityStream>> streams) override {
219     auto fds = KJ_MAP(stream, streams) {
220       return downcast<AsyncStreamFd>(*stream).fd;
221     };
222     auto promise = writeInternal(data, moreData, fds);
223     return promise.attach(kj::mv(fds), kj::mv(streams));
224   }
225 
whenWriteDisconnected()226   Promise<void> whenWriteDisconnected() override {
227     KJ_IF_MAYBE(p, writeDisconnectedPromise) {
228       return p->addBranch();
229     } else {
230       auto fork = observer.whenWriteDisconnected().fork();
231       auto result = fork.addBranch();
232       writeDisconnectedPromise = kj::mv(fork);
233       return kj::mv(result);
234     }
235   }
236 
shutdownWrite()237   void shutdownWrite() override {
238     // There's no legitimate way to get an AsyncStreamFd that isn't a socket through the
239     // UnixAsyncIoProvider interface.
240     KJ_SYSCALL(shutdown(fd, SHUT_WR));
241   }
242 
abortRead()243   void abortRead() override {
244     // There's no legitimate way to get an AsyncStreamFd that isn't a socket through the
245     // UnixAsyncIoProvider interface.
246     KJ_SYSCALL(shutdown(fd, SHUT_RD));
247   }
248 
getsockopt(int level,int option,void * value,uint * length)249   void getsockopt(int level, int option, void* value, uint* length) override {
250     socklen_t socklen = *length;
251     KJ_SYSCALL(::getsockopt(fd, level, option, value, &socklen));
252     *length = socklen;
253   }
254 
setsockopt(int level,int option,const void * value,uint length)255   void setsockopt(int level, int option, const void* value, uint length) override {
256     KJ_SYSCALL(::setsockopt(fd, level, option, value, length));
257   }
258 
getsockname(struct sockaddr * addr,uint * length)259   void getsockname(struct sockaddr* addr, uint* length) override {
260     socklen_t socklen = *length;
261     KJ_SYSCALL(::getsockname(fd, addr, &socklen));
262     *length = socklen;
263   }
264 
getpeername(struct sockaddr * addr,uint * length)265   void getpeername(struct sockaddr* addr, uint* length) override {
266     socklen_t socklen = *length;
267     KJ_SYSCALL(::getpeername(fd, addr, &socklen));
268     *length = socklen;
269   }
270 
getFd() const271   kj::Maybe<int> getFd() const override {
272     return fd;
273   }
274 
registerAncillaryMessageHandler(kj::Function<void (kj::ArrayPtr<AncillaryMessage>)> fn)275   void registerAncillaryMessageHandler(
276       kj::Function<void(kj::ArrayPtr<AncillaryMessage>)> fn) override {
277     ancillaryMsgCallback = kj::mv(fn);
278   }
279 
waitConnected()280   Promise<void> waitConnected() {
281     // Wait until initial connection has completed. This actually just waits until it is writable.
282 
283     // Can't just go directly to writeObserver.whenBecomesWritable() because of edge triggering. We
284     // need to explicitly check if the socket is already connected.
285 
286     struct pollfd pollfd;
287     memset(&pollfd, 0, sizeof(pollfd));
288     pollfd.fd = fd;
289     pollfd.events = POLLOUT;
290 
291     int pollResult;
292     KJ_SYSCALL(pollResult = poll(&pollfd, 1, 0));
293 
294     if (pollResult == 0) {
295       // Not ready yet. We can safely use the edge-triggered observer.
296       return observer.whenBecomesWritable();
297     } else {
298       // Ready now.
299       return kj::READY_NOW;
300     }
301   }
302 
303 private:
304   UnixEventPort& eventPort;
305   UnixEventPort::FdObserver observer;
306   Maybe<ForkedPromise<void>> writeDisconnectedPromise;
307   Maybe<Function<void(ArrayPtr<AncillaryMessage>)>> ancillaryMsgCallback;
308 
tryReadInternal(void * buffer,size_t minBytes,size_t maxBytes,AutoCloseFd * fdBuffer,size_t maxFds,ReadResult alreadyRead)309   Promise<ReadResult> tryReadInternal(void* buffer, size_t minBytes, size_t maxBytes,
310                                       AutoCloseFd* fdBuffer, size_t maxFds,
311                                       ReadResult alreadyRead) {
312     // `alreadyRead` is the number of bytes we have already received via previous reads -- minBytes,
313     // maxBytes, and buffer have already been adjusted to account for them, but this count must
314     // be included in the final return value.
315 
316     ssize_t n;
317     if (maxFds == 0 && ancillaryMsgCallback == nullptr) {
318       KJ_NONBLOCKING_SYSCALL(n = ::read(fd, buffer, maxBytes)) {
319         // Error.
320 
321         // We can't "return kj::READY_NOW;" inside this block because it causes a memory leak due to
322         // a bug that exists in both Clang and GCC:
323         //   http://gcc.gnu.org/bugzilla/show_bug.cgi?id=33799
324         //   http://llvm.org/bugs/show_bug.cgi?id=12286
325         goto error;
326       }
327     } else {
328       struct msghdr msg;
329       memset(&msg, 0, sizeof(msg));
330 
331       struct iovec iov;
332       memset(&iov, 0, sizeof(iov));
333       iov.iov_base = buffer;
334       iov.iov_len = maxBytes;
335       msg.msg_iov = &iov;
336       msg.msg_iovlen = 1;
337 
338       // Allocate space to receive a cmsg.
339       size_t msgBytes;
340       if (ancillaryMsgCallback == nullptr) {
341 #if __APPLE__ || __FreeBSD__
342         // Until very recently (late 2018 / early 2019), FreeBSD suffered from a bug in which when
343         // an SCM_RIGHTS message was truncated on delivery, it would not close the FDs that weren't
344         // delivered -- they would simply leak: https://bugs.freebsd.org/131876
345         //
346         // My testing indicates that MacOS has this same bug as of today (April 2019). I don't know
347         // if they plan to fix it or are even aware of it.
348         //
349         // To handle both cases, we will always provide space to receive 512 FDs. Hopefully, this is
350         // greater than the maximum number of FDs that these kernels will transmit in one message
351         // PLUS enough space for any other ancillary messages that could be sent before the
352         // SCM_RIGHTS message to push it back in the buffer. I couldn't find any firm documentation
353         // on these limits, though -- I only know that Linux is limited to 253, and I saw a hint in
354         // a comment in someone else's application that suggested FreeBSD is the same. Hopefully,
355         // then, this is sufficient to prevent attacks. But if not, there's nothing more we can do;
356         // it's really up to the kernel to fix this.
357         msgBytes = CMSG_SPACE(sizeof(int) * 512);
358 #else
359         msgBytes = CMSG_SPACE(sizeof(int) * maxFds);
360 #endif
361       } else {
362         // If we want room for ancillary messages instead of or in addition to FDs, just use the
363         // same amount of cushion as in the MacOS/FreeBSD case above.
364         // Someday we may want to allow customization here, but there's no immediate use for it.
365         msgBytes = CMSG_SPACE(sizeof(int) * 512);
366       }
367 
368       // On Linux, CMSG_SPACE will align to a word-size boundary, but on Mac it always aligns to a
369       // 32-bit boundary. I guess aligning to 32 bits helps avoid the problem where you
370       // surprisingly end up with space for two file descriptors when you only wanted one. However,
371       // cmsghdr's preferred alignment is word-size (it contains a size_t). If we stack-allocate
372       // the buffer, we need to make sure it is aligned properly (maybe not on x64, but maybe on
373       // other platforms), so we want to allocate an array of words (we use void*). So... we use
374       // CMSG_SPACE() and then additionally round up to deal with Mac.
375       size_t msgWords = (msgBytes + sizeof(void*) - 1) / sizeof(void*);
376       KJ_STACK_ARRAY(void*, cmsgSpace, msgWords, 16, 256);
377       auto cmsgBytes = cmsgSpace.asBytes();
378       memset(cmsgBytes.begin(), 0, cmsgBytes.size());
379       msg.msg_control = cmsgBytes.begin();
380       msg.msg_controllen = msgBytes;
381 
382 #ifdef MSG_CMSG_CLOEXEC
383       static constexpr int RECVMSG_FLAGS = MSG_CMSG_CLOEXEC;
384 #else
385       static constexpr int RECVMSG_FLAGS = 0;
386 #endif
387 
388       KJ_NONBLOCKING_SYSCALL(n = ::recvmsg(fd, &msg, RECVMSG_FLAGS)) {
389         // Error.
390 
391         // We can't "return kj::READY_NOW;" inside this block because it causes a memory leak due to
392         // a bug that exists in both Clang and GCC:
393         //   http://gcc.gnu.org/bugzilla/show_bug.cgi?id=33799
394         //   http://llvm.org/bugs/show_bug.cgi?id=12286
395         goto error;
396       }
397 
398       if (n >= 0) {
399         // Process all messages.
400         //
401         // WARNING DANGER: We have to be VERY careful not to miss a file descriptor here, because
402         // if we do, then that FD will never be closed, and a malicious peer could exploit this to
403         // fill up our FD table, creating a DoS attack. Some things to keep in mind:
404         // - CMSG_SPACE() could have rounded up the space for alignment purposes, and this could
405         //   mean we permitted the kernel to deliver more file descriptors than `maxFds`. We need
406         //   to close the extras.
407         // - We can receive multiple ancillary messages at once. In particular, there is also
408         //   SCM_CREDENTIALS. The sender decides what to send. They could send SCM_CREDENTIALS
409         //   first followed by SCM_RIGHTS. We need to make sure we see both.
410         size_t nfds = 0;
411         size_t spaceLeft = msg.msg_controllen;
412         Vector<AncillaryMessage> ancillaryMessages;
413         for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
414             cmsg != nullptr; cmsg = CMSG_NXTHDR(&msg, cmsg)) {
415           if (spaceLeft >= CMSG_LEN(0) &&
416               cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
417             // Some operating systems (like MacOS) do not adjust csmg_len when the message is
418             // truncated. We must do so ourselves or risk overrunning the buffer.
419             auto len = kj::min(cmsg->cmsg_len, spaceLeft);
420             auto data = arrayPtr(reinterpret_cast<int*>(CMSG_DATA(cmsg)),
421                                  (len - CMSG_LEN(0)) / sizeof(int));
422             kj::Vector<kj::AutoCloseFd> trashFds;
423             for (auto fd: data) {
424               kj::AutoCloseFd ownFd(fd);
425               if (nfds < maxFds) {
426                 fdBuffer[nfds++] = kj::mv(ownFd);
427               } else {
428                 trashFds.add(kj::mv(ownFd));
429               }
430             }
431           } else if (spaceLeft >= CMSG_LEN(0) && ancillaryMsgCallback != nullptr) {
432             auto len = kj::min(cmsg->cmsg_len, spaceLeft);
433             auto data = ArrayPtr<const byte>(CMSG_DATA(cmsg), len - CMSG_LEN(0));
434             ancillaryMessages.add(cmsg->cmsg_level, cmsg->cmsg_type, data);
435           }
436 
437           if (spaceLeft >= CMSG_LEN(0) && spaceLeft >= cmsg->cmsg_len) {
438             spaceLeft -= cmsg->cmsg_len;
439           } else {
440             spaceLeft = 0;
441           }
442         }
443 
444 #ifndef MSG_CMSG_CLOEXEC
445         for (size_t i = 0; i < nfds; i++) {
446           setCloseOnExec(fdBuffer[i]);
447         }
448 #endif
449 
450         if (ancillaryMessages.size() > 0) {
451           KJ_IF_MAYBE(fn, ancillaryMsgCallback) {
452             (*fn)(ancillaryMessages.asPtr());
453           }
454         }
455 
456         alreadyRead.capCount += nfds;
457         fdBuffer += nfds;
458         maxFds -= nfds;
459       }
460     }
461 
462     if (false) {
463     error:
464       return alreadyRead;
465     }
466 
467     if (n < 0) {
468       // Read would block.
469       return observer.whenBecomesReadable().then([=]() {
470         return tryReadInternal(buffer, minBytes, maxBytes, fdBuffer, maxFds, alreadyRead);
471       });
472     } else if (n == 0) {
473       // EOF -OR- maxBytes == 0.
474       return alreadyRead;
475     } else if (implicitCast<size_t>(n) >= minBytes) {
476       // We read enough to stop here.
477       alreadyRead.byteCount += n;
478       return alreadyRead;
479     } else {
480       // The kernel returned fewer bytes than we asked for (and fewer than we need).
481 
482       buffer = reinterpret_cast<byte*>(buffer) + n;
483       minBytes -= n;
484       maxBytes -= n;
485       alreadyRead.byteCount += n;
486 
487       // According to David Klempner, who works on Stubby at Google, we sadly CANNOT assume that
488       // we've consumed the whole read buffer here. If a signal is delivered in the middle of a
489       // read() -- yes, even a non-blocking read -- it can cause the kernel to return a partial
490       // result, with data still in the buffer.
491       //     https://bugzilla.kernel.org/show_bug.cgi?id=199131
492       //     https://twitter.com/CaptainSegfault/status/1112622245531144194
493       //
494       // Unfortunately, we have no choice but to issue more read()s until it either tells us EOF
495       // or EAGAIN. We used to have an optimization here using observer.atEndHint() (when it is
496       // non-null) to avoid a redundant call to read(). Alas...
497       return tryReadInternal(buffer, minBytes, maxBytes, fdBuffer, maxFds, alreadyRead);
498     }
499   }
500 
writeInternal(ArrayPtr<const byte> firstPiece,ArrayPtr<const ArrayPtr<const byte>> morePieces,ArrayPtr<const int> fds)501   Promise<void> writeInternal(ArrayPtr<const byte> firstPiece,
502                               ArrayPtr<const ArrayPtr<const byte>> morePieces,
503                               ArrayPtr<const int> fds) {
504     const size_t iovmax = kj::miniposix::iovMax();
505     // If there are more than IOV_MAX pieces, we'll only write the first IOV_MAX for now, and
506     // then we'll loop later.
507     KJ_STACK_ARRAY(struct iovec, iov, kj::min(1 + morePieces.size(), iovmax), 16, 128);
508     size_t iovTotal = 0;
509 
510     // writev() interface is not const-correct.  :(
511     iov[0].iov_base = const_cast<byte*>(firstPiece.begin());
512     iov[0].iov_len = firstPiece.size();
513     iovTotal += iov[0].iov_len;
514     for (uint i = 1; i < iov.size(); i++) {
515       iov[i].iov_base = const_cast<byte*>(morePieces[i - 1].begin());
516       iov[i].iov_len = morePieces[i - 1].size();
517       iovTotal += iov[i].iov_len;
518     }
519 
520     if (iovTotal == 0) {
521       KJ_REQUIRE(fds.size() == 0, "can't write FDs without bytes");
522       return kj::READY_NOW;
523     }
524 
525     ssize_t n;
526     if (fds.size() == 0) {
527       KJ_NONBLOCKING_SYSCALL(n = ::writev(fd, iov.begin(), iov.size()), iovTotal, iov.size()) {
528         // Error.
529 
530         // We can't "return kj::READY_NOW;" inside this block because it causes a memory leak due to
531         // a bug that exists in both Clang and GCC:
532         //   http://gcc.gnu.org/bugzilla/show_bug.cgi?id=33799
533         //   http://llvm.org/bugs/show_bug.cgi?id=12286
534         goto error;
535       }
536     } else {
537       struct msghdr msg;
538       memset(&msg, 0, sizeof(msg));
539       msg.msg_iov = iov.begin();
540       msg.msg_iovlen = iov.size();
541 
542       // Allocate space to send a cmsg.
543       size_t msgBytes = CMSG_SPACE(sizeof(int) * fds.size());
544       // On Linux, CMSG_SPACE will align to a word-size boundary, but on Mac it always aligns to a
545       // 32-bit boundary. I guess aligning to 32 bits helps avoid the problem where you
546       // surprisingly end up with space for two file descriptors when you only wanted one. However,
547       // cmsghdr's preferred alignment is word-size (it contains a size_t). If we stack-allocate
548       // the buffer, we need to make sure it is aligned properly (maybe not on x64, but maybe on
549       // other platforms), so we want to allocate an array of words (we use void*). So... we use
550       // CMSG_SPACE() and then additionally round up to deal with Mac.
551       size_t msgWords = (msgBytes + sizeof(void*) - 1) / sizeof(void*);
552       KJ_STACK_ARRAY(void*, cmsgSpace, msgWords, 16, 256);
553       auto cmsgBytes = cmsgSpace.asBytes();
554       memset(cmsgBytes.begin(), 0, cmsgBytes.size());
555       msg.msg_control = cmsgBytes.begin();
556       msg.msg_controllen = msgBytes;
557 
558       struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
559       cmsg->cmsg_level = SOL_SOCKET;
560       cmsg->cmsg_type = SCM_RIGHTS;
561       cmsg->cmsg_len = CMSG_LEN(sizeof(int) * fds.size());
562       memcpy(CMSG_DATA(cmsg), fds.begin(), fds.asBytes().size());
563 
564       KJ_NONBLOCKING_SYSCALL(n = ::sendmsg(fd, &msg, 0)) {
565         // Error.
566 
567         // We can't "return kj::READY_NOW;" inside this block because it causes a memory leak due to
568         // a bug that exists in both Clang and GCC:
569         //   http://gcc.gnu.org/bugzilla/show_bug.cgi?id=33799
570         //   http://llvm.org/bugs/show_bug.cgi?id=12286
571         goto error;
572       }
573     }
574 
575     if (false) {
576     error:
577       return kj::READY_NOW;
578     }
579 
580     if (n < 0) {
581       // Got EAGAIN. Nothing was written.
582       return observer.whenBecomesWritable().then([=]() {
583         return writeInternal(firstPiece, morePieces, fds);
584       });
585     } else if (n == 0) {
586       // Why would a sendmsg() with a non-empty message ever return 0 when writing to a stream
587       // socket? If there's no room in the send buffer, it should fail with EAGAIN. If the
588       // connection is closed, it should fail with EPIPE. Various documents and forum posts around
589       // the internet claim this can happen but no one seems to know when. My guess is it can only
590       // happen if we try to send an empty message -- which we didn't. So I think this is
591       // impossible. If it is possible, we need to figure out how to correctly handle it, which
592       // depends on what caused it.
593       //
594       // Note in particular that if 0 is a valid return here, and we sent an SCM_RIGHTS message,
595       // we need to know whether the message was sent or not, in order to decide whether to retry
596       // sending it!
597       KJ_FAIL_ASSERT("non-empty sendmsg() returned 0");
598     }
599 
600     // Non-zero bytes were written. This also implies that *all* FDs were written.
601 
602     // Discard all data that was written, then issue a new write for what's left (if any).
603     for (;;) {
604       if (n < firstPiece.size()) {
605         // Only part of the first piece was consumed.  Wait for buffer space and then write again.
606         firstPiece = firstPiece.slice(n, firstPiece.size());
607         iovTotal -= n;
608 
609         if (iovTotal == 0) {
610           // Oops, what actually happened is that we hit the IOV_MAX limit. Don't wait.
611           return writeInternal(firstPiece, morePieces, nullptr);
612         }
613 
614         // As with read(), we cannot assume that a short write() really means the write buffer is
615         // full (see comments in the read path above). We have to write again.
616         return writeInternal(firstPiece, morePieces, nullptr);
617       } else if (morePieces.size() == 0) {
618         // First piece was fully-consumed and there are no more pieces, so we're done.
619         KJ_DASSERT(n == firstPiece.size(), n);
620         return READY_NOW;
621       } else {
622         // First piece was fully consumed, so move on to the next piece.
623         n -= firstPiece.size();
624         iovTotal -= firstPiece.size();
625         firstPiece = morePieces[0];
626         morePieces = morePieces.slice(1, morePieces.size());
627       }
628     }
629   }
630 };
631 
632 // =======================================================================================
633 
634 class SocketAddress {
635 public:
SocketAddress(const void * sockaddr,uint len)636   SocketAddress(const void* sockaddr, uint len): addrlen(len) {
637     KJ_REQUIRE(len <= sizeof(addr), "Sorry, your sockaddr is too big for me.");
638     memcpy(&addr.generic, sockaddr, len);
639   }
640 
operator <(const SocketAddress & other) const641   bool operator<(const SocketAddress& other) const {
642     // So we can use std::set<SocketAddress>...  see DNS lookup code.
643 
644     if (wildcard < other.wildcard) return true;
645     if (wildcard > other.wildcard) return false;
646 
647     if (addrlen < other.addrlen) return true;
648     if (addrlen > other.addrlen) return false;
649 
650     return memcmp(&addr.generic, &other.addr.generic, addrlen) < 0;
651   }
652 
getRaw() const653   const struct sockaddr* getRaw() const { return &addr.generic; }
getRawSize() const654   socklen_t getRawSize() const { return addrlen; }
655 
socket(int type) const656   int socket(int type) const {
657     bool isStream = type == SOCK_STREAM;
658 
659     int result;
660 #if __linux__ && !__BIONIC__
661     type |= SOCK_NONBLOCK | SOCK_CLOEXEC;
662 #endif
663     KJ_SYSCALL(result = ::socket(addr.generic.sa_family, type, 0));
664 
665     if (isStream && (addr.generic.sa_family == AF_INET ||
666                      addr.generic.sa_family == AF_INET6)) {
667       // TODO(perf):  As a hack for the 0.4 release we are always setting
668       //   TCP_NODELAY because Nagle's algorithm pretty much kills Cap'n Proto's
669       //   RPC protocol.  Later, we should extend the interface to provide more
670       //   control over this.  Perhaps write() should have a flag which
671       //   specifies whether to pass MSG_MORE.
672       int one = 1;
673       KJ_SYSCALL(setsockopt(
674           result, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(one)));
675     }
676 
677     return result;
678   }
679 
bind(int sockfd) const680   void bind(int sockfd) const {
681 #if !defined(__OpenBSD__)
682     if (wildcard) {
683       // Disable IPV6_V6ONLY because we want to handle both ipv4 and ipv6 on this socket.  (The
684       // default value of this option varies across platforms.)
685       int value = 0;
686       KJ_SYSCALL(setsockopt(sockfd, IPPROTO_IPV6, IPV6_V6ONLY, &value, sizeof(value)));
687     }
688 #endif
689 
690     KJ_SYSCALL(::bind(sockfd, &addr.generic, addrlen), toString());
691   }
692 
getPort() const693   uint getPort() const {
694     switch (addr.generic.sa_family) {
695       case AF_INET: return ntohs(addr.inet4.sin_port);
696       case AF_INET6: return ntohs(addr.inet6.sin6_port);
697       default: return 0;
698     }
699   }
700 
toString() const701   String toString() const {
702     if (wildcard) {
703       return str("*:", getPort());
704     }
705 
706     switch (addr.generic.sa_family) {
707       case AF_INET: {
708         char buffer[INET6_ADDRSTRLEN];
709         if (inet_ntop(addr.inet4.sin_family, &addr.inet4.sin_addr,
710                       buffer, sizeof(buffer)) == nullptr) {
711           KJ_FAIL_SYSCALL("inet_ntop", errno) { break; }
712           return heapString("(inet_ntop error)");
713         }
714         return str(buffer, ':', ntohs(addr.inet4.sin_port));
715       }
716       case AF_INET6: {
717         char buffer[INET6_ADDRSTRLEN];
718         if (inet_ntop(addr.inet6.sin6_family, &addr.inet6.sin6_addr,
719                       buffer, sizeof(buffer)) == nullptr) {
720           KJ_FAIL_SYSCALL("inet_ntop", errno) { break; }
721           return heapString("(inet_ntop error)");
722         }
723         return str('[', buffer, "]:", ntohs(addr.inet6.sin6_port));
724       }
725       case AF_UNIX: {
726         auto path = _::safeUnixPath(&addr.unixDomain, addrlen);
727         if (path.size() > 0 && path[0] == '\0') {
728           return str("unix-abstract:", path.slice(1, path.size()));
729         } else {
730           return str("unix:", path);
731         }
732       }
733       default:
734         return str("(unknown address family ", addr.generic.sa_family, ")");
735     }
736   }
737 
738   static Promise<Array<SocketAddress>> lookupHost(
739       LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint,
740       _::NetworkFilter& filter);
741   // Perform a DNS lookup.
742 
parse(LowLevelAsyncIoProvider & lowLevel,StringPtr str,uint portHint,_::NetworkFilter & filter)743   static Promise<Array<SocketAddress>> parse(
744       LowLevelAsyncIoProvider& lowLevel, StringPtr str, uint portHint, _::NetworkFilter& filter) {
745     // TODO(someday):  Allow commas in `str`.
746 
747     SocketAddress result;
748 
749     if (str.startsWith("unix:")) {
750       StringPtr path = str.slice(strlen("unix:"));
751       KJ_REQUIRE(path.size() < sizeof(addr.unixDomain.sun_path),
752                  "Unix domain socket address is too long.", str);
753       KJ_REQUIRE(path.size() == strlen(path.cStr()),
754                  "Unix domain socket address contains NULL. Use"
755                  " 'unix-abstract:' for the abstract namespace.");
756       result.addr.unixDomain.sun_family = AF_UNIX;
757       strcpy(result.addr.unixDomain.sun_path, path.cStr());
758       result.addrlen = offsetof(struct sockaddr_un, sun_path) + path.size() + 1;
759 
760       if (!result.parseAllowedBy(filter)) {
761         KJ_FAIL_REQUIRE("unix sockets blocked by restrictPeers()");
762         return Array<SocketAddress>();
763       }
764 
765       auto array = kj::heapArrayBuilder<SocketAddress>(1);
766       array.add(result);
767       return array.finish();
768     }
769 
770     if (str.startsWith("unix-abstract:")) {
771       StringPtr path = str.slice(strlen("unix-abstract:"));
772       KJ_REQUIRE(path.size() + 1 < sizeof(addr.unixDomain.sun_path),
773                  "Unix domain socket address is too long.", str);
774       result.addr.unixDomain.sun_family = AF_UNIX;
775       result.addr.unixDomain.sun_path[0] = '\0';
776       // although not strictly required by Linux, also copy the trailing
777       // NULL terminator so that we can safely read it back in toString
778       memcpy(result.addr.unixDomain.sun_path + 1, path.cStr(), path.size() + 1);
779       result.addrlen = offsetof(struct sockaddr_un, sun_path) + path.size() + 1;
780 
781       if (!result.parseAllowedBy(filter)) {
782         KJ_FAIL_REQUIRE("abstract unix sockets blocked by restrictPeers()");
783         return Array<SocketAddress>();
784       }
785 
786       auto array = kj::heapArrayBuilder<SocketAddress>(1);
787       array.add(result);
788       return array.finish();
789     }
790 
791     // Try to separate the address and port.
792     ArrayPtr<const char> addrPart;
793     Maybe<StringPtr> portPart;
794 
795     int af;
796 
797     if (str.startsWith("[")) {
798       // Address starts with a bracket, which is a common way to write an ip6 address with a port,
799       // since without brackets around the address part, the port looks like another segment of
800       // the address.
801       af = AF_INET6;
802       size_t closeBracket = KJ_ASSERT_NONNULL(str.findLast(']'),
803           "Unclosed '[' in address string.", str);
804 
805       addrPart = str.slice(1, closeBracket);
806       if (str.size() > closeBracket + 1) {
807         KJ_REQUIRE(str.slice(closeBracket + 1).startsWith(":"),
808                    "Expected port suffix after ']'.", str);
809         portPart = str.slice(closeBracket + 2);
810       }
811     } else {
812       KJ_IF_MAYBE(colon, str.findFirst(':')) {
813         if (str.slice(*colon + 1).findFirst(':') == nullptr) {
814           // There is exactly one colon and no brackets, so it must be an ip4 address with port.
815           af = AF_INET;
816           addrPart = str.slice(0, *colon);
817           portPart = str.slice(*colon + 1);
818         } else {
819           // There are two or more colons and no brackets, so the whole thing must be an ip6
820           // address with no port.
821           af = AF_INET6;
822           addrPart = str;
823         }
824       } else {
825         // No colons, so it must be an ip4 address without port.
826         af = AF_INET;
827         addrPart = str;
828       }
829     }
830 
831     // Parse the port.
832     unsigned long port;
833     KJ_IF_MAYBE(portText, portPart) {
834       char* endptr;
835       port = strtoul(portText->cStr(), &endptr, 0);
836       if (portText->size() == 0 || *endptr != '\0') {
837         // Not a number.  Maybe it's a service name.  Fall back to DNS.
838         return lookupHost(lowLevel, kj::heapString(addrPart), kj::heapString(*portText), portHint,
839                           filter);
840       }
841       KJ_REQUIRE(port < 65536, "Port number too large.");
842     } else {
843       port = portHint;
844     }
845 
846     // Check for wildcard.
847     if (addrPart.size() == 1 && addrPart[0] == '*') {
848       result.wildcard = true;
849 #if defined(__OpenBSD__)
850       // On OpenBSD, all sockets are either v4-only or v6-only, so use v4 as a
851       // temporary workaround for wildcards.
852       result.addrlen = sizeof(addr.inet4);
853       result.addr.inet4.sin_family = AF_INET;
854       result.addr.inet4.sin_port = htons(port);
855 #else
856       // Create an ip6 socket and set IPV6_V6ONLY to 0 later.
857       result.addrlen = sizeof(addr.inet6);
858       result.addr.inet6.sin6_family = AF_INET6;
859       result.addr.inet6.sin6_port = htons(port);
860 #endif
861 
862       auto array = kj::heapArrayBuilder<SocketAddress>(1);
863       array.add(result);
864       return array.finish();
865     }
866 
867     void* addrTarget;
868     if (af == AF_INET6) {
869       result.addrlen = sizeof(addr.inet6);
870       result.addr.inet6.sin6_family = AF_INET6;
871       result.addr.inet6.sin6_port = htons(port);
872       addrTarget = &result.addr.inet6.sin6_addr;
873     } else {
874       result.addrlen = sizeof(addr.inet4);
875       result.addr.inet4.sin_family = AF_INET;
876       result.addr.inet4.sin_port = htons(port);
877       addrTarget = &result.addr.inet4.sin_addr;
878     }
879 
880     if (addrPart.size() < INET6_ADDRSTRLEN - 1) {
881       // addrPart is not necessarily NUL-terminated so we have to make a copy.  :(
882       char buffer[INET6_ADDRSTRLEN];
883       memcpy(buffer, addrPart.begin(), addrPart.size());
884       buffer[addrPart.size()] = '\0';
885 
886       // OK, parse it!
887       switch (inet_pton(af, buffer, addrTarget)) {
888         case 1: {
889           // success.
890           if (!result.parseAllowedBy(filter)) {
891             KJ_FAIL_REQUIRE("address family blocked by restrictPeers()");
892             return Array<SocketAddress>();
893           }
894 
895           auto array = kj::heapArrayBuilder<SocketAddress>(1);
896           array.add(result);
897           return array.finish();
898         }
899         case 0:
900           // It's apparently not a simple address...  fall back to DNS.
901           break;
902         default:
903           KJ_FAIL_SYSCALL("inet_pton", errno, af, addrPart);
904       }
905     }
906 
907     return lookupHost(lowLevel, kj::heapString(addrPart), nullptr, port, filter);
908   }
909 
getLocalAddress(int sockfd)910   static SocketAddress getLocalAddress(int sockfd) {
911     SocketAddress result;
912     result.addrlen = sizeof(addr);
913     KJ_SYSCALL(getsockname(sockfd, &result.addr.generic, &result.addrlen));
914     return result;
915   }
916 
allowedBy(LowLevelAsyncIoProvider::NetworkFilter & filter)917   bool allowedBy(LowLevelAsyncIoProvider::NetworkFilter& filter) {
918     return filter.shouldAllow(&addr.generic, addrlen);
919   }
920 
parseAllowedBy(_::NetworkFilter & filter)921   bool parseAllowedBy(_::NetworkFilter& filter) {
922     return filter.shouldAllowParse(&addr.generic, addrlen);
923   }
924 
925   kj::Own<PeerIdentity> getIdentity(LowLevelAsyncIoProvider& llaiop,
926                                     LowLevelAsyncIoProvider::NetworkFilter& filter,
927                                     AsyncIoStream& stream) const;
928 
929 private:
SocketAddress()930   SocketAddress() {
931     // We need to memset the whole object 0 otherwise Valgrind gets unhappy when we write it to a
932     // pipe, due to the padding bytes being uninitialized.
933     memset(this, 0, sizeof(*this));
934   }
935 
936   socklen_t addrlen;
937   bool wildcard = false;
938   union {
939     struct sockaddr generic;
940     struct sockaddr_in inet4;
941     struct sockaddr_in6 inet6;
942     struct sockaddr_un unixDomain;
943     struct sockaddr_storage storage;
944   } addr;
945 
946   struct LookupParams;
947   class LookupReader;
948 };
949 
950 class SocketAddress::LookupReader {
951   // Reads SocketAddresses off of a pipe coming from another thread that is performing
952   // getaddrinfo.
953 
954 public:
LookupReader(kj::Own<Thread> && thread,kj::Own<AsyncInputStream> && input,_::NetworkFilter & filter)955   LookupReader(kj::Own<Thread>&& thread, kj::Own<AsyncInputStream>&& input,
956                _::NetworkFilter& filter)
957       : thread(kj::mv(thread)), input(kj::mv(input)), filter(filter) {}
958 
~LookupReader()959   ~LookupReader() {
960     if (thread) thread->detach();
961   }
962 
read()963   Promise<Array<SocketAddress>> read() {
964     return input->tryRead(&current, sizeof(current), sizeof(current)).then(
965         [this](size_t n) -> Promise<Array<SocketAddress>> {
966       if (n < sizeof(current)) {
967         thread = nullptr;
968         // getaddrinfo()'s docs seem to say it will never return an empty list, but let's check
969         // anyway.
970         KJ_REQUIRE(addresses.size() > 0, "DNS lookup returned no permitted addresses.") { break; }
971         return addresses.releaseAsArray();
972       } else {
973         // getaddrinfo() can return multiple copies of the same address for several reasons.
974         // A major one is that we don't give it a socket type (SOCK_STREAM vs. SOCK_DGRAM), so
975         // it may return two copies of the same address, one for each type, unless it explicitly
976         // knows that the service name given is specific to one type.  But we can't tell it a type,
977         // because we don't actually know which one the user wants, and if we specify SOCK_STREAM
978         // while the user specified a UDP service name then they'll get a resolution error which
979         // is lame.  (At least, I think that's how it works.)
980         //
981         // So we instead resort to de-duping results.
982         if (alreadySeen.insert(current).second) {
983           if (current.parseAllowedBy(filter)) {
984             addresses.add(current);
985           }
986         }
987         return read();
988       }
989     });
990   }
991 
992 private:
993   kj::Own<Thread> thread;
994   kj::Own<AsyncInputStream> input;
995   _::NetworkFilter& filter;
996   SocketAddress current;
997   kj::Vector<SocketAddress> addresses;
998   std::set<SocketAddress> alreadySeen;
999 };
1000 
1001 struct SocketAddress::LookupParams {
1002   kj::String host;
1003   kj::String service;
1004 };
1005 
lookupHost(LowLevelAsyncIoProvider & lowLevel,kj::String host,kj::String service,uint portHint,_::NetworkFilter & filter)1006 Promise<Array<SocketAddress>> SocketAddress::lookupHost(
1007     LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint,
1008     _::NetworkFilter& filter) {
1009   // This shitty function spawns a thread to run getaddrinfo().  Unfortunately, getaddrinfo() is
1010   // the only cross-platform DNS API and it is blocking.
1011   //
1012   // TODO(perf):  Use a thread pool?  Maybe kj::Thread should use a thread pool automatically?
1013   //   Maybe use the various platform-specific asynchronous DNS libraries?  Please do not implement
1014   //   a custom DNS resolver...
1015 
1016   int fds[2];
1017 #if __linux__ && !__BIONIC__
1018   KJ_SYSCALL(pipe2(fds, O_NONBLOCK | O_CLOEXEC));
1019 #else
1020   KJ_SYSCALL(pipe(fds));
1021 #endif
1022 
1023   auto input = lowLevel.wrapInputFd(fds[0], NEW_FD_FLAGS);
1024 
1025   int outFd = fds[1];
1026 
1027   LookupParams params = { kj::mv(host), kj::mv(service) };
1028 
1029   auto thread = heap<Thread>(kj::mvCapture(params, [outFd,portHint](LookupParams&& params) {
1030     FdOutputStream output((AutoCloseFd(outFd)));
1031 
1032     struct addrinfo* list;
1033     int status = getaddrinfo(
1034         params.host == "*" ? nullptr : params.host.cStr(),
1035         params.service == nullptr ? nullptr : params.service.cStr(),
1036         nullptr, &list);
1037     if (status == 0) {
1038       KJ_DEFER(freeaddrinfo(list));
1039 
1040       struct addrinfo* cur = list;
1041       while (cur != nullptr) {
1042         if (params.service == nullptr) {
1043           switch (cur->ai_addr->sa_family) {
1044             case AF_INET:
1045               ((struct sockaddr_in*)cur->ai_addr)->sin_port = htons(portHint);
1046               break;
1047             case AF_INET6:
1048               ((struct sockaddr_in6*)cur->ai_addr)->sin6_port = htons(portHint);
1049               break;
1050             default:
1051               break;
1052           }
1053         }
1054 
1055         SocketAddress addr;
1056         if (params.host == "*") {
1057           // Set up a wildcard SocketAddress.  Only use the port number returned by getaddrinfo().
1058           addr.wildcard = true;
1059           addr.addrlen = sizeof(addr.addr.inet6);
1060           addr.addr.inet6.sin6_family = AF_INET6;
1061           switch (cur->ai_addr->sa_family) {
1062             case AF_INET:
1063               addr.addr.inet6.sin6_port = ((struct sockaddr_in*)cur->ai_addr)->sin_port;
1064               break;
1065             case AF_INET6:
1066               addr.addr.inet6.sin6_port = ((struct sockaddr_in6*)cur->ai_addr)->sin6_port;
1067               break;
1068             default:
1069               addr.addr.inet6.sin6_port = portHint;
1070               break;
1071           }
1072         } else {
1073           addr.addrlen = cur->ai_addrlen;
1074           memcpy(&addr.addr.generic, cur->ai_addr, cur->ai_addrlen);
1075         }
1076         KJ_ASSERT_CAN_MEMCPY(SocketAddress);
1077         output.write(&addr, sizeof(addr));
1078         cur = cur->ai_next;
1079       }
1080     } else if (status == EAI_SYSTEM) {
1081       KJ_FAIL_SYSCALL("getaddrinfo", errno, params.host, params.service) {
1082         return;
1083       }
1084     } else {
1085       KJ_FAIL_REQUIRE("DNS lookup failed.",
1086                       params.host, params.service, gai_strerror(status)) {
1087         return;
1088       }
1089     }
1090   }));
1091 
1092   auto reader = heap<LookupReader>(kj::mv(thread), kj::mv(input), filter);
1093   return reader->read().attach(kj::mv(reader));
1094 }
1095 
1096 // =======================================================================================
1097 
1098 class FdConnectionReceiver final: public ConnectionReceiver, public OwnedFileDescriptor {
1099 public:
FdConnectionReceiver(LowLevelAsyncIoProvider & lowLevel,UnixEventPort & eventPort,int fd,LowLevelAsyncIoProvider::NetworkFilter & filter,uint flags)1100   FdConnectionReceiver(LowLevelAsyncIoProvider& lowLevel,
1101                        UnixEventPort& eventPort, int fd,
1102                        LowLevelAsyncIoProvider::NetworkFilter& filter, uint flags)
1103       : OwnedFileDescriptor(fd, flags), lowLevel(lowLevel), eventPort(eventPort), filter(filter),
1104         observer(eventPort, fd, UnixEventPort::FdObserver::OBSERVE_READ) {}
1105 
accept()1106   Promise<Own<AsyncIoStream>> accept() override {
1107     return acceptImpl(false).then([](AuthenticatedStream&& a) { return kj::mv(a.stream); });
1108   }
1109 
acceptAuthenticated()1110   Promise<AuthenticatedStream> acceptAuthenticated() override {
1111     return acceptImpl(true);
1112   }
1113 
acceptImpl(bool authenticated)1114   Promise<AuthenticatedStream> acceptImpl(bool authenticated) {
1115     int newFd;
1116 
1117     struct sockaddr_storage addr;
1118     socklen_t addrlen = sizeof(addr);
1119 
1120   retry:
1121 #if __linux__ && !__BIONIC__
1122     newFd = ::accept4(fd, reinterpret_cast<struct sockaddr*>(&addr), &addrlen,
1123                       SOCK_NONBLOCK | SOCK_CLOEXEC);
1124 #else
1125     newFd = ::accept(fd, reinterpret_cast<struct sockaddr*>(&addr), &addrlen);
1126 #endif
1127 
1128     if (newFd >= 0) {
1129       kj::AutoCloseFd ownFd(newFd);
1130       if (!filter.shouldAllow(reinterpret_cast<struct sockaddr*>(&addr), addrlen)) {
1131         // Ignore disallowed address.
1132         return acceptImpl(authenticated);
1133       } else {
1134         // TODO(perf):  As a hack for the 0.4 release we are always setting
1135         //   TCP_NODELAY because Nagle's algorithm pretty much kills Cap'n Proto's
1136         //   RPC protocol.  Later, we should extend the interface to provide more
1137         //   control over this.  Perhaps write() should have a flag which
1138         //   specifies whether to pass MSG_MORE.
1139         int one = 1;
1140         KJ_SYSCALL_HANDLE_ERRORS(::setsockopt(
1141               ownFd.get(), IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(one))) {
1142           case EOPNOTSUPP:
1143           case ENOPROTOOPT: // (returned for AF_UNIX in cygwin)
1144             break;
1145           default:
1146             KJ_FAIL_SYSCALL("setsocketopt(IPPROTO_TCP, TCP_NODELAY)", error);
1147         }
1148 
1149         AuthenticatedStream result;
1150         result.stream = heap<AsyncStreamFd>(eventPort, ownFd.release(), NEW_FD_FLAGS);
1151         if (authenticated) {
1152           result.peerIdentity = SocketAddress(reinterpret_cast<struct sockaddr*>(&addr), addrlen)
1153               .getIdentity(lowLevel, filter, *result.stream);
1154         }
1155         return kj::mv(result);
1156       }
1157     } else {
1158       int error = errno;
1159 
1160       switch (error) {
1161         case EAGAIN:
1162 #if EAGAIN != EWOULDBLOCK
1163         case EWOULDBLOCK:
1164 #endif
1165           // Not ready yet.
1166           return observer.whenBecomesReadable().then([this,authenticated]() {
1167             return acceptImpl(authenticated);
1168           });
1169 
1170         case EINTR:
1171         case ENETDOWN:
1172 #ifdef EPROTO
1173         // EPROTO is not defined on OpenBSD.
1174         case EPROTO:
1175 #endif
1176         case EHOSTDOWN:
1177         case EHOSTUNREACH:
1178         case ENETUNREACH:
1179         case ECONNABORTED:
1180         case ETIMEDOUT:
1181           // According to the Linux man page, accept() may report an error if the accepted
1182           // connection is already broken.  In this case, we really ought to just ignore it and
1183           // keep waiting.  But it's hard to say exactly what errors are such network errors and
1184           // which ones are permanent errors.  We've made a guess here.
1185           goto retry;
1186 
1187         default:
1188           KJ_FAIL_SYSCALL("accept", error);
1189       }
1190 
1191     }
1192   }
1193 
getPort()1194   uint getPort() override {
1195     return SocketAddress::getLocalAddress(fd).getPort();
1196   }
1197 
getsockopt(int level,int option,void * value,uint * length)1198   void getsockopt(int level, int option, void* value, uint* length) override {
1199     socklen_t socklen = *length;
1200     KJ_SYSCALL(::getsockopt(fd, level, option, value, &socklen));
1201     *length = socklen;
1202   }
setsockopt(int level,int option,const void * value,uint length)1203   void setsockopt(int level, int option, const void* value, uint length) override {
1204     KJ_SYSCALL(::setsockopt(fd, level, option, value, length));
1205   }
getsockname(struct sockaddr * addr,uint * length)1206   void getsockname(struct sockaddr* addr, uint* length) override {
1207     socklen_t socklen = *length;
1208     KJ_SYSCALL(::getsockname(fd, addr, &socklen));
1209     *length = socklen;
1210   }
1211 
1212 public:
1213   LowLevelAsyncIoProvider& lowLevel;
1214   UnixEventPort& eventPort;
1215   LowLevelAsyncIoProvider::NetworkFilter& filter;
1216   UnixEventPort::FdObserver observer;
1217 };
1218 
1219 class DatagramPortImpl final: public DatagramPort, public OwnedFileDescriptor {
1220 public:
DatagramPortImpl(LowLevelAsyncIoProvider & lowLevel,UnixEventPort & eventPort,int fd,LowLevelAsyncIoProvider::NetworkFilter & filter,uint flags)1221   DatagramPortImpl(LowLevelAsyncIoProvider& lowLevel, UnixEventPort& eventPort, int fd,
1222                    LowLevelAsyncIoProvider::NetworkFilter& filter, uint flags)
1223       : OwnedFileDescriptor(fd, flags), lowLevel(lowLevel), eventPort(eventPort), filter(filter),
1224         observer(eventPort, fd, UnixEventPort::FdObserver::OBSERVE_READ |
1225                                 UnixEventPort::FdObserver::OBSERVE_WRITE) {}
1226 
1227   Promise<size_t> send(const void* buffer, size_t size, NetworkAddress& destination) override;
1228   Promise<size_t> send(
1229       ArrayPtr<const ArrayPtr<const byte>> pieces, NetworkAddress& destination) override;
1230 
1231   class ReceiverImpl;
1232 
1233   Own<DatagramReceiver> makeReceiver(DatagramReceiver::Capacity capacity) override;
1234 
getPort()1235   uint getPort() override {
1236     return SocketAddress::getLocalAddress(fd).getPort();
1237   }
1238 
getsockopt(int level,int option,void * value,uint * length)1239   void getsockopt(int level, int option, void* value, uint* length) override {
1240     socklen_t socklen = *length;
1241     KJ_SYSCALL(::getsockopt(fd, level, option, value, &socklen));
1242     *length = socklen;
1243   }
setsockopt(int level,int option,const void * value,uint length)1244   void setsockopt(int level, int option, const void* value, uint length) override {
1245     KJ_SYSCALL(::setsockopt(fd, level, option, value, length));
1246   }
1247 
1248 public:
1249   LowLevelAsyncIoProvider& lowLevel;
1250   UnixEventPort& eventPort;
1251   LowLevelAsyncIoProvider::NetworkFilter& filter;
1252   UnixEventPort::FdObserver observer;
1253 };
1254 
1255 class LowLevelAsyncIoProviderImpl final: public LowLevelAsyncIoProvider {
1256 public:
LowLevelAsyncIoProviderImpl()1257   LowLevelAsyncIoProviderImpl()
1258       : eventLoop(eventPort), waitScope(eventLoop) {}
1259 
getWaitScope()1260   inline WaitScope& getWaitScope() { return waitScope; }
1261 
wrapInputFd(int fd,uint flags=0)1262   Own<AsyncInputStream> wrapInputFd(int fd, uint flags = 0) override {
1263     return heap<AsyncStreamFd>(eventPort, fd, flags);
1264   }
wrapOutputFd(int fd,uint flags=0)1265   Own<AsyncOutputStream> wrapOutputFd(int fd, uint flags = 0) override {
1266     return heap<AsyncStreamFd>(eventPort, fd, flags);
1267   }
wrapSocketFd(int fd,uint flags=0)1268   Own<AsyncIoStream> wrapSocketFd(int fd, uint flags = 0) override {
1269     return heap<AsyncStreamFd>(eventPort, fd, flags);
1270   }
wrapUnixSocketFd(Fd fd,uint flags=0)1271   Own<AsyncCapabilityStream> wrapUnixSocketFd(Fd fd, uint flags = 0) override {
1272     return heap<AsyncStreamFd>(eventPort, fd, flags);
1273   }
wrapConnectingSocketFd(int fd,const struct sockaddr * addr,uint addrlen,uint flags=0)1274   Promise<Own<AsyncIoStream>> wrapConnectingSocketFd(
1275       int fd, const struct sockaddr* addr, uint addrlen, uint flags = 0) override {
1276     // It's important that we construct the AsyncStreamFd first, so that `flags` are honored,
1277     // especially setting nonblocking mode and taking ownership.
1278     auto result = heap<AsyncStreamFd>(eventPort, fd, flags);
1279 
1280     // Unfortunately connect() doesn't fit the mold of KJ_NONBLOCKING_SYSCALL, since it indicates
1281     // non-blocking using EINPROGRESS.
1282     for (;;) {
1283       if (::connect(fd, addr, addrlen) < 0) {
1284         int error = errno;
1285         if (error == EINPROGRESS) {
1286           // Fine.
1287           break;
1288         } else if (error != EINTR) {
1289           KJ_FAIL_SYSCALL("connect()", error) { break; }
1290           return Own<AsyncIoStream>();
1291         }
1292       } else {
1293         // no error
1294         break;
1295       }
1296     }
1297 
1298     auto connected = result->waitConnected();
1299     return connected.then(kj::mvCapture(result, [fd](Own<AsyncIoStream>&& stream) {
1300       int err;
1301       socklen_t errlen = sizeof(err);
1302       KJ_SYSCALL(getsockopt(fd, SOL_SOCKET, SO_ERROR, &err, &errlen));
1303       if (err != 0) {
1304         KJ_FAIL_SYSCALL("connect()", err) { break; }
1305       }
1306       return kj::mv(stream);
1307     }));
1308   }
wrapListenSocketFd(int fd,NetworkFilter & filter,uint flags=0)1309   Own<ConnectionReceiver> wrapListenSocketFd(
1310       int fd, NetworkFilter& filter, uint flags = 0) override {
1311     return heap<FdConnectionReceiver>(*this, eventPort, fd, filter, flags);
1312   }
wrapDatagramSocketFd(int fd,NetworkFilter & filter,uint flags=0)1313   Own<DatagramPort> wrapDatagramSocketFd(
1314       int fd, NetworkFilter& filter, uint flags = 0) override {
1315     return heap<DatagramPortImpl>(*this, eventPort, fd, filter, flags);
1316   }
1317 
getTimer()1318   Timer& getTimer() override { return eventPort.getTimer(); }
1319 
getEventPort()1320   UnixEventPort& getEventPort() { return eventPort; }
1321 
1322 private:
1323   UnixEventPort eventPort;
1324   EventLoop eventLoop;
1325   WaitScope waitScope;
1326 };
1327 
1328 // =======================================================================================
1329 
1330 class NetworkAddressImpl final: public NetworkAddress {
1331 public:
NetworkAddressImpl(LowLevelAsyncIoProvider & lowLevel,LowLevelAsyncIoProvider::NetworkFilter & filter,Array<SocketAddress> addrs)1332   NetworkAddressImpl(LowLevelAsyncIoProvider& lowLevel,
1333                      LowLevelAsyncIoProvider::NetworkFilter& filter,
1334                      Array<SocketAddress> addrs)
1335       : lowLevel(lowLevel), filter(filter), addrs(kj::mv(addrs)) {}
1336 
connect()1337   Promise<Own<AsyncIoStream>> connect() override {
1338     auto addrsCopy = heapArray(addrs.asPtr());
1339     auto promise = connectImpl(lowLevel, filter, addrsCopy, false);
1340     return promise.attach(kj::mv(addrsCopy))
1341         .then([](AuthenticatedStream&& a) { return kj::mv(a.stream); });
1342   }
1343 
connectAuthenticated()1344   Promise<AuthenticatedStream> connectAuthenticated() override {
1345     auto addrsCopy = heapArray(addrs.asPtr());
1346     auto promise = connectImpl(lowLevel, filter, addrsCopy, true);
1347     return promise.attach(kj::mv(addrsCopy));
1348   }
1349 
listen()1350   Own<ConnectionReceiver> listen() override {
1351     if (addrs.size() > 1) {
1352       KJ_LOG(WARNING, "Bind address resolved to multiple addresses.  Only the first address will "
1353           "be used.  If this is incorrect, specify the address numerically.  This may be fixed "
1354           "in the future.", addrs[0].toString());
1355     }
1356 
1357     int fd = addrs[0].socket(SOCK_STREAM);
1358 
1359     {
1360       KJ_ON_SCOPE_FAILURE(close(fd));
1361 
1362       // We always enable SO_REUSEADDR because having to take your server down for five minutes
1363       // before it can restart really sucks.
1364       int optval = 1;
1365       KJ_SYSCALL(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval)));
1366 
1367       addrs[0].bind(fd);
1368 
1369       // TODO(someday):  Let queue size be specified explicitly in string addresses.
1370       KJ_SYSCALL(::listen(fd, SOMAXCONN));
1371     }
1372 
1373     return lowLevel.wrapListenSocketFd(fd, filter, NEW_FD_FLAGS);
1374   }
1375 
bindDatagramPort()1376   Own<DatagramPort> bindDatagramPort() override {
1377     if (addrs.size() > 1) {
1378       KJ_LOG(WARNING, "Bind address resolved to multiple addresses.  Only the first address will "
1379           "be used.  If this is incorrect, specify the address numerically.  This may be fixed "
1380           "in the future.", addrs[0].toString());
1381     }
1382 
1383     int fd = addrs[0].socket(SOCK_DGRAM);
1384 
1385     {
1386       KJ_ON_SCOPE_FAILURE(close(fd));
1387 
1388       // We always enable SO_REUSEADDR because having to take your server down for five minutes
1389       // before it can restart really sucks.
1390       int optval = 1;
1391       KJ_SYSCALL(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval)));
1392 
1393       addrs[0].bind(fd);
1394     }
1395 
1396     return lowLevel.wrapDatagramSocketFd(fd, filter, NEW_FD_FLAGS);
1397   }
1398 
clone()1399   Own<NetworkAddress> clone() override {
1400     return kj::heap<NetworkAddressImpl>(lowLevel, filter, kj::heapArray(addrs.asPtr()));
1401   }
1402 
toString()1403   String toString() override {
1404     return strArray(KJ_MAP(addr, addrs) { return addr.toString(); }, ",");
1405   }
1406 
chooseOneAddress()1407   const SocketAddress& chooseOneAddress() {
1408     KJ_REQUIRE(addrs.size() > 0, "No addresses available.");
1409     return addrs[counter++ % addrs.size()];
1410   }
1411 
1412 private:
1413   LowLevelAsyncIoProvider& lowLevel;
1414   LowLevelAsyncIoProvider::NetworkFilter& filter;
1415   Array<SocketAddress> addrs;
1416   uint counter = 0;
1417 
connectImpl(LowLevelAsyncIoProvider & lowLevel,LowLevelAsyncIoProvider::NetworkFilter & filter,ArrayPtr<SocketAddress> addrs,bool authenticated)1418   static Promise<AuthenticatedStream> connectImpl(
1419       LowLevelAsyncIoProvider& lowLevel,
1420       LowLevelAsyncIoProvider::NetworkFilter& filter,
1421       ArrayPtr<SocketAddress> addrs,
1422       bool authenticated) {
1423     KJ_ASSERT(addrs.size() > 0);
1424 
1425     return kj::evalNow([&]() -> Promise<Own<AsyncIoStream>> {
1426       if (!addrs[0].allowedBy(filter)) {
1427         return KJ_EXCEPTION(FAILED, "connect() blocked by restrictPeers()");
1428       } else {
1429         int fd = addrs[0].socket(SOCK_STREAM);
1430         return lowLevel.wrapConnectingSocketFd(
1431             fd, addrs[0].getRaw(), addrs[0].getRawSize(), NEW_FD_FLAGS);
1432       }
1433     }).then([&lowLevel,&filter,addrs,authenticated](Own<AsyncIoStream>&& stream)
1434         -> Promise<AuthenticatedStream> {
1435       // Success, pass along.
1436       AuthenticatedStream result;
1437       result.stream = kj::mv(stream);
1438       if (authenticated) {
1439         result.peerIdentity = addrs[0].getIdentity(lowLevel, filter, *result.stream);
1440       }
1441       return kj::mv(result);
1442     }, [&lowLevel,&filter,addrs,authenticated](Exception&& exception) mutable
1443         -> Promise<AuthenticatedStream> {
1444       // Connect failed.
1445       if (addrs.size() > 1) {
1446         // Try the next address instead.
1447         return connectImpl(lowLevel, filter, addrs.slice(1, addrs.size()), authenticated);
1448       } else {
1449         // No more addresses to try, so propagate the exception.
1450         return kj::mv(exception);
1451       }
1452     });
1453   }
1454 };
1455 
getIdentity(kj::LowLevelAsyncIoProvider & llaiop,LowLevelAsyncIoProvider::NetworkFilter & filter,AsyncIoStream & stream) const1456 kj::Own<PeerIdentity> SocketAddress::getIdentity(kj::LowLevelAsyncIoProvider& llaiop,
1457                                                  LowLevelAsyncIoProvider::NetworkFilter& filter,
1458                                                  AsyncIoStream& stream) const {
1459   switch (addr.generic.sa_family) {
1460     case AF_INET:
1461     case AF_INET6: {
1462       auto builder = kj::heapArrayBuilder<SocketAddress>(1);
1463       builder.add(*this);
1464       return NetworkPeerIdentity::newInstance(
1465           kj::heap<NetworkAddressImpl>(llaiop, filter, builder.finish()));
1466     }
1467     case AF_UNIX: {
1468       LocalPeerIdentity::Credentials result;
1469 
1470       // There is little documentation on what happens when the uid/pid can't be obtained, but I've
1471       // seen vague references on the internet saying that a PID of 0 and a UID of uid_t(-1) are used
1472       // as invalid values.
1473 
1474 #if defined(SO_PEERCRED)
1475       struct ucred creds;
1476       uint length = sizeof(creds);
1477       stream.getsockopt(SOL_SOCKET, SO_PEERCRED, &creds, &length);
1478       if (creds.pid > 0) {
1479         result.pid = creds.pid;
1480       }
1481       if (creds.uid != static_cast<uid_t>(-1)) {
1482         result.uid = creds.uid;
1483       }
1484 
1485 #elif defined(LOCAL_PEERCRED)
1486       // MacOS / FreeBSD
1487       struct xucred creds;
1488       uint length = sizeof(creds);
1489 #if defined SOL_LOCAL
1490       stream.getsockopt(SOL_LOCAL, LOCAL_PEERCRED, &creds, &length);
1491 #else
1492       stream.getsockopt(0, LOCAL_PEERCRED, &creds, &length);
1493 #endif
1494       KJ_ASSERT(length == sizeof(creds));
1495       if (creds.cr_uid != static_cast<uid_t>(-1)) {
1496         result.uid = creds.cr_uid;
1497       }
1498 
1499 #if defined(LOCAL_PEERPID)
1500       // MacOS only?
1501       pid_t pid;
1502       length = sizeof(pid);
1503       stream.getsockopt(SOL_LOCAL, LOCAL_PEERPID, &pid, &length);
1504       KJ_ASSERT(length == sizeof(pid));
1505       if (pid > 0) {
1506         result.pid = pid;
1507       }
1508 #endif
1509 #endif
1510 
1511       return LocalPeerIdentity::newInstance(result);
1512     }
1513     default:
1514       return UnknownPeerIdentity::newInstance();
1515   }
1516 }
1517 
1518 class SocketNetwork final: public Network {
1519 public:
SocketNetwork(LowLevelAsyncIoProvider & lowLevel)1520   explicit SocketNetwork(LowLevelAsyncIoProvider& lowLevel): lowLevel(lowLevel) {}
SocketNetwork(SocketNetwork & parent,kj::ArrayPtr<const kj::StringPtr> allow,kj::ArrayPtr<const kj::StringPtr> deny)1521   explicit SocketNetwork(SocketNetwork& parent,
1522                          kj::ArrayPtr<const kj::StringPtr> allow,
1523                          kj::ArrayPtr<const kj::StringPtr> deny)
1524       : lowLevel(parent.lowLevel), filter(allow, deny, parent.filter) {}
1525 
parseAddress(StringPtr addr,uint portHint=0)1526   Promise<Own<NetworkAddress>> parseAddress(StringPtr addr, uint portHint = 0) override {
1527     return evalLater(mvCapture(heapString(addr), [this,portHint](String&& addr) {
1528       return SocketAddress::parse(lowLevel, addr, portHint, filter);
1529     })).then([this](Array<SocketAddress> addresses) -> Own<NetworkAddress> {
1530       return heap<NetworkAddressImpl>(lowLevel, filter, kj::mv(addresses));
1531     });
1532   }
1533 
getSockaddr(const void * sockaddr,uint len)1534   Own<NetworkAddress> getSockaddr(const void* sockaddr, uint len) override {
1535     auto array = kj::heapArrayBuilder<SocketAddress>(1);
1536     array.add(SocketAddress(sockaddr, len));
1537     KJ_REQUIRE(array[0].allowedBy(filter), "address blocked by restrictPeers()") { break; }
1538     return Own<NetworkAddress>(heap<NetworkAddressImpl>(lowLevel, filter, array.finish()));
1539   }
1540 
restrictPeers(kj::ArrayPtr<const kj::StringPtr> allow,kj::ArrayPtr<const kj::StringPtr> deny=nullptr)1541   Own<Network> restrictPeers(
1542       kj::ArrayPtr<const kj::StringPtr> allow,
1543       kj::ArrayPtr<const kj::StringPtr> deny = nullptr) override {
1544     return heap<SocketNetwork>(*this, allow, deny);
1545   }
1546 
1547 private:
1548   LowLevelAsyncIoProvider& lowLevel;
1549   _::NetworkFilter filter;
1550 };
1551 
1552 // =======================================================================================
1553 
send(const void * buffer,size_t size,NetworkAddress & destination)1554 Promise<size_t> DatagramPortImpl::send(
1555     const void* buffer, size_t size, NetworkAddress& destination) {
1556   auto& addr = downcast<NetworkAddressImpl>(destination).chooseOneAddress();
1557 
1558   ssize_t n;
1559   KJ_NONBLOCKING_SYSCALL(n = sendto(fd, buffer, size, 0, addr.getRaw(), addr.getRawSize()));
1560   if (n < 0) {
1561     // Write buffer full.
1562     return observer.whenBecomesWritable().then([this, buffer, size, &destination]() {
1563       return send(buffer, size, destination);
1564     });
1565   } else {
1566     // If less than the whole message was sent, then it got truncated, and there's nothing we can
1567     // do about it.
1568     return n;
1569   }
1570 }
1571 
send(ArrayPtr<const ArrayPtr<const byte>> pieces,NetworkAddress & destination)1572 Promise<size_t> DatagramPortImpl::send(
1573     ArrayPtr<const ArrayPtr<const byte>> pieces, NetworkAddress& destination) {
1574   struct msghdr msg;
1575   memset(&msg, 0, sizeof(msg));
1576 
1577   auto& addr = downcast<NetworkAddressImpl>(destination).chooseOneAddress();
1578   msg.msg_name = const_cast<void*>(implicitCast<const void*>(addr.getRaw()));
1579   msg.msg_namelen = addr.getRawSize();
1580 
1581   const size_t iovmax = kj::miniposix::iovMax();
1582   KJ_STACK_ARRAY(struct iovec, iov, kj::min(pieces.size(), iovmax), 16, 64);
1583 
1584   for (size_t i: kj::indices(pieces)) {
1585     iov[i].iov_base = const_cast<void*>(implicitCast<const void*>(pieces[i].begin()));
1586     iov[i].iov_len = pieces[i].size();
1587   }
1588 
1589   Array<byte> extra;
1590   if (pieces.size() > iovmax) {
1591     // Too many pieces, but we can't use multiple syscalls because they'd send separate
1592     // datagrams. We'll have to copy the trailing pieces into a temporary array.
1593     //
1594     // TODO(perf): On Linux we could use multiple syscalls via MSG_MORE or sendmsg/sendmmsg.
1595     size_t extraSize = 0;
1596     for (size_t i = iovmax - 1; i < pieces.size(); i++) {
1597       extraSize += pieces[i].size();
1598     }
1599     extra = kj::heapArray<byte>(extraSize);
1600     extraSize = 0;
1601     for (size_t i = iovmax - 1; i < pieces.size(); i++) {
1602       memcpy(extra.begin() + extraSize, pieces[i].begin(), pieces[i].size());
1603       extraSize += pieces[i].size();
1604     }
1605     iov.back().iov_base = extra.begin();
1606     iov.back().iov_len = extra.size();
1607   }
1608 
1609   msg.msg_iov = iov.begin();
1610   msg.msg_iovlen = iov.size();
1611 
1612   ssize_t n;
1613   KJ_NONBLOCKING_SYSCALL(n = sendmsg(fd, &msg, 0));
1614   if (n < 0) {
1615     // Write buffer full.
1616     return observer.whenBecomesWritable().then([this, pieces, &destination]() {
1617       return send(pieces, destination);
1618     });
1619   } else {
1620     // If less than the whole message was sent, then it was truncated, and there's nothing we can
1621     // do about that now.
1622     return n;
1623   }
1624 }
1625 
1626 class DatagramPortImpl::ReceiverImpl final: public DatagramReceiver {
1627 public:
ReceiverImpl(DatagramPortImpl & port,Capacity capacity)1628   explicit ReceiverImpl(DatagramPortImpl& port, Capacity capacity)
1629       : port(port),
1630         contentBuffer(heapArray<byte>(capacity.content)),
1631         ancillaryBuffer(capacity.ancillary > 0 ? heapArray<byte>(capacity.ancillary)
1632                                                : Array<byte>(nullptr)) {}
1633 
receive()1634   Promise<void> receive() override {
1635     struct msghdr msg;
1636     memset(&msg, 0, sizeof(msg));
1637 
1638     struct sockaddr_storage addr;
1639     memset(&addr, 0, sizeof(addr));
1640     msg.msg_name = &addr;
1641     msg.msg_namelen = sizeof(addr);
1642 
1643     struct iovec iov;
1644     iov.iov_base = contentBuffer.begin();
1645     iov.iov_len = contentBuffer.size();
1646     msg.msg_iov = &iov;
1647     msg.msg_iovlen = 1;
1648     msg.msg_control = ancillaryBuffer.begin();
1649     msg.msg_controllen = ancillaryBuffer.size();
1650 
1651     ssize_t n;
1652     KJ_NONBLOCKING_SYSCALL(n = recvmsg(port.fd, &msg, 0));
1653 
1654     if (n < 0) {
1655       // No data available. Wait.
1656       return port.observer.whenBecomesReadable().then([this]() {
1657         return receive();
1658       });
1659     } else {
1660       if (!port.filter.shouldAllow(reinterpret_cast<const struct sockaddr*>(msg.msg_name),
1661                                    msg.msg_namelen)) {
1662         // Ignore message from disallowed source.
1663         return receive();
1664       }
1665 
1666       receivedSize = n;
1667       contentTruncated = msg.msg_flags & MSG_TRUNC;
1668 
1669       source.emplace(port.lowLevel, port.filter, msg.msg_name, msg.msg_namelen);
1670 
1671       ancillaryList.resize(0);
1672       ancillaryTruncated = msg.msg_flags & MSG_CTRUNC;
1673 
1674       for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); cmsg != nullptr;
1675            cmsg = CMSG_NXTHDR(&msg, cmsg)) {
1676         // On some platforms (OSX), a cmsghdr's length may cross the end of the ancillary buffer
1677         // when truncated. On other platforms (Linux) the length in cmsghdr will itself be
1678         // truncated to fit within the buffer.
1679 
1680 #if __APPLE__
1681 // On MacOS, `CMSG_SPACE(0)` triggers a bogus warning.
1682 #pragma GCC diagnostic ignored "-Wnull-pointer-arithmetic"
1683 #endif
1684         const byte* pos = reinterpret_cast<const byte*>(cmsg);
1685         size_t available = ancillaryBuffer.end() - pos;
1686         if (available < CMSG_SPACE(0)) {
1687           // The buffer ends in the middle of the header. We can't use this message.
1688           // (On Linux, this never happens, because the message is not included if there isn't
1689           // space for a header. I'm not sure how other systems behave, though, so let's be safe.)
1690           break;
1691         }
1692 
1693         // OK, we know the cmsghdr is valid, at least.
1694 
1695         // Find the start of the message payload.
1696         const byte* begin = (const byte *)CMSG_DATA(cmsg);
1697 
1698         // Cap the message length to the available space.
1699         const byte* end = pos + kj::min(available, cmsg->cmsg_len);
1700 
1701         ancillaryList.add(AncillaryMessage(
1702             cmsg->cmsg_level, cmsg->cmsg_type, arrayPtr(begin, end)));
1703       }
1704 
1705       return READY_NOW;
1706     }
1707   }
1708 
getContent()1709   MaybeTruncated<ArrayPtr<const byte>> getContent() override {
1710     return { contentBuffer.slice(0, receivedSize), contentTruncated };
1711   }
1712 
getAncillary()1713   MaybeTruncated<ArrayPtr<const AncillaryMessage>> getAncillary() override {
1714     return { ancillaryList.asPtr(), ancillaryTruncated };
1715   }
1716 
getSource()1717   NetworkAddress& getSource() override {
1718     return KJ_REQUIRE_NONNULL(source, "Haven't sent a message yet.").abstract;
1719   }
1720 
1721 private:
1722   DatagramPortImpl& port;
1723   Array<byte> contentBuffer;
1724   Array<byte> ancillaryBuffer;
1725   Vector<AncillaryMessage> ancillaryList;
1726   size_t receivedSize = 0;
1727   bool contentTruncated = false;
1728   bool ancillaryTruncated = false;
1729 
1730   struct StoredAddress {
StoredAddresskj::__anona1d21b730111::DatagramPortImpl::ReceiverImpl::StoredAddress1731     StoredAddress(LowLevelAsyncIoProvider& lowLevel, LowLevelAsyncIoProvider::NetworkFilter& filter,
1732                   const void* sockaddr, uint length)
1733         : raw(sockaddr, length),
1734           abstract(lowLevel, filter, Array<SocketAddress>(&raw, 1, NullArrayDisposer::instance)) {}
1735 
1736     SocketAddress raw;
1737     NetworkAddressImpl abstract;
1738   };
1739 
1740   kj::Maybe<StoredAddress> source;
1741 };
1742 
makeReceiver(DatagramReceiver::Capacity capacity)1743 Own<DatagramReceiver> DatagramPortImpl::makeReceiver(DatagramReceiver::Capacity capacity) {
1744   return kj::heap<ReceiverImpl>(*this, capacity);
1745 }
1746 
1747 // =======================================================================================
1748 
1749 class AsyncIoProviderImpl final: public AsyncIoProvider {
1750 public:
AsyncIoProviderImpl(LowLevelAsyncIoProvider & lowLevel)1751   AsyncIoProviderImpl(LowLevelAsyncIoProvider& lowLevel)
1752       : lowLevel(lowLevel), network(lowLevel) {}
1753 
newOneWayPipe()1754   OneWayPipe newOneWayPipe() override {
1755     int fds[2];
1756 #if __linux__ && !__BIONIC__
1757     KJ_SYSCALL(pipe2(fds, O_NONBLOCK | O_CLOEXEC));
1758 #else
1759     KJ_SYSCALL(pipe(fds));
1760 #endif
1761     return OneWayPipe {
1762       lowLevel.wrapInputFd(fds[0], NEW_FD_FLAGS),
1763       lowLevel.wrapOutputFd(fds[1], NEW_FD_FLAGS)
1764     };
1765   }
1766 
newTwoWayPipe()1767   TwoWayPipe newTwoWayPipe() override {
1768     int fds[2];
1769     int type = SOCK_STREAM;
1770 #if __linux__ && !__BIONIC__
1771     type |= SOCK_NONBLOCK | SOCK_CLOEXEC;
1772 #endif
1773     KJ_SYSCALL(socketpair(AF_UNIX, type, 0, fds));
1774     return TwoWayPipe { {
1775       lowLevel.wrapSocketFd(fds[0], NEW_FD_FLAGS),
1776       lowLevel.wrapSocketFd(fds[1], NEW_FD_FLAGS)
1777     } };
1778   }
1779 
newCapabilityPipe()1780   CapabilityPipe newCapabilityPipe() override {
1781     int fds[2];
1782     int type = SOCK_STREAM;
1783 #if __linux__ && !__BIONIC__
1784     type |= SOCK_NONBLOCK | SOCK_CLOEXEC;
1785 #endif
1786     KJ_SYSCALL(socketpair(AF_UNIX, type, 0, fds));
1787     return CapabilityPipe { {
1788       lowLevel.wrapUnixSocketFd(fds[0], NEW_FD_FLAGS),
1789       lowLevel.wrapUnixSocketFd(fds[1], NEW_FD_FLAGS)
1790     } };
1791   }
1792 
getNetwork()1793   Network& getNetwork() override {
1794     return network;
1795   }
1796 
newPipeThread(Function<void (AsyncIoProvider &,AsyncIoStream &,WaitScope &)> startFunc)1797   PipeThread newPipeThread(
1798       Function<void(AsyncIoProvider&, AsyncIoStream&, WaitScope&)> startFunc) override {
1799     int fds[2];
1800     int type = SOCK_STREAM;
1801 #if __linux__ && !__BIONIC__
1802     type |= SOCK_NONBLOCK | SOCK_CLOEXEC;
1803 #endif
1804     KJ_SYSCALL(socketpair(AF_UNIX, type, 0, fds));
1805 
1806     int threadFd = fds[1];
1807     KJ_ON_SCOPE_FAILURE(close(threadFd));
1808 
1809     auto pipe = lowLevel.wrapSocketFd(fds[0], NEW_FD_FLAGS);
1810 
1811     auto thread = heap<Thread>(kj::mvCapture(startFunc,
1812         [threadFd](Function<void(AsyncIoProvider&, AsyncIoStream&, WaitScope&)>&& startFunc) {
1813       LowLevelAsyncIoProviderImpl lowLevel;
1814       auto stream = lowLevel.wrapSocketFd(threadFd, NEW_FD_FLAGS);
1815       AsyncIoProviderImpl ioProvider(lowLevel);
1816       startFunc(ioProvider, *stream, lowLevel.getWaitScope());
1817     }));
1818 
1819     return { kj::mv(thread), kj::mv(pipe) };
1820   }
1821 
getTimer()1822   Timer& getTimer() override { return lowLevel.getTimer(); }
1823 
1824 private:
1825   LowLevelAsyncIoProvider& lowLevel;
1826   SocketNetwork network;
1827 };
1828 
1829 }  // namespace
1830 
newAsyncIoProvider(LowLevelAsyncIoProvider & lowLevel)1831 Own<AsyncIoProvider> newAsyncIoProvider(LowLevelAsyncIoProvider& lowLevel) {
1832   return kj::heap<AsyncIoProviderImpl>(lowLevel);
1833 }
1834 
setupAsyncIo()1835 AsyncIoContext setupAsyncIo() {
1836   auto lowLevel = heap<LowLevelAsyncIoProviderImpl>();
1837   auto ioProvider = kj::heap<AsyncIoProviderImpl>(*lowLevel);
1838   auto& waitScope = lowLevel->getWaitScope();
1839   auto& eventPort = lowLevel->getEventPort();
1840   return { kj::mv(lowLevel), kj::mv(ioProvider), waitScope, eventPort };
1841 }
1842 
1843 }  // namespace kj
1844 
1845 #endif  // !_WIN32
1846