1 // Copyright (c) 2016 Sandstorm Development Group, Inc. and contributors
2 // Licensed under the MIT License:
3 //
4 // Permission is hereby granted, free of charge, to any person obtaining a copy
5 // of this software and associated documentation files (the "Software"), to deal
6 // in the Software without restriction, including without limitation the rights
7 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 // copies of the Software, and to permit persons to whom the Software is
9 // furnished to do so, subject to the following conditions:
10 //
11 // The above copyright notice and this permission notice shall be included in
12 // all copies or substantial portions of the Software.
13 //
14 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 // THE SOFTWARE.
21 
22 #if _WIN32
23 // For Unix implementation, see async-io-unix.c++.
24 
25 // Request Vista-level APIs.
26 #define WINVER 0x0600
27 #define _WIN32_WINNT 0x0600
28 
29 #include "async-io.h"
30 #include "async-io-internal.h"
31 #include "async-win32.h"
32 #include "debug.h"
33 #include "thread.h"
34 #include "io.h"
35 #include "vector.h"
36 #include <set>
37 
38 #include <winsock2.h>
39 #include <ws2ipdef.h>
40 #include <ws2tcpip.h>
41 #include <mswsock.h>
42 #include <stdlib.h>
43 
44 #ifndef IPV6_V6ONLY
45 // MinGW's headers are missing this.
46 #define IPV6_V6ONLY 27
47 #endif
48 
49 namespace kj {
50 
51 namespace _ {  // private
52 
53 struct WinsockInitializer {
WinsockInitializerkj::_::WinsockInitializer54   WinsockInitializer() {
55     WSADATA dontcare;
56     int result = WSAStartup(MAKEWORD(2, 2), &dontcare);
57     if (result != 0) {
58       KJ_FAIL_WIN32("WSAStartup()", result);
59     }
60   }
61 };
62 
initWinsockOnce()63 void initWinsockOnce() {
64   static WinsockInitializer initializer;
65 }
66 
win32Socketpair(SOCKET socks[2])67 int win32Socketpair(SOCKET socks[2]) {
68   // This function from: https://github.com/ncm/selectable-socketpair/blob/master/socketpair.c
69   //
70   // Copyright notice:
71   //
72   //    Copyright 2007, 2010 by Nathan C. Myers <ncm@cantrip.org>
73   //    Redistribution and use in source and binary forms, with or without modification,
74   //    are permitted provided that the following conditions are met:
75   //
76   //        Redistributions of source code must retain the above copyright notice, this
77   //        list of conditions and the following disclaimer.
78   //
79   //        Redistributions in binary form must reproduce the above copyright notice,
80   //        this list of conditions and the following disclaimer in the documentation
81   //        and/or other materials provided with the distribution.
82   //
83   //        The name of the author must not be used to endorse or promote products
84   //        derived from this software without specific prior written permission.
85   //
86   //    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
87   //    AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
88   //    IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
89   //    DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
90   //    FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
91   //    DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
92   //    SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
93   //    CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
94   //    OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
95   //    OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
96 
97   // Note: This function is called from some Cap'n Proto unit tests, despite not having a public
98   //   header declaration.
99   // TODO(cleanup): Consider putting this somewhere public? Note that since it depends on Winsock,
100   //   it needs to be in the kj-async library.
101 
102   initWinsockOnce();
103 
104   union {
105     struct sockaddr_in inaddr;
106     struct sockaddr addr;
107   } a;
108   SOCKET listener;
109   int e;
110   socklen_t addrlen = sizeof(a.inaddr);
111   int reuse = 1;
112 
113   if (socks == 0) {
114     WSASetLastError(WSAEINVAL);
115     return SOCKET_ERROR;
116   }
117   socks[0] = socks[1] = -1;
118 
119   listener = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
120   if (listener == -1)
121     return SOCKET_ERROR;
122 
123   memset(&a, 0, sizeof(a));
124   a.inaddr.sin_family = AF_INET;
125   a.inaddr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
126   a.inaddr.sin_port = 0;
127 
128   for (;;) {
129     if (setsockopt(listener, SOL_SOCKET, SO_REUSEADDR,
130            (char*) &reuse, (socklen_t) sizeof(reuse)) == -1)
131       break;
132     if  (bind(listener, &a.addr, sizeof(a.inaddr)) == SOCKET_ERROR)
133       break;
134 
135     memset(&a, 0, sizeof(a));
136     if  (getsockname(listener, &a.addr, &addrlen) == SOCKET_ERROR)
137       break;
138     // win32 getsockname may only set the port number, p=0.0005.
139     // ( http://msdn.microsoft.com/library/ms738543.aspx ):
140     a.inaddr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
141     a.inaddr.sin_family = AF_INET;
142 
143     if (listen(listener, 1) == SOCKET_ERROR)
144       break;
145 
146     socks[0] = WSASocket(AF_INET, SOCK_STREAM, 0, NULL, 0, WSA_FLAG_OVERLAPPED);
147     if (socks[0] == -1)
148       break;
149     if (connect(socks[0], &a.addr, sizeof(a.inaddr)) == SOCKET_ERROR)
150       break;
151 
152   retryAccept:
153     socks[1] = accept(listener, NULL, NULL);
154     if (socks[1] == -1)
155       break;
156 
157     // Verify that the client is actually us and not someone else who raced to connect first.
158     // (This check added by Kenton for security.)
159     union {
160       struct sockaddr_in inaddr;
161       struct sockaddr addr;
162     } b, c;
163     socklen_t bAddrlen = sizeof(b.inaddr);
164     socklen_t cAddrlen = sizeof(b.inaddr);
165     if (getpeername(socks[1], &b.addr, &bAddrlen) == SOCKET_ERROR)
166       break;
167     if (getsockname(socks[0], &c.addr, &cAddrlen) == SOCKET_ERROR)
168       break;
169     if (bAddrlen != cAddrlen || memcmp(&b.addr, &c.addr, bAddrlen) != 0) {
170       // Someone raced to connect first. Ignore.
171       closesocket(socks[1]);
172       goto retryAccept;
173     }
174 
175     closesocket(listener);
176     return 0;
177   }
178 
179   e = WSAGetLastError();
180   closesocket(listener);
181   closesocket(socks[0]);
182   closesocket(socks[1]);
183   WSASetLastError(e);
184   socks[0] = socks[1] = -1;
185   return SOCKET_ERROR;
186 }
187 
188 }  // namespace _
189 
190 namespace {
191 
192 // =======================================================================================
193 
194 static constexpr uint NEW_FD_FLAGS = LowLevelAsyncIoProvider::TAKE_OWNERSHIP;
195 
196 class OwnedFd {
197 public:
OwnedFd(SOCKET fd,uint flags)198   OwnedFd(SOCKET fd, uint flags): fd(fd), flags(flags) {
199     // TODO(perf): Maybe use SetFileCompletionNotificationModes() to tell Windows not to bother
200     //   delivering an event when the operation completes inline. Not currently implemented on
201     //   Wine, though.
202   }
203 
~OwnedFd()204   ~OwnedFd() noexcept(false) {
205     if (flags & LowLevelAsyncIoProvider::TAKE_OWNERSHIP) {
206       KJ_WINSOCK(closesocket(fd)) { break; }
207     }
208   }
209 
210 protected:
211   SOCKET fd;
212 
213 private:
214   uint flags;
215 };
216 
217 // =======================================================================================
218 
219 class AsyncStreamFd: public OwnedFd, public AsyncIoStream {
220 public:
AsyncStreamFd(Win32EventPort & eventPort,SOCKET fd,uint flags)221   AsyncStreamFd(Win32EventPort& eventPort, SOCKET fd, uint flags)
222       : OwnedFd(fd, flags),
223         observer(eventPort.observeIo(reinterpret_cast<HANDLE>(fd))) {}
~AsyncStreamFd()224   virtual ~AsyncStreamFd() noexcept(false) {}
225 
read(void * buffer,size_t minBytes,size_t maxBytes)226   Promise<size_t> read(void* buffer, size_t minBytes, size_t maxBytes) override {
227     return tryRead(buffer, minBytes, maxBytes).then([=](size_t result) {
228       KJ_REQUIRE(result >= minBytes, "Premature EOF") {
229         // Pretend we read zeros from the input.
230         memset(reinterpret_cast<byte*>(buffer) + result, 0, minBytes - result);
231         return minBytes;
232       }
233       return result;
234     });
235   }
236 
tryRead(void * buffer,size_t minBytes,size_t maxBytes)237   Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
238     auto bufs = heapArray<WSABUF>(1);
239     bufs[0].buf = reinterpret_cast<char*>(buffer);
240     bufs[0].len = maxBytes;
241 
242     ArrayPtr<WSABUF> ref = bufs;
243     return tryReadInternal(ref, minBytes, 0).attach(kj::mv(bufs));
244   }
245 
write(const void * buffer,size_t size)246   Promise<void> write(const void* buffer, size_t size) override {
247     auto bufs = heapArray<WSABUF>(1);
248     bufs[0].buf = const_cast<char*>(reinterpret_cast<const char*>(buffer));
249     bufs[0].len = size;
250 
251     ArrayPtr<WSABUF> ref = bufs;
252     return writeInternal(ref).attach(kj::mv(bufs));
253   }
254 
write(ArrayPtr<const ArrayPtr<const byte>> pieces)255   Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override {
256     auto bufs = heapArray<WSABUF>(pieces.size());
257     for (auto i: kj::indices(pieces)) {
258       bufs[i].buf = const_cast<char*>(pieces[i].asChars().begin());
259       bufs[i].len = pieces[i].size();
260     }
261 
262     ArrayPtr<WSABUF> ref = bufs;
263     return writeInternal(ref).attach(kj::mv(bufs));
264   }
265 
connect(const struct sockaddr * addr,uint addrlen)266   kj::Promise<void> connect(const struct sockaddr* addr, uint addrlen) {
267     // In order to connect asynchronously, we need the ConnectEx() function. Apparently, we have
268     // to query the socket for it dynamically, I guess because of the insanity in which winsock
269     // can be implemented in userspace and old implementations may not support it.
270     GUID guid = WSAID_CONNECTEX;
271     LPFN_CONNECTEX connectEx = nullptr;
272     DWORD n = 0;
273     KJ_WINSOCK(WSAIoctl(fd, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid, sizeof(guid),
274                         &connectEx, sizeof(connectEx), &n, NULL, NULL)) {
275       goto fail;  // avoid memory leak due to compiler bugs
276     }
277     if (false) {
278     fail:
279       return kj::READY_NOW;
280     }
281 
282     // OK, phew, we now have our ConnectEx function pointer. Call it.
283     auto op = observer->newOperation(0);
284 
285     if (!connectEx(fd, addr, addrlen, NULL, 0, NULL, op->getOverlapped())) {
286       DWORD error = WSAGetLastError();
287       if (error != ERROR_IO_PENDING) {
288         KJ_FAIL_WIN32("ConnectEx()", error) { break; }
289         return kj::READY_NOW;
290       }
291     }
292 
293     return op->onComplete().then([this](Win32EventPort::IoResult result) {
294       if (result.errorCode != ERROR_SUCCESS) {
295         KJ_FAIL_WIN32("ConnectEx()", result.errorCode) { return; }
296       }
297 
298       // Enable shutdown() to work.
299       setsockopt(SOL_SOCKET, SO_UPDATE_CONNECT_CONTEXT, NULL, 0);
300     });
301   }
302 
whenWriteDisconnected()303   Promise<void> whenWriteDisconnected() override {
304     // Windows IOCP does not provide a direct, documented way to detect when the socket disconnects
305     // without actually doing a read or write. However, there is an undocoumented-but-stable
306     // ioctl called IOCTL_AFD_POLL which can be used for this purpose. In fact, select() is
307     // implemented in terms of this ioctl -- performed synchronously -- but it's entirely possible
308     // to put only one socket into the list and perform the ioctl asynchronously. Here's the
309     // source code for select() in Windows 2000 (not sure how this became public...):
310     //
311     //     https://github.com/pustladi/Windows-2000/blob/661d000d50637ed6fab2329d30e31775046588a9/private/net/sockets/winsock2/wsp/msafd/select.c#L59-L655
312     //
313     // And here's an interesting discussion: https://github.com/python-trio/trio/issues/52
314     //
315     // TODO(someday): Implement this with IOCTL_AFD_POLL. For now I'm leaving it unimplemented
316     //   because I added this method for a Linux-only use case.
317     return NEVER_DONE;
318   }
319 
shutdownWrite()320   void shutdownWrite() override {
321     // There's no legitimate way to get an AsyncStreamFd that isn't a socket through the
322     // Win32AsyncIoProvider interface.
323     KJ_WINSOCK(shutdown(fd, SD_SEND));
324   }
325 
abortRead()326   void abortRead() override {
327     // There's no legitimate way to get an AsyncStreamFd that isn't a socket through the
328     // Win32AsyncIoProvider interface.
329     KJ_WINSOCK(shutdown(fd, SD_RECEIVE));
330   }
331 
getsockopt(int level,int option,void * value,uint * length)332   void getsockopt(int level, int option, void* value, uint* length) override {
333     socklen_t socklen = *length;
334     KJ_WINSOCK(::getsockopt(fd, level, option,
335                             reinterpret_cast<char*>(value), &socklen));
336     *length = socklen;
337   }
338 
setsockopt(int level,int option,const void * value,uint length)339   void setsockopt(int level, int option, const void* value, uint length) override {
340     KJ_WINSOCK(::setsockopt(fd, level, option,
341                             reinterpret_cast<const char*>(value), length));
342   }
343 
getsockname(struct sockaddr * addr,uint * length)344   void getsockname(struct sockaddr* addr, uint* length) override {
345     socklen_t socklen = *length;
346     KJ_WINSOCK(::getsockname(fd, addr, &socklen));
347     *length = socklen;
348   }
349 
getpeername(struct sockaddr * addr,uint * length)350   void getpeername(struct sockaddr* addr, uint* length) override {
351     socklen_t socklen = *length;
352     KJ_WINSOCK(::getpeername(fd, addr, &socklen));
353     *length = socklen;
354   }
355 
356 private:
357   Own<Win32EventPort::IoObserver> observer;
358 
tryReadInternal(ArrayPtr<WSABUF> bufs,size_t minBytes,size_t alreadyRead)359   Promise<size_t> tryReadInternal(ArrayPtr<WSABUF> bufs, size_t minBytes, size_t alreadyRead) {
360     // `bufs` will remain valid until the promise completes and may be freely modified.
361     //
362     // `alreadyRead` is the number of bytes we have already received via previous reads -- minBytes
363     // and buffer have already been adjusted to account for them, but this count must be included
364     // in the final return value.
365 
366     auto op = observer->newOperation(0);
367 
368     DWORD flags = 0;
369     if (WSARecv(fd, bufs.begin(), bufs.size(), NULL, &flags,
370                 op->getOverlapped(), NULL) == SOCKET_ERROR) {
371       DWORD error = WSAGetLastError();
372       if (error != WSA_IO_PENDING) {
373         KJ_FAIL_WIN32("WSARecv()", error) { break; }
374         return alreadyRead;
375       }
376     }
377 
378     return op->onComplete()
379         .then([this,KJ_CPCAP(bufs),minBytes,alreadyRead](Win32IocpEventPort::IoResult result) mutable
380               -> Promise<size_t> {
381       if (result.errorCode != ERROR_SUCCESS) {
382         if (alreadyRead > 0) {
383           // Report what we already read.
384           return alreadyRead;
385         } else {
386           KJ_FAIL_WIN32("WSARecv()", result.errorCode) { break; }
387           return size_t(0);
388         }
389       }
390 
391       if (result.bytesTransferred == 0) {
392         return alreadyRead;
393       }
394 
395       alreadyRead += result.bytesTransferred;
396       if (result.bytesTransferred >= minBytes) {
397         // We can stop here.
398         return alreadyRead;
399       }
400       minBytes -= result.bytesTransferred;
401 
402       while (result.bytesTransferred >= bufs[0].len) {
403         result.bytesTransferred -= bufs[0].len;
404         bufs = bufs.slice(1, bufs.size());
405       }
406 
407       if (result.bytesTransferred > 0) {
408         bufs[0].buf += result.bytesTransferred;
409         bufs[0].len -= result.bytesTransferred;
410       }
411 
412       return tryReadInternal(bufs, minBytes, alreadyRead);
413     }).attach(kj::mv(bufs));
414   }
415 
writeInternal(ArrayPtr<WSABUF> bufs)416   Promise<void> writeInternal(ArrayPtr<WSABUF> bufs) {
417     // `bufs` will remain valid until the promise completes and may be freely modified.
418 
419     auto op = observer->newOperation(0);
420 
421     if (WSASend(fd, bufs.begin(), bufs.size(), NULL, 0,
422                 op->getOverlapped(), NULL) == SOCKET_ERROR) {
423       DWORD error = WSAGetLastError();
424       if (error != WSA_IO_PENDING) {
425         KJ_FAIL_WIN32("WSASend()", error) { break; }
426         return kj::READY_NOW;
427       }
428     }
429 
430     return op->onComplete()
431         .then([this,KJ_CPCAP(bufs)](Win32IocpEventPort::IoResult result) mutable -> Promise<void> {
432       if (result.errorCode != ERROR_SUCCESS) {
433         KJ_FAIL_WIN32("WSASend()", result.errorCode) { break; }
434         return kj::READY_NOW;
435       }
436 
437       while (bufs.size() > 0 && result.bytesTransferred >= bufs[0].len) {
438         result.bytesTransferred -= bufs[0].len;
439         bufs = bufs.slice(1, bufs.size());
440       }
441 
442       if (result.bytesTransferred > 0) {
443         bufs[0].buf += result.bytesTransferred;
444         bufs[0].len -= result.bytesTransferred;
445       }
446 
447       if (bufs.size() > 0) {
448         return writeInternal(bufs);
449       } else {
450         return kj::READY_NOW;
451       }
452     }).attach(kj::mv(bufs));
453   }
454 };
455 
456 // =======================================================================================
457 
458 class SocketAddress {
459 public:
SocketAddress(const void * sockaddr,uint len)460   SocketAddress(const void* sockaddr, uint len): addrlen(len) {
461     KJ_REQUIRE(len <= sizeof(addr), "Sorry, your sockaddr is too big for me.");
462     memcpy(&addr.generic, sockaddr, len);
463   }
464 
operator <(const SocketAddress & other) const465   bool operator<(const SocketAddress& other) const {
466     // So we can use std::set<SocketAddress>...  see DNS lookup code.
467 
468     if (wildcard < other.wildcard) return true;
469     if (wildcard > other.wildcard) return false;
470 
471     if (addrlen < other.addrlen) return true;
472     if (addrlen > other.addrlen) return false;
473 
474     return memcmp(&addr.generic, &other.addr.generic, addrlen) < 0;
475   }
476 
getRaw() const477   const struct sockaddr* getRaw() const { return &addr.generic; }
getRawSize() const478   int getRawSize() const { return addrlen; }
479 
socket(int type) const480   SOCKET socket(int type) const {
481     bool isStream = type == SOCK_STREAM;
482 
483     SOCKET result = ::socket(addr.generic.sa_family, type, 0);
484 
485     if (result == INVALID_SOCKET) {
486       KJ_FAIL_WIN32("WSASocket()", WSAGetLastError()) { return INVALID_SOCKET; }
487     }
488 
489     if (isStream && (addr.generic.sa_family == AF_INET ||
490                      addr.generic.sa_family == AF_INET6)) {
491       // TODO(perf):  As a hack for the 0.4 release we are always setting
492       //   TCP_NODELAY because Nagle's algorithm pretty much kills Cap'n Proto's
493       //   RPC protocol.  Later, we should extend the interface to provide more
494       //   control over this.  Perhaps write() should have a flag which
495       //   specifies whether to pass MSG_MORE.
496       BOOL one = TRUE;
497       KJ_WINSOCK(setsockopt(result, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(one)));
498     }
499 
500     return result;
501   }
502 
bind(SOCKET sockfd) const503   void bind(SOCKET sockfd) const {
504     if (wildcard) {
505       // Disable IPV6_V6ONLY because we want to handle both ipv4 and ipv6 on this socket.  (The
506       // default value of this option varies across platforms.)
507       DWORD value = 0;
508       KJ_WINSOCK(setsockopt(sockfd, IPPROTO_IPV6, IPV6_V6ONLY,
509                             reinterpret_cast<char*>(&value), sizeof(value)));
510     }
511 
512     KJ_WINSOCK(::bind(sockfd, &addr.generic, addrlen), toString());
513   }
514 
getPort() const515   uint getPort() const {
516     switch (addr.generic.sa_family) {
517       case AF_INET: return ntohs(addr.inet4.sin_port);
518       case AF_INET6: return ntohs(addr.inet6.sin6_port);
519       default: return 0;
520     }
521   }
522 
toString() const523   String toString() const {
524     if (wildcard) {
525       return str("*:", getPort());
526     }
527 
528     switch (addr.generic.sa_family) {
529       case AF_INET: {
530         char buffer[16];
531         if (InetNtopA(addr.inet4.sin_family, const_cast<struct in_addr*>(&addr.inet4.sin_addr),
532                       buffer, sizeof(buffer)) == nullptr) {
533           KJ_FAIL_WIN32("InetNtop", WSAGetLastError()) { break; }
534           return heapString("(inet_ntop error)");
535         }
536         return str(buffer, ':', ntohs(addr.inet4.sin_port));
537       }
538       case AF_INET6: {
539         char buffer[46];
540         if (InetNtopA(addr.inet6.sin6_family, const_cast<struct in6_addr*>(&addr.inet6.sin6_addr),
541                       buffer, sizeof(buffer)) == nullptr) {
542           KJ_FAIL_WIN32("InetNtop", WSAGetLastError()) { break; }
543           return heapString("(inet_ntop error)");
544         }
545         return str('[', buffer, "]:", ntohs(addr.inet6.sin6_port));
546       }
547       default:
548         return str("(unknown address family ", addr.generic.sa_family, ")");
549     }
550   }
551 
552   static Promise<Array<SocketAddress>> lookupHost(
553       LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint,
554       _::NetworkFilter& filter);
555   // Perform a DNS lookup.
556 
parse(LowLevelAsyncIoProvider & lowLevel,StringPtr str,uint portHint,_::NetworkFilter & filter)557   static Promise<Array<SocketAddress>> parse(
558       LowLevelAsyncIoProvider& lowLevel, StringPtr str, uint portHint, _::NetworkFilter& filter) {
559     // TODO(someday):  Allow commas in `str`.
560 
561     SocketAddress result;
562 
563     // Try to separate the address and port.
564     ArrayPtr<const char> addrPart;
565     Maybe<StringPtr> portPart;
566 
567     int af;
568 
569     if (str.startsWith("[")) {
570       // Address starts with a bracket, which is a common way to write an ip6 address with a port,
571       // since without brackets around the address part, the port looks like another segment of
572       // the address.
573       af = AF_INET6;
574       size_t closeBracket = KJ_ASSERT_NONNULL(str.findLast(']'),
575           "Unclosed '[' in address string.", str);
576 
577       addrPart = str.slice(1, closeBracket);
578       if (str.size() > closeBracket + 1) {
579         KJ_REQUIRE(str.slice(closeBracket + 1).startsWith(":"),
580                    "Expected port suffix after ']'.", str);
581         portPart = str.slice(closeBracket + 2);
582       }
583     } else {
584       KJ_IF_MAYBE(colon, str.findFirst(':')) {
585         if (str.slice(*colon + 1).findFirst(':') == nullptr) {
586           // There is exactly one colon and no brackets, so it must be an ip4 address with port.
587           af = AF_INET;
588           addrPart = str.slice(0, *colon);
589           portPart = str.slice(*colon + 1);
590         } else {
591           // There are two or more colons and no brackets, so the whole thing must be an ip6
592           // address with no port.
593           af = AF_INET6;
594           addrPart = str;
595         }
596       } else {
597         // No colons, so it must be an ip4 address without port.
598         af = AF_INET;
599         addrPart = str;
600       }
601     }
602 
603     // Parse the port.
604     unsigned long port;
605     KJ_IF_MAYBE(portText, portPart) {
606       char* endptr;
607       port = strtoul(portText->cStr(), &endptr, 0);
608       if (portText->size() == 0 || *endptr != '\0') {
609         // Not a number.  Maybe it's a service name.  Fall back to DNS.
610         return lookupHost(lowLevel, kj::heapString(addrPart), kj::heapString(*portText), portHint,
611                           filter);
612       }
613       KJ_REQUIRE(port < 65536, "Port number too large.");
614     } else {
615       port = portHint;
616     }
617 
618     // Check for wildcard.
619     if (addrPart.size() == 1 && addrPart[0] == '*') {
620       result.wildcard = true;
621       // Create an ip6 socket and set IPV6_V6ONLY to 0 later.
622       result.addrlen = sizeof(addr.inet6);
623       result.addr.inet6.sin6_family = AF_INET6;
624       result.addr.inet6.sin6_port = htons(port);
625       auto array = kj::heapArrayBuilder<SocketAddress>(1);
626       array.add(result);
627       return array.finish();
628     }
629 
630     void* addrTarget;
631     if (af == AF_INET6) {
632       result.addrlen = sizeof(addr.inet6);
633       result.addr.inet6.sin6_family = AF_INET6;
634       result.addr.inet6.sin6_port = htons(port);
635       addrTarget = &result.addr.inet6.sin6_addr;
636     } else {
637       result.addrlen = sizeof(addr.inet4);
638       result.addr.inet4.sin_family = AF_INET;
639       result.addr.inet4.sin_port = htons(port);
640       addrTarget = &result.addr.inet4.sin_addr;
641     }
642 
643     char buffer[64];
644     if (addrPart.size() < sizeof(buffer) - 1) {
645       // addrPart is not necessarily NUL-terminated so we have to make a copy.  :(
646       memcpy(buffer, addrPart.begin(), addrPart.size());
647       buffer[addrPart.size()] = '\0';
648 
649       // OK, parse it!
650       switch (InetPtonA(af, buffer, addrTarget)) {
651         case 1: {
652           // success.
653           if (!result.parseAllowedBy(filter)) {
654             KJ_FAIL_REQUIRE("address family blocked by restrictPeers()");
655             return Array<SocketAddress>();
656           }
657 
658           auto array = kj::heapArrayBuilder<SocketAddress>(1);
659           array.add(result);
660           return array.finish();
661         }
662         case 0:
663           // It's apparently not a simple address...  fall back to DNS.
664           break;
665         default:
666           KJ_FAIL_WIN32("InetPton", WSAGetLastError(), af, addrPart);
667       }
668     }
669 
670     return lookupHost(lowLevel, kj::heapString(addrPart), nullptr, port, filter);
671   }
672 
getLocalAddress(SOCKET sockfd)673   static SocketAddress getLocalAddress(SOCKET sockfd) {
674     SocketAddress result;
675     result.addrlen = sizeof(addr);
676     KJ_WINSOCK(getsockname(sockfd, &result.addr.generic, &result.addrlen));
677     return result;
678   }
679 
getPeerAddress(SOCKET sockfd)680   static SocketAddress getPeerAddress(SOCKET sockfd) {
681     SocketAddress result;
682     result.addrlen = sizeof(addr);
683     KJ_WINSOCK(getpeername(sockfd, &result.addr.generic, &result.addrlen));
684     return result;
685   }
686 
allowedBy(LowLevelAsyncIoProvider::NetworkFilter & filter)687   bool allowedBy(LowLevelAsyncIoProvider::NetworkFilter& filter) {
688     return filter.shouldAllow(&addr.generic, addrlen);
689   }
690 
parseAllowedBy(_::NetworkFilter & filter)691   bool parseAllowedBy(_::NetworkFilter& filter) {
692     return filter.shouldAllowParse(&addr.generic, addrlen);
693   }
694 
getWildcardForFamily(int family)695   static SocketAddress getWildcardForFamily(int family) {
696     SocketAddress result;
697     switch (family) {
698       case AF_INET:
699         result.addrlen = sizeof(addr.inet4);
700         result.addr.inet4.sin_family = AF_INET;
701         return result;
702       case AF_INET6:
703         result.addrlen = sizeof(addr.inet6);
704         result.addr.inet6.sin6_family = AF_INET6;
705         return result;
706       default:
707         KJ_FAIL_REQUIRE("unknown address family", family);
708     }
709   }
710 
711 private:
SocketAddress()712   SocketAddress(): addrlen(0) {
713     memset(&addr, 0, sizeof(addr));
714   }
715 
716   socklen_t addrlen;
717   bool wildcard = false;
718   union {
719     struct sockaddr generic;
720     struct sockaddr_in inet4;
721     struct sockaddr_in6 inet6;
722     struct sockaddr_storage storage;
723   } addr;
724 
725   struct LookupParams;
726   class LookupReader;
727 };
728 
729 class SocketAddress::LookupReader {
730   // Reads SocketAddresses off of a pipe coming from another thread that is performing
731   // getaddrinfo.
732 
733 public:
LookupReader(kj::Own<Thread> && thread,kj::Own<AsyncInputStream> && input,_::NetworkFilter & filter)734   LookupReader(kj::Own<Thread>&& thread, kj::Own<AsyncInputStream>&& input,
735                _::NetworkFilter& filter)
736       : thread(kj::mv(thread)), input(kj::mv(input)), filter(filter) {}
737 
~LookupReader()738   ~LookupReader() {
739     if (thread) thread->detach();
740   }
741 
read()742   Promise<Array<SocketAddress>> read() {
743     return input->tryRead(&current, sizeof(current), sizeof(current)).then(
744         [this](size_t n) -> Promise<Array<SocketAddress>> {
745       if (n < sizeof(current)) {
746         thread = nullptr;
747         // getaddrinfo()'s docs seem to say it will never return an empty list, but let's check
748         // anyway.
749         KJ_REQUIRE(addresses.size() > 0, "DNS lookup returned no permitted addresses.") { break; }
750         return addresses.releaseAsArray();
751       } else {
752         // getaddrinfo() can return multiple copies of the same address for several reasons.
753         // A major one is that we don't give it a socket type (SOCK_STREAM vs. SOCK_DGRAM), so
754         // it may return two copies of the same address, one for each type, unless it explicitly
755         // knows that the service name given is specific to one type.  But we can't tell it a type,
756         // because we don't actually know which one the user wants, and if we specify SOCK_STREAM
757         // while the user specified a UDP service name then they'll get a resolution error which
758         // is lame.  (At least, I think that's how it works.)
759         //
760         // So we instead resort to de-duping results.
761         if (alreadySeen.insert(current).second) {
762           if (current.parseAllowedBy(filter)) {
763             addresses.add(current);
764           }
765         }
766         return read();
767       }
768     });
769   }
770 
771 private:
772   kj::Own<Thread> thread;
773   kj::Own<AsyncInputStream> input;
774   _::NetworkFilter& filter;
775   SocketAddress current;
776   kj::Vector<SocketAddress> addresses;
777   std::set<SocketAddress> alreadySeen;
778 };
779 
780 struct SocketAddress::LookupParams {
781   kj::String host;
782   kj::String service;
783 };
784 
lookupHost(LowLevelAsyncIoProvider & lowLevel,kj::String host,kj::String service,uint portHint,_::NetworkFilter & filter)785 Promise<Array<SocketAddress>> SocketAddress::lookupHost(
786     LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint,
787     _::NetworkFilter& filter) {
788   // This shitty function spawns a thread to run getaddrinfo().  Unfortunately, getaddrinfo() is
789   // the only cross-platform DNS API and it is blocking.
790   //
791   // TODO(perf): Use GetAddrInfoEx(). But there are problems:
792   // - Not implemented in Wine.
793   // - Doesn't seem compatible with I/O completion ports, in particular because it's not associated
794   //   with a handle. Could signal completion as an APC instead, but that requires the IOCP code
795   //   to use GetQueuedCompletionStatusEx() which it doesn't right now because it's not available
796   //   in Wine.
797   // - Requires Unicode, for some reason. Only GetAddrInfoExW() supports async, according to the
798   //   docs. Never mind that DNS itself is ASCII...
799 
800   SOCKET fds[2];
801   KJ_WINSOCK(_::win32Socketpair(fds));
802 
803   auto input = lowLevel.wrapInputFd(fds[0], NEW_FD_FLAGS);
804 
805   int outFd = fds[1];
806 
807   LookupParams params = { kj::mv(host), kj::mv(service) };
808 
809   auto thread = heap<Thread>(kj::mvCapture(params, [outFd,portHint](LookupParams&& params) {
810     KJ_DEFER(closesocket(outFd));
811 
812     struct addrinfo* list;
813     int status = getaddrinfo(
814         params.host == "*" ? nullptr : params.host.cStr(),
815         params.service == nullptr ? nullptr : params.service.cStr(),
816         nullptr, &list);
817     if (status == 0) {
818       KJ_DEFER(freeaddrinfo(list));
819 
820       struct addrinfo* cur = list;
821       while (cur != nullptr) {
822         if (params.service == nullptr) {
823           switch (cur->ai_addr->sa_family) {
824             case AF_INET:
825               ((struct sockaddr_in*)cur->ai_addr)->sin_port = htons(portHint);
826               break;
827             case AF_INET6:
828               ((struct sockaddr_in6*)cur->ai_addr)->sin6_port = htons(portHint);
829               break;
830             default:
831               break;
832           }
833         }
834 
835         SocketAddress addr;
836         memset(&addr, 0, sizeof(addr));  // mollify valgrind
837         if (params.host == "*") {
838           // Set up a wildcard SocketAddress.  Only use the port number returned by getaddrinfo().
839           addr.wildcard = true;
840           addr.addrlen = sizeof(addr.addr.inet6);
841           addr.addr.inet6.sin6_family = AF_INET6;
842           switch (cur->ai_addr->sa_family) {
843             case AF_INET:
844               addr.addr.inet6.sin6_port = ((struct sockaddr_in*)cur->ai_addr)->sin_port;
845               break;
846             case AF_INET6:
847               addr.addr.inet6.sin6_port = ((struct sockaddr_in6*)cur->ai_addr)->sin6_port;
848               break;
849             default:
850               addr.addr.inet6.sin6_port = portHint;
851               break;
852           }
853         } else {
854           addr.addrlen = cur->ai_addrlen;
855           memcpy(&addr.addr.generic, cur->ai_addr, cur->ai_addrlen);
856         }
857         KJ_ASSERT_CAN_MEMCPY(SocketAddress);
858 
859         const char* data = reinterpret_cast<const char*>(&addr);
860         size_t size = sizeof(addr);
861         while (size > 0) {
862           int n;
863           KJ_WINSOCK(n = send(outFd, data, size, 0));
864           data += n;
865           size -= n;
866         }
867 
868         cur = cur->ai_next;
869       }
870     } else {
871       KJ_FAIL_WIN32("getaddrinfo()", status, params.host, params.service) {
872         return;
873       }
874     }
875   }));
876 
877   auto reader = heap<LookupReader>(kj::mv(thread), kj::mv(input), filter);
878   return reader->read().attach(kj::mv(reader));
879 }
880 
881 // =======================================================================================
882 
883 class FdConnectionReceiver final: public ConnectionReceiver, public OwnedFd {
884 public:
FdConnectionReceiver(Win32EventPort & eventPort,SOCKET fd,LowLevelAsyncIoProvider::NetworkFilter & filter,uint flags)885   FdConnectionReceiver(Win32EventPort& eventPort, SOCKET fd,
886                        LowLevelAsyncIoProvider::NetworkFilter& filter, uint flags)
887       : OwnedFd(fd, flags), eventPort(eventPort), filter(filter),
888         observer(eventPort.observeIo(reinterpret_cast<HANDLE>(fd))),
889         address(SocketAddress::getLocalAddress(fd)) {
890     // In order to accept asynchronously, we need the AcceptEx() function. Apparently, we have
891     // to query the socket for it dynamically, I guess because of the insanity in which winsock
892     // can be implemented in userspace and old implementations may not support it.
893     GUID guid = WSAID_ACCEPTEX;
894     DWORD n = 0;
895     KJ_WINSOCK(WSAIoctl(fd, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid, sizeof(guid),
896                         &acceptEx, sizeof(acceptEx), &n, NULL, NULL)) {
897       acceptEx = nullptr;
898       return;
899     }
900   }
901 
accept()902   Promise<Own<AsyncIoStream>> accept() override {
903     SOCKET newFd = address.socket(SOCK_STREAM);
904     KJ_ASSERT(newFd != INVALID_SOCKET);
905     auto result = heap<AsyncStreamFd>(eventPort, newFd, NEW_FD_FLAGS);
906 
907     auto scratch = heapArray<byte>(256);
908     DWORD dummy;
909     auto op = observer->newOperation(0);
910     if (!acceptEx(fd, newFd, scratch.begin(), 0, 128, 128, &dummy, op->getOverlapped())) {
911       DWORD error = WSAGetLastError();
912       if (error != ERROR_IO_PENDING) {
913         KJ_FAIL_WIN32("AcceptEx()", error) { break; }
914         return Own<AsyncIoStream>(kj::mv(result));  // dummy, won't be used
915       }
916     }
917 
918     return op->onComplete().then(mvCapture(result, mvCapture(scratch,
919         [this,newFd]
920         (Array<byte> scratch, Own<AsyncIoStream> stream, Win32EventPort::IoResult ioResult)
921         -> Promise<Own<AsyncIoStream>> {
922       if (ioResult.errorCode != ERROR_SUCCESS) {
923         KJ_FAIL_WIN32("AcceptEx()", ioResult.errorCode) { break; }
924       } else {
925         SOCKET me = fd;
926         stream->setsockopt(SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT,
927                            reinterpret_cast<char*>(&me), sizeof(me));
928       }
929 
930       // Supposedly, AcceptEx() places the local and peer addresses into the buffer (which we've
931       // named `scratch`). However, the format in which it writes these is undocumented, and
932       // doesn't even match between native Windows and WINE. Apparently it is useless. I don't know
933       // why they require the buffer to have space for it in the first place. We'll need to call
934       // getpeername() to get the address.
935       auto addr = SocketAddress::getPeerAddress(newFd);
936       if (addr.allowedBy(filter)) {
937         return kj::mv(stream);
938       } else {
939         return accept();
940       }
941     })));
942   }
943 
getPort()944   uint getPort() override {
945     return address.getPort();
946   }
947 
getsockopt(int level,int option,void * value,uint * length)948   void getsockopt(int level, int option, void* value, uint* length) override {
949     socklen_t socklen = *length;
950     KJ_WINSOCK(::getsockopt(fd, level, option,
951                             reinterpret_cast<char*>(value), &socklen));
952     *length = socklen;
953   }
setsockopt(int level,int option,const void * value,uint length)954   void setsockopt(int level, int option, const void* value, uint length) override {
955     KJ_WINSOCK(::setsockopt(fd, level, option,
956                             reinterpret_cast<const char*>(value), length));
957   }
getsockname(struct sockaddr * addr,uint * length)958   void getsockname(struct sockaddr* addr, uint* length) override {
959     socklen_t socklen = *length;
960     KJ_WINSOCK(::getsockname(fd, addr, &socklen));
961     *length = socklen;
962   }
963 
964 public:
965   Win32EventPort& eventPort;
966   LowLevelAsyncIoProvider::NetworkFilter& filter;
967   Own<Win32EventPort::IoObserver> observer;
968   LPFN_ACCEPTEX acceptEx = nullptr;
969   SocketAddress address;
970 };
971 
972 // TODO(someday): DatagramPortImpl
973 
974 class LowLevelAsyncIoProviderImpl final: public LowLevelAsyncIoProvider {
975 public:
LowLevelAsyncIoProviderImpl()976   LowLevelAsyncIoProviderImpl()
977       : eventLoop(eventPort), waitScope(eventLoop) {}
978 
getWaitScope()979   inline WaitScope& getWaitScope() { return waitScope; }
980 
wrapInputFd(SOCKET fd,uint flags=0)981   Own<AsyncInputStream> wrapInputFd(SOCKET fd, uint flags = 0) override {
982     return heap<AsyncStreamFd>(eventPort, fd, flags);
983   }
wrapOutputFd(SOCKET fd,uint flags=0)984   Own<AsyncOutputStream> wrapOutputFd(SOCKET fd, uint flags = 0) override {
985     return heap<AsyncStreamFd>(eventPort, fd, flags);
986   }
wrapSocketFd(SOCKET fd,uint flags=0)987   Own<AsyncIoStream> wrapSocketFd(SOCKET fd, uint flags = 0) override {
988     return heap<AsyncStreamFd>(eventPort, fd, flags);
989   }
wrapConnectingSocketFd(SOCKET fd,const struct sockaddr * addr,uint addrlen,uint flags=0)990   Promise<Own<AsyncIoStream>> wrapConnectingSocketFd(
991       SOCKET fd, const struct sockaddr* addr, uint addrlen, uint flags = 0) override {
992     auto result = heap<AsyncStreamFd>(eventPort, fd, flags);
993 
994     // ConnectEx requires that the socket be bound, for some reason. Bind to an arbitrary port.
995     SocketAddress::getWildcardForFamily(addr->sa_family).bind(fd);
996 
997     auto connected = result->connect(addr, addrlen);
998     return connected.then(kj::mvCapture(result, [](Own<AsyncIoStream>&& result) {
999       return kj::mv(result);
1000     }));
1001   }
wrapListenSocketFd(SOCKET fd,NetworkFilter & filter,uint flags=0)1002   Own<ConnectionReceiver> wrapListenSocketFd(
1003       SOCKET fd, NetworkFilter& filter, uint flags = 0) override {
1004     return heap<FdConnectionReceiver>(eventPort, fd, filter, flags);
1005   }
1006 
getTimer()1007   Timer& getTimer() override { return eventPort.getTimer(); }
1008 
getEventPort()1009   Win32EventPort& getEventPort() { return eventPort; }
1010 
1011 private:
1012   Win32IocpEventPort eventPort;
1013   EventLoop eventLoop;
1014   WaitScope waitScope;
1015 };
1016 
1017 // =======================================================================================
1018 
1019 class NetworkAddressImpl final: public NetworkAddress {
1020 public:
NetworkAddressImpl(LowLevelAsyncIoProvider & lowLevel,LowLevelAsyncIoProvider::NetworkFilter & filter,Array<SocketAddress> addrs)1021   NetworkAddressImpl(LowLevelAsyncIoProvider& lowLevel,
1022                      LowLevelAsyncIoProvider::NetworkFilter& filter,
1023                      Array<SocketAddress> addrs)
1024       : lowLevel(lowLevel), filter(filter), addrs(kj::mv(addrs)) {}
1025 
connect()1026   Promise<Own<AsyncIoStream>> connect() override {
1027     auto addrsCopy = heapArray(addrs.asPtr());
1028     auto promise = connectImpl(lowLevel, filter, addrsCopy);
1029     return promise.attach(kj::mv(addrsCopy));
1030   }
1031 
listen()1032   Own<ConnectionReceiver> listen() override {
1033     if (addrs.size() > 1) {
1034       KJ_LOG(WARNING, "Bind address resolved to multiple addresses.  Only the first address will "
1035           "be used.  If this is incorrect, specify the address numerically.  This may be fixed "
1036           "in the future.", addrs[0].toString());
1037     }
1038 
1039     int fd = addrs[0].socket(SOCK_STREAM);
1040 
1041     {
1042       KJ_ON_SCOPE_FAILURE(closesocket(fd));
1043 
1044       // We always enable SO_REUSEADDR because having to take your server down for five minutes
1045       // before it can restart really sucks.
1046       int optval = 1;
1047       KJ_WINSOCK(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR,
1048                             reinterpret_cast<char*>(&optval), sizeof(optval)));
1049 
1050       addrs[0].bind(fd);
1051 
1052       // TODO(someday):  Let queue size be specified explicitly in string addresses.
1053       KJ_WINSOCK(::listen(fd, SOMAXCONN));
1054     }
1055 
1056     return lowLevel.wrapListenSocketFd(fd, filter, NEW_FD_FLAGS);
1057   }
1058 
bindDatagramPort()1059   Own<DatagramPort> bindDatagramPort() override {
1060     if (addrs.size() > 1) {
1061       KJ_LOG(WARNING, "Bind address resolved to multiple addresses.  Only the first address will "
1062           "be used.  If this is incorrect, specify the address numerically.  This may be fixed "
1063           "in the future.", addrs[0].toString());
1064     }
1065 
1066     int fd = addrs[0].socket(SOCK_DGRAM);
1067 
1068     {
1069       KJ_ON_SCOPE_FAILURE(closesocket(fd));
1070 
1071       // We always enable SO_REUSEADDR because having to take your server down for five minutes
1072       // before it can restart really sucks.
1073       int optval = 1;
1074       KJ_WINSOCK(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR,
1075                             reinterpret_cast<char*>(&optval), sizeof(optval)));
1076 
1077       addrs[0].bind(fd);
1078     }
1079 
1080     return lowLevel.wrapDatagramSocketFd(fd, filter, NEW_FD_FLAGS);
1081   }
1082 
clone()1083   Own<NetworkAddress> clone() override {
1084     return kj::heap<NetworkAddressImpl>(lowLevel, filter, kj::heapArray(addrs.asPtr()));
1085   }
1086 
toString()1087   String toString() override {
1088     return strArray(KJ_MAP(addr, addrs) { return addr.toString(); }, ",");
1089   }
1090 
chooseOneAddress()1091   const SocketAddress& chooseOneAddress() {
1092     KJ_REQUIRE(addrs.size() > 0, "No addresses available.");
1093     return addrs[counter++ % addrs.size()];
1094   }
1095 
1096 private:
1097   LowLevelAsyncIoProvider& lowLevel;
1098   LowLevelAsyncIoProvider::NetworkFilter& filter;
1099   Array<SocketAddress> addrs;
1100   uint counter = 0;
1101 
connectImpl(LowLevelAsyncIoProvider & lowLevel,LowLevelAsyncIoProvider::NetworkFilter & filter,ArrayPtr<SocketAddress> addrs)1102   static Promise<Own<AsyncIoStream>> connectImpl(
1103       LowLevelAsyncIoProvider& lowLevel,
1104       LowLevelAsyncIoProvider::NetworkFilter& filter,
1105       ArrayPtr<SocketAddress> addrs) {
1106     KJ_ASSERT(addrs.size() > 0);
1107 
1108     int fd = addrs[0].socket(SOCK_STREAM);
1109 
1110     return kj::evalNow([&]() -> Promise<Own<AsyncIoStream>> {
1111       if (!addrs[0].allowedBy(filter)) {
1112         return KJ_EXCEPTION(FAILED, "connect() blocked by restrictPeers()");
1113       } else {
1114         return lowLevel.wrapConnectingSocketFd(
1115             fd, addrs[0].getRaw(), addrs[0].getRawSize(), NEW_FD_FLAGS);
1116       }
1117     }).then([](Own<AsyncIoStream>&& stream) -> Promise<Own<AsyncIoStream>> {
1118       // Success, pass along.
1119       return kj::mv(stream);
1120     }, [&lowLevel,&filter,KJ_CPCAP(addrs)](Exception&& exception) mutable
1121         -> Promise<Own<AsyncIoStream>> {
1122       // Connect failed.
1123       if (addrs.size() > 1) {
1124         // Try the next address instead.
1125         return connectImpl(lowLevel, filter, addrs.slice(1, addrs.size()));
1126       } else {
1127         // No more addresses to try, so propagate the exception.
1128         return kj::mv(exception);
1129       }
1130     });
1131   }
1132 };
1133 
1134 class SocketNetwork final: public Network {
1135 public:
SocketNetwork(LowLevelAsyncIoProvider & lowLevel)1136   explicit SocketNetwork(LowLevelAsyncIoProvider& lowLevel): lowLevel(lowLevel) {}
SocketNetwork(SocketNetwork & parent,kj::ArrayPtr<const kj::StringPtr> allow,kj::ArrayPtr<const kj::StringPtr> deny)1137   explicit SocketNetwork(SocketNetwork& parent,
1138                          kj::ArrayPtr<const kj::StringPtr> allow,
1139                          kj::ArrayPtr<const kj::StringPtr> deny)
1140       : lowLevel(parent.lowLevel), filter(allow, deny, parent.filter) {}
1141 
parseAddress(StringPtr addr,uint portHint=0)1142   Promise<Own<NetworkAddress>> parseAddress(StringPtr addr, uint portHint = 0) override {
1143     return evalLater(mvCapture(heapString(addr), [this,portHint](String&& addr) {
1144       return SocketAddress::parse(lowLevel, addr, portHint, filter);
1145     })).then([this](Array<SocketAddress> addresses) -> Own<NetworkAddress> {
1146       return heap<NetworkAddressImpl>(lowLevel, filter, kj::mv(addresses));
1147     });
1148   }
1149 
getSockaddr(const void * sockaddr,uint len)1150   Own<NetworkAddress> getSockaddr(const void* sockaddr, uint len) override {
1151     auto array = kj::heapArrayBuilder<SocketAddress>(1);
1152     array.add(SocketAddress(sockaddr, len));
1153     KJ_REQUIRE(array[0].allowedBy(filter), "address blocked by restrictPeers()") { break; }
1154     return Own<NetworkAddress>(heap<NetworkAddressImpl>(lowLevel, filter, array.finish()));
1155   }
1156 
restrictPeers(kj::ArrayPtr<const kj::StringPtr> allow,kj::ArrayPtr<const kj::StringPtr> deny=nullptr)1157   Own<Network> restrictPeers(
1158       kj::ArrayPtr<const kj::StringPtr> allow,
1159       kj::ArrayPtr<const kj::StringPtr> deny = nullptr) override {
1160     return heap<SocketNetwork>(*this, allow, deny);
1161   }
1162 
1163 private:
1164   LowLevelAsyncIoProvider& lowLevel;
1165   _::NetworkFilter filter;
1166 };
1167 
1168 // =======================================================================================
1169 
1170 class AsyncIoProviderImpl final: public AsyncIoProvider {
1171 public:
AsyncIoProviderImpl(LowLevelAsyncIoProvider & lowLevel)1172   AsyncIoProviderImpl(LowLevelAsyncIoProvider& lowLevel)
1173       : lowLevel(lowLevel), network(lowLevel) {}
1174 
newOneWayPipe()1175   OneWayPipe newOneWayPipe() override {
1176     SOCKET fds[2];
1177     KJ_WINSOCK(_::win32Socketpair(fds));
1178     auto in = lowLevel.wrapSocketFd(fds[0], NEW_FD_FLAGS);
1179     auto out = lowLevel.wrapOutputFd(fds[1], NEW_FD_FLAGS);
1180     in->shutdownWrite();
1181     return { kj::mv(in), kj::mv(out) };
1182   }
1183 
newTwoWayPipe()1184   TwoWayPipe newTwoWayPipe() override {
1185     SOCKET fds[2];
1186     KJ_WINSOCK(_::win32Socketpair(fds));
1187     return TwoWayPipe { {
1188       lowLevel.wrapSocketFd(fds[0], NEW_FD_FLAGS),
1189       lowLevel.wrapSocketFd(fds[1], NEW_FD_FLAGS)
1190     } };
1191   }
1192 
getNetwork()1193   Network& getNetwork() override {
1194     return network;
1195   }
1196 
newPipeThread(Function<void (AsyncIoProvider &,AsyncIoStream &,WaitScope &)> startFunc)1197   PipeThread newPipeThread(
1198       Function<void(AsyncIoProvider&, AsyncIoStream&, WaitScope&)> startFunc) override {
1199     SOCKET fds[2];
1200     KJ_WINSOCK(_::win32Socketpair(fds));
1201 
1202     int threadFd = fds[1];
1203     KJ_ON_SCOPE_FAILURE(closesocket(threadFd));
1204 
1205     auto pipe = lowLevel.wrapSocketFd(fds[0], NEW_FD_FLAGS);
1206 
1207     auto thread = heap<Thread>(kj::mvCapture(startFunc,
1208         [threadFd](Function<void(AsyncIoProvider&, AsyncIoStream&, WaitScope&)>&& startFunc) {
1209       LowLevelAsyncIoProviderImpl lowLevel;
1210       auto stream = lowLevel.wrapSocketFd(threadFd, NEW_FD_FLAGS);
1211       AsyncIoProviderImpl ioProvider(lowLevel);
1212       startFunc(ioProvider, *stream, lowLevel.getWaitScope());
1213     }));
1214 
1215     return { kj::mv(thread), kj::mv(pipe) };
1216   }
1217 
getTimer()1218   Timer& getTimer() override { return lowLevel.getTimer(); }
1219 
1220 private:
1221   LowLevelAsyncIoProvider& lowLevel;
1222   SocketNetwork network;
1223 };
1224 
1225 }  // namespace
1226 
newAsyncIoProvider(LowLevelAsyncIoProvider & lowLevel)1227 Own<AsyncIoProvider> newAsyncIoProvider(LowLevelAsyncIoProvider& lowLevel) {
1228   return kj::heap<AsyncIoProviderImpl>(lowLevel);
1229 }
1230 
setupAsyncIo()1231 AsyncIoContext setupAsyncIo() {
1232   _::initWinsockOnce();
1233 
1234   auto lowLevel = heap<LowLevelAsyncIoProviderImpl>();
1235   auto ioProvider = kj::heap<AsyncIoProviderImpl>(*lowLevel);
1236   auto& waitScope = lowLevel->getWaitScope();
1237   auto& eventPort = lowLevel->getEventPort();
1238   return { kj::mv(lowLevel), kj::mv(ioProvider), waitScope, eventPort };
1239 }
1240 
1241 }  // namespace kj
1242 
1243 #endif  // _WIN32
1244