1 // Copyright (c) 2016 Sandstorm Development Group, Inc. and contributors
2 // Licensed under the MIT License:
3 //
4 // Permission is hereby granted, free of charge, to any person obtaining a copy
5 // of this software and associated documentation files (the "Software"), to deal
6 // in the Software without restriction, including without limitation the rights
7 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 // copies of the Software, and to permit persons to whom the Software is
9 // furnished to do so, subject to the following conditions:
10 //
11 // The above copyright notice and this permission notice shall be included in
12 // all copies or substantial portions of the Software.
13 //
14 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 // THE SOFTWARE.
21 
22 #if _WIN32
23 // For Unix implementation, see async-io-unix.c++.
24 
25 // Request Vista-level APIs.
26 #include "win32-api-version.h"
27 
28 #include "async-io.h"
29 #include "async-io-internal.h"
30 #include "async-win32.h"
31 #include "debug.h"
32 #include "thread.h"
33 #include "io.h"
34 #include "vector.h"
35 #include <set>
36 
37 #include <winsock2.h>
38 #include <ws2ipdef.h>
39 #include <ws2tcpip.h>
40 #include <mswsock.h>
41 #include <stdlib.h>
42 
43 #ifndef IPV6_V6ONLY
44 // MinGW's headers are missing this.
45 #define IPV6_V6ONLY 27
46 #endif
47 
48 namespace kj {
49 
50 namespace _ {  // private
51 
52 struct WinsockInitializer {
WinsockInitializerkj::_::WinsockInitializer53   WinsockInitializer() {
54     WSADATA dontcare;
55     int result = WSAStartup(MAKEWORD(2, 2), &dontcare);
56     if (result != 0) {
57       KJ_FAIL_WIN32("WSAStartup()", result);
58     }
59   }
60 };
61 
initWinsockOnce()62 void initWinsockOnce() {
63   static WinsockInitializer initializer;
64 }
65 
win32Socketpair(SOCKET socks[2])66 int win32Socketpair(SOCKET socks[2]) {
67   // This function from: https://github.com/ncm/selectable-socketpair/blob/master/socketpair.c
68   //
69   // Copyright notice:
70   //
71   //    Copyright 2007, 2010 by Nathan C. Myers <ncm@cantrip.org>
72   //    Redistribution and use in source and binary forms, with or without modification,
73   //    are permitted provided that the following conditions are met:
74   //
75   //        Redistributions of source code must retain the above copyright notice, this
76   //        list of conditions and the following disclaimer.
77   //
78   //        Redistributions in binary form must reproduce the above copyright notice,
79   //        this list of conditions and the following disclaimer in the documentation
80   //        and/or other materials provided with the distribution.
81   //
82   //        The name of the author must not be used to endorse or promote products
83   //        derived from this software without specific prior written permission.
84   //
85   //    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
86   //    AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
87   //    IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
88   //    DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
89   //    FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
90   //    DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
91   //    SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
92   //    CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
93   //    OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
94   //    OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
95 
96   // Note: This function is called from some Cap'n Proto unit tests, despite not having a public
97   //   header declaration.
98   // TODO(cleanup): Consider putting this somewhere public? Note that since it depends on Winsock,
99   //   it needs to be in the kj-async library.
100 
101   initWinsockOnce();
102 
103   union {
104     struct sockaddr_in inaddr;
105     struct sockaddr addr;
106   } a;
107   SOCKET listener;
108   int e;
109   socklen_t addrlen = sizeof(a.inaddr);
110   int reuse = 1;
111 
112   if (socks == 0) {
113     WSASetLastError(WSAEINVAL);
114     return SOCKET_ERROR;
115   }
116   socks[0] = socks[1] = -1;
117 
118   listener = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
119   if (listener == -1)
120     return SOCKET_ERROR;
121 
122   memset(&a, 0, sizeof(a));
123   a.inaddr.sin_family = AF_INET;
124   a.inaddr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
125   a.inaddr.sin_port = 0;
126 
127   for (;;) {
128     if (setsockopt(listener, SOL_SOCKET, SO_REUSEADDR,
129            (char*) &reuse, (socklen_t) sizeof(reuse)) == -1)
130       break;
131     if  (bind(listener, &a.addr, sizeof(a.inaddr)) == SOCKET_ERROR)
132       break;
133 
134     memset(&a, 0, sizeof(a));
135     if  (getsockname(listener, &a.addr, &addrlen) == SOCKET_ERROR)
136       break;
137     // win32 getsockname may only set the port number, p=0.0005.
138     // ( http://msdn.microsoft.com/library/ms738543.aspx ):
139     a.inaddr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
140     a.inaddr.sin_family = AF_INET;
141 
142     if (listen(listener, 1) == SOCKET_ERROR)
143       break;
144 
145     socks[0] = WSASocket(AF_INET, SOCK_STREAM, 0, NULL, 0, WSA_FLAG_OVERLAPPED);
146     if (socks[0] == -1)
147       break;
148     if (connect(socks[0], &a.addr, sizeof(a.inaddr)) == SOCKET_ERROR)
149       break;
150 
151   retryAccept:
152     socks[1] = accept(listener, NULL, NULL);
153     if (socks[1] == -1)
154       break;
155 
156     // Verify that the client is actually us and not someone else who raced to connect first.
157     // (This check added by Kenton for security.)
158     union {
159       struct sockaddr_in inaddr;
160       struct sockaddr addr;
161     } b, c;
162     socklen_t bAddrlen = sizeof(b.inaddr);
163     socklen_t cAddrlen = sizeof(b.inaddr);
164     if (getpeername(socks[1], &b.addr, &bAddrlen) == SOCKET_ERROR)
165       break;
166     if (getsockname(socks[0], &c.addr, &cAddrlen) == SOCKET_ERROR)
167       break;
168     if (bAddrlen != cAddrlen || memcmp(&b.addr, &c.addr, bAddrlen) != 0) {
169       // Someone raced to connect first. Ignore.
170       closesocket(socks[1]);
171       goto retryAccept;
172     }
173 
174     closesocket(listener);
175     return 0;
176   }
177 
178   e = WSAGetLastError();
179   closesocket(listener);
180   closesocket(socks[0]);
181   closesocket(socks[1]);
182   WSASetLastError(e);
183   socks[0] = socks[1] = -1;
184   return SOCKET_ERROR;
185 }
186 
187 }  // namespace _
188 
189 namespace {
190 
191 // =======================================================================================
192 
193 static constexpr uint NEW_FD_FLAGS = LowLevelAsyncIoProvider::TAKE_OWNERSHIP;
194 
195 class OwnedFd {
196 public:
OwnedFd(SOCKET fd,uint flags)197   OwnedFd(SOCKET fd, uint flags): fd(fd), flags(flags) {
198     // TODO(perf): Maybe use SetFileCompletionNotificationModes() to tell Windows not to bother
199     //   delivering an event when the operation completes inline. Not currently implemented on
200     //   Wine, though.
201   }
202 
~OwnedFd()203   ~OwnedFd() noexcept(false) {
204     if (flags & LowLevelAsyncIoProvider::TAKE_OWNERSHIP) {
205       KJ_WINSOCK(closesocket(fd)) { break; }
206     }
207   }
208 
209 protected:
210   SOCKET fd;
211 
212 private:
213   uint flags;
214 };
215 
216 // =======================================================================================
217 
218 class AsyncStreamFd: public OwnedFd, public AsyncIoStream {
219 public:
AsyncStreamFd(Win32EventPort & eventPort,SOCKET fd,uint flags)220   AsyncStreamFd(Win32EventPort& eventPort, SOCKET fd, uint flags)
221       : OwnedFd(fd, flags),
222         observer(eventPort.observeIo(reinterpret_cast<HANDLE>(fd))) {}
~AsyncStreamFd()223   virtual ~AsyncStreamFd() noexcept(false) {}
224 
read(void * buffer,size_t minBytes,size_t maxBytes)225   Promise<size_t> read(void* buffer, size_t minBytes, size_t maxBytes) override {
226     return tryRead(buffer, minBytes, maxBytes).then([=](size_t result) {
227       KJ_REQUIRE(result >= minBytes, "Premature EOF") {
228         // Pretend we read zeros from the input.
229         memset(reinterpret_cast<byte*>(buffer) + result, 0, minBytes - result);
230         return minBytes;
231       }
232       return result;
233     });
234   }
235 
tryRead(void * buffer,size_t minBytes,size_t maxBytes)236   Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
237     auto bufs = heapArray<WSABUF>(1);
238     bufs[0].buf = reinterpret_cast<char*>(buffer);
239     bufs[0].len = maxBytes;
240 
241     ArrayPtr<WSABUF> ref = bufs;
242     return tryReadInternal(ref, minBytes, 0).attach(kj::mv(bufs));
243   }
244 
write(const void * buffer,size_t size)245   Promise<void> write(const void* buffer, size_t size) override {
246     auto bufs = heapArray<WSABUF>(1);
247     bufs[0].buf = const_cast<char*>(reinterpret_cast<const char*>(buffer));
248     bufs[0].len = size;
249 
250     ArrayPtr<WSABUF> ref = bufs;
251     return writeInternal(ref).attach(kj::mv(bufs));
252   }
253 
write(ArrayPtr<const ArrayPtr<const byte>> pieces)254   Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override {
255     auto bufs = heapArray<WSABUF>(pieces.size());
256     for (auto i: kj::indices(pieces)) {
257       bufs[i].buf = const_cast<char*>(pieces[i].asChars().begin());
258       bufs[i].len = pieces[i].size();
259     }
260 
261     ArrayPtr<WSABUF> ref = bufs;
262     return writeInternal(ref).attach(kj::mv(bufs));
263   }
264 
connect(const struct sockaddr * addr,uint addrlen)265   kj::Promise<void> connect(const struct sockaddr* addr, uint addrlen) {
266     // In order to connect asynchronously, we need the ConnectEx() function. Apparently, we have
267     // to query the socket for it dynamically, I guess because of the insanity in which winsock
268     // can be implemented in userspace and old implementations may not support it.
269     GUID guid = WSAID_CONNECTEX;
270     LPFN_CONNECTEX connectEx = nullptr;
271     DWORD n = 0;
272     KJ_WINSOCK(WSAIoctl(fd, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid, sizeof(guid),
273                         &connectEx, sizeof(connectEx), &n, NULL, NULL)) {
274       goto fail;  // avoid memory leak due to compiler bugs
275     }
276     if (false) {
277     fail:
278       return kj::READY_NOW;
279     }
280 
281     // OK, phew, we now have our ConnectEx function pointer. Call it.
282     auto op = observer->newOperation(0);
283 
284     if (!connectEx(fd, addr, addrlen, NULL, 0, NULL, op->getOverlapped())) {
285       DWORD error = WSAGetLastError();
286       if (error != ERROR_IO_PENDING) {
287         KJ_FAIL_WIN32("ConnectEx()", error) { break; }
288         return kj::READY_NOW;
289       }
290     }
291 
292     return op->onComplete().then([this](Win32EventPort::IoResult result) {
293       if (result.errorCode != ERROR_SUCCESS) {
294         KJ_FAIL_WIN32("ConnectEx()", result.errorCode) { return; }
295       }
296 
297       // Enable shutdown() to work.
298       setsockopt(SOL_SOCKET, SO_UPDATE_CONNECT_CONTEXT, NULL, 0);
299     });
300   }
301 
whenWriteDisconnected()302   Promise<void> whenWriteDisconnected() override {
303     // Windows IOCP does not provide a direct, documented way to detect when the socket disconnects
304     // without actually doing a read or write. However, there is an undocoumented-but-stable
305     // ioctl called IOCTL_AFD_POLL which can be used for this purpose. In fact, select() is
306     // implemented in terms of this ioctl -- performed synchronously -- but it's entirely possible
307     // to put only one socket into the list and perform the ioctl asynchronously. Here's the
308     // source code for select() in Windows 2000 (not sure how this became public...):
309     //
310     //     https://github.com/pustladi/Windows-2000/blob/661d000d50637ed6fab2329d30e31775046588a9/private/net/sockets/winsock2/wsp/msafd/select.c#L59-L655
311     //
312     // And here's an interesting discussion: https://github.com/python-trio/trio/issues/52
313     //
314     // TODO(someday): Implement this with IOCTL_AFD_POLL. For now I'm leaving it unimplemented
315     //   because I added this method for a Linux-only use case.
316     return NEVER_DONE;
317   }
318 
shutdownWrite()319   void shutdownWrite() override {
320     // There's no legitimate way to get an AsyncStreamFd that isn't a socket through the
321     // Win32AsyncIoProvider interface.
322     KJ_WINSOCK(shutdown(fd, SD_SEND));
323   }
324 
abortRead()325   void abortRead() override {
326     // There's no legitimate way to get an AsyncStreamFd that isn't a socket through the
327     // Win32AsyncIoProvider interface.
328     KJ_WINSOCK(shutdown(fd, SD_RECEIVE));
329   }
330 
getsockopt(int level,int option,void * value,uint * length)331   void getsockopt(int level, int option, void* value, uint* length) override {
332     socklen_t socklen = *length;
333     KJ_WINSOCK(::getsockopt(fd, level, option,
334                             reinterpret_cast<char*>(value), &socklen));
335     *length = socklen;
336   }
337 
setsockopt(int level,int option,const void * value,uint length)338   void setsockopt(int level, int option, const void* value, uint length) override {
339     KJ_WINSOCK(::setsockopt(fd, level, option,
340                             reinterpret_cast<const char*>(value), length));
341   }
342 
getsockname(struct sockaddr * addr,uint * length)343   void getsockname(struct sockaddr* addr, uint* length) override {
344     socklen_t socklen = *length;
345     KJ_WINSOCK(::getsockname(fd, addr, &socklen));
346     *length = socklen;
347   }
348 
getpeername(struct sockaddr * addr,uint * length)349   void getpeername(struct sockaddr* addr, uint* length) override {
350     socklen_t socklen = *length;
351     KJ_WINSOCK(::getpeername(fd, addr, &socklen));
352     *length = socklen;
353   }
354 
355 private:
356   Own<Win32EventPort::IoObserver> observer;
357 
tryReadInternal(ArrayPtr<WSABUF> bufs,size_t minBytes,size_t alreadyRead)358   Promise<size_t> tryReadInternal(ArrayPtr<WSABUF> bufs, size_t minBytes, size_t alreadyRead) {
359     // `bufs` will remain valid until the promise completes and may be freely modified.
360     //
361     // `alreadyRead` is the number of bytes we have already received via previous reads -- minBytes
362     // and buffer have already been adjusted to account for them, but this count must be included
363     // in the final return value.
364 
365     auto op = observer->newOperation(0);
366 
367     DWORD flags = 0;
368     if (WSARecv(fd, bufs.begin(), bufs.size(), NULL, &flags,
369                 op->getOverlapped(), NULL) == SOCKET_ERROR) {
370       DWORD error = WSAGetLastError();
371       if (error != WSA_IO_PENDING) {
372         KJ_FAIL_WIN32("WSARecv()", error) { break; }
373         return alreadyRead;
374       }
375     }
376 
377     return op->onComplete()
378         .then([this,KJ_CPCAP(bufs),minBytes,alreadyRead](Win32IocpEventPort::IoResult result) mutable
379               -> Promise<size_t> {
380       if (result.errorCode != ERROR_SUCCESS) {
381         if (alreadyRead > 0) {
382           // Report what we already read.
383           return alreadyRead;
384         } else {
385           KJ_FAIL_WIN32("WSARecv()", result.errorCode) { break; }
386           return size_t(0);
387         }
388       }
389 
390       if (result.bytesTransferred == 0) {
391         return alreadyRead;
392       }
393 
394       alreadyRead += result.bytesTransferred;
395       if (result.bytesTransferred >= minBytes) {
396         // We can stop here.
397         return alreadyRead;
398       }
399       minBytes -= result.bytesTransferred;
400 
401       while (result.bytesTransferred >= bufs[0].len) {
402         result.bytesTransferred -= bufs[0].len;
403         bufs = bufs.slice(1, bufs.size());
404       }
405 
406       if (result.bytesTransferred > 0) {
407         bufs[0].buf += result.bytesTransferred;
408         bufs[0].len -= result.bytesTransferred;
409       }
410 
411       return tryReadInternal(bufs, minBytes, alreadyRead);
412     }).attach(kj::mv(bufs));
413   }
414 
writeInternal(ArrayPtr<WSABUF> bufs)415   Promise<void> writeInternal(ArrayPtr<WSABUF> bufs) {
416     // `bufs` will remain valid until the promise completes and may be freely modified.
417 
418     auto op = observer->newOperation(0);
419 
420     if (WSASend(fd, bufs.begin(), bufs.size(), NULL, 0,
421                 op->getOverlapped(), NULL) == SOCKET_ERROR) {
422       DWORD error = WSAGetLastError();
423       if (error != WSA_IO_PENDING) {
424         KJ_FAIL_WIN32("WSASend()", error) { break; }
425         return kj::READY_NOW;
426       }
427     }
428 
429     return op->onComplete()
430         .then([this,KJ_CPCAP(bufs)](Win32IocpEventPort::IoResult result) mutable -> Promise<void> {
431       if (result.errorCode != ERROR_SUCCESS) {
432         KJ_FAIL_WIN32("WSASend()", result.errorCode) { break; }
433         return kj::READY_NOW;
434       }
435 
436       while (bufs.size() > 0 && result.bytesTransferred >= bufs[0].len) {
437         result.bytesTransferred -= bufs[0].len;
438         bufs = bufs.slice(1, bufs.size());
439       }
440 
441       if (result.bytesTransferred > 0) {
442         bufs[0].buf += result.bytesTransferred;
443         bufs[0].len -= result.bytesTransferred;
444       }
445 
446       if (bufs.size() > 0) {
447         return writeInternal(bufs);
448       } else {
449         return kj::READY_NOW;
450       }
451     }).attach(kj::mv(bufs));
452   }
453 };
454 
455 // =======================================================================================
456 
457 class SocketAddress {
458 public:
SocketAddress(const void * sockaddr,uint len)459   SocketAddress(const void* sockaddr, uint len): addrlen(len) {
460     KJ_REQUIRE(len <= sizeof(addr), "Sorry, your sockaddr is too big for me.");
461     memcpy(&addr.generic, sockaddr, len);
462   }
463 
operator <(const SocketAddress & other) const464   bool operator<(const SocketAddress& other) const {
465     // So we can use std::set<SocketAddress>...  see DNS lookup code.
466 
467     if (wildcard < other.wildcard) return true;
468     if (wildcard > other.wildcard) return false;
469 
470     if (addrlen < other.addrlen) return true;
471     if (addrlen > other.addrlen) return false;
472 
473     return memcmp(&addr.generic, &other.addr.generic, addrlen) < 0;
474   }
475 
getRaw() const476   const struct sockaddr* getRaw() const { return &addr.generic; }
getRawSize() const477   int getRawSize() const { return addrlen; }
478 
socket(int type) const479   SOCKET socket(int type) const {
480     bool isStream = type == SOCK_STREAM;
481 
482     SOCKET result = ::socket(addr.generic.sa_family, type, 0);
483 
484     if (result == INVALID_SOCKET) {
485       KJ_FAIL_WIN32("WSASocket()", WSAGetLastError()) { return INVALID_SOCKET; }
486     }
487 
488     if (isStream && (addr.generic.sa_family == AF_INET ||
489                      addr.generic.sa_family == AF_INET6)) {
490       // TODO(perf):  As a hack for the 0.4 release we are always setting
491       //   TCP_NODELAY because Nagle's algorithm pretty much kills Cap'n Proto's
492       //   RPC protocol.  Later, we should extend the interface to provide more
493       //   control over this.  Perhaps write() should have a flag which
494       //   specifies whether to pass MSG_MORE.
495       BOOL one = TRUE;
496       KJ_WINSOCK(setsockopt(result, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(one)));
497     }
498 
499     return result;
500   }
501 
bind(SOCKET sockfd) const502   void bind(SOCKET sockfd) const {
503     if (wildcard) {
504       // Disable IPV6_V6ONLY because we want to handle both ipv4 and ipv6 on this socket.  (The
505       // default value of this option varies across platforms.)
506       DWORD value = 0;
507       KJ_WINSOCK(setsockopt(sockfd, IPPROTO_IPV6, IPV6_V6ONLY,
508                             reinterpret_cast<char*>(&value), sizeof(value)));
509     }
510 
511     KJ_WINSOCK(::bind(sockfd, &addr.generic, addrlen), toString());
512   }
513 
getPort() const514   uint getPort() const {
515     switch (addr.generic.sa_family) {
516       case AF_INET: return ntohs(addr.inet4.sin_port);
517       case AF_INET6: return ntohs(addr.inet6.sin6_port);
518       default: return 0;
519     }
520   }
521 
toString() const522   String toString() const {
523     if (wildcard) {
524       return str("*:", getPort());
525     }
526 
527     switch (addr.generic.sa_family) {
528       case AF_INET: {
529         char buffer[16];
530         if (InetNtopA(addr.inet4.sin_family, const_cast<struct in_addr*>(&addr.inet4.sin_addr),
531                       buffer, sizeof(buffer)) == nullptr) {
532           KJ_FAIL_WIN32("InetNtop", WSAGetLastError()) { break; }
533           return heapString("(inet_ntop error)");
534         }
535         return str(buffer, ':', ntohs(addr.inet4.sin_port));
536       }
537       case AF_INET6: {
538         char buffer[46];
539         if (InetNtopA(addr.inet6.sin6_family, const_cast<struct in6_addr*>(&addr.inet6.sin6_addr),
540                       buffer, sizeof(buffer)) == nullptr) {
541           KJ_FAIL_WIN32("InetNtop", WSAGetLastError()) { break; }
542           return heapString("(inet_ntop error)");
543         }
544         return str('[', buffer, "]:", ntohs(addr.inet6.sin6_port));
545       }
546       default:
547         return str("(unknown address family ", addr.generic.sa_family, ")");
548     }
549   }
550 
551   static Promise<Array<SocketAddress>> lookupHost(
552       LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint,
553       _::NetworkFilter& filter);
554   // Perform a DNS lookup.
555 
parse(LowLevelAsyncIoProvider & lowLevel,StringPtr str,uint portHint,_::NetworkFilter & filter)556   static Promise<Array<SocketAddress>> parse(
557       LowLevelAsyncIoProvider& lowLevel, StringPtr str, uint portHint, _::NetworkFilter& filter) {
558     // TODO(someday):  Allow commas in `str`.
559 
560     SocketAddress result;
561 
562     // Try to separate the address and port.
563     ArrayPtr<const char> addrPart;
564     Maybe<StringPtr> portPart;
565 
566     int af;
567 
568     if (str.startsWith("[")) {
569       // Address starts with a bracket, which is a common way to write an ip6 address with a port,
570       // since without brackets around the address part, the port looks like another segment of
571       // the address.
572       af = AF_INET6;
573       size_t closeBracket = KJ_ASSERT_NONNULL(str.findLast(']'),
574           "Unclosed '[' in address string.", str);
575 
576       addrPart = str.slice(1, closeBracket);
577       if (str.size() > closeBracket + 1) {
578         KJ_REQUIRE(str.slice(closeBracket + 1).startsWith(":"),
579                    "Expected port suffix after ']'.", str);
580         portPart = str.slice(closeBracket + 2);
581       }
582     } else {
583       KJ_IF_MAYBE(colon, str.findFirst(':')) {
584         if (str.slice(*colon + 1).findFirst(':') == nullptr) {
585           // There is exactly one colon and no brackets, so it must be an ip4 address with port.
586           af = AF_INET;
587           addrPart = str.slice(0, *colon);
588           portPart = str.slice(*colon + 1);
589         } else {
590           // There are two or more colons and no brackets, so the whole thing must be an ip6
591           // address with no port.
592           af = AF_INET6;
593           addrPart = str;
594         }
595       } else {
596         // No colons, so it must be an ip4 address without port.
597         af = AF_INET;
598         addrPart = str;
599       }
600     }
601 
602     // Parse the port.
603     unsigned long port;
604     KJ_IF_MAYBE(portText, portPart) {
605       char* endptr;
606       port = strtoul(portText->cStr(), &endptr, 0);
607       if (portText->size() == 0 || *endptr != '\0') {
608         // Not a number.  Maybe it's a service name.  Fall back to DNS.
609         return lookupHost(lowLevel, kj::heapString(addrPart), kj::heapString(*portText), portHint,
610                           filter);
611       }
612       KJ_REQUIRE(port < 65536, "Port number too large.");
613     } else {
614       port = portHint;
615     }
616 
617     // Check for wildcard.
618     if (addrPart.size() == 1 && addrPart[0] == '*') {
619       result.wildcard = true;
620       // Create an ip6 socket and set IPV6_V6ONLY to 0 later.
621       result.addrlen = sizeof(addr.inet6);
622       result.addr.inet6.sin6_family = AF_INET6;
623       result.addr.inet6.sin6_port = htons(port);
624       auto array = kj::heapArrayBuilder<SocketAddress>(1);
625       array.add(result);
626       return array.finish();
627     }
628 
629     void* addrTarget;
630     if (af == AF_INET6) {
631       result.addrlen = sizeof(addr.inet6);
632       result.addr.inet6.sin6_family = AF_INET6;
633       result.addr.inet6.sin6_port = htons(port);
634       addrTarget = &result.addr.inet6.sin6_addr;
635     } else {
636       result.addrlen = sizeof(addr.inet4);
637       result.addr.inet4.sin_family = AF_INET;
638       result.addr.inet4.sin_port = htons(port);
639       addrTarget = &result.addr.inet4.sin_addr;
640     }
641 
642     char buffer[64];
643     if (addrPart.size() < sizeof(buffer) - 1) {
644       // addrPart is not necessarily NUL-terminated so we have to make a copy.  :(
645       memcpy(buffer, addrPart.begin(), addrPart.size());
646       buffer[addrPart.size()] = '\0';
647 
648       // OK, parse it!
649       switch (InetPtonA(af, buffer, addrTarget)) {
650         case 1: {
651           // success.
652           if (!result.parseAllowedBy(filter)) {
653             KJ_FAIL_REQUIRE("address family blocked by restrictPeers()");
654             return Array<SocketAddress>();
655           }
656 
657           auto array = kj::heapArrayBuilder<SocketAddress>(1);
658           array.add(result);
659           return array.finish();
660         }
661         case 0:
662           // It's apparently not a simple address...  fall back to DNS.
663           break;
664         default:
665           KJ_FAIL_WIN32("InetPton", WSAGetLastError(), af, addrPart);
666       }
667     }
668 
669     return lookupHost(lowLevel, kj::heapString(addrPart), nullptr, port, filter);
670   }
671 
getLocalAddress(SOCKET sockfd)672   static SocketAddress getLocalAddress(SOCKET sockfd) {
673     SocketAddress result;
674     result.addrlen = sizeof(addr);
675     KJ_WINSOCK(getsockname(sockfd, &result.addr.generic, &result.addrlen));
676     return result;
677   }
678 
getPeerAddress(SOCKET sockfd)679   static SocketAddress getPeerAddress(SOCKET sockfd) {
680     SocketAddress result;
681     result.addrlen = sizeof(addr);
682     KJ_WINSOCK(getpeername(sockfd, &result.addr.generic, &result.addrlen));
683     return result;
684   }
685 
allowedBy(LowLevelAsyncIoProvider::NetworkFilter & filter)686   bool allowedBy(LowLevelAsyncIoProvider::NetworkFilter& filter) {
687     return filter.shouldAllow(&addr.generic, addrlen);
688   }
689 
parseAllowedBy(_::NetworkFilter & filter)690   bool parseAllowedBy(_::NetworkFilter& filter) {
691     return filter.shouldAllowParse(&addr.generic, addrlen);
692   }
693 
getWildcardForFamily(int family)694   static SocketAddress getWildcardForFamily(int family) {
695     SocketAddress result;
696     switch (family) {
697       case AF_INET:
698         result.addrlen = sizeof(addr.inet4);
699         result.addr.inet4.sin_family = AF_INET;
700         return result;
701       case AF_INET6:
702         result.addrlen = sizeof(addr.inet6);
703         result.addr.inet6.sin6_family = AF_INET6;
704         return result;
705       default:
706         KJ_FAIL_REQUIRE("unknown address family", family);
707     }
708   }
709 
710 private:
SocketAddress()711   SocketAddress(): addrlen(0) {
712     memset(&addr, 0, sizeof(addr));
713   }
714 
715   socklen_t addrlen;
716   bool wildcard = false;
717   union {
718     struct sockaddr generic;
719     struct sockaddr_in inet4;
720     struct sockaddr_in6 inet6;
721     struct sockaddr_storage storage;
722   } addr;
723 
724   struct LookupParams;
725   class LookupReader;
726 };
727 
728 class SocketAddress::LookupReader {
729   // Reads SocketAddresses off of a pipe coming from another thread that is performing
730   // getaddrinfo.
731 
732 public:
LookupReader(kj::Own<Thread> && thread,kj::Own<AsyncInputStream> && input,_::NetworkFilter & filter)733   LookupReader(kj::Own<Thread>&& thread, kj::Own<AsyncInputStream>&& input,
734                _::NetworkFilter& filter)
735       : thread(kj::mv(thread)), input(kj::mv(input)), filter(filter) {}
736 
~LookupReader()737   ~LookupReader() {
738     if (thread) thread->detach();
739   }
740 
read()741   Promise<Array<SocketAddress>> read() {
742     return input->tryRead(&current, sizeof(current), sizeof(current)).then(
743         [this](size_t n) -> Promise<Array<SocketAddress>> {
744       if (n < sizeof(current)) {
745         thread = nullptr;
746         // getaddrinfo()'s docs seem to say it will never return an empty list, but let's check
747         // anyway.
748         KJ_REQUIRE(addresses.size() > 0, "DNS lookup returned no permitted addresses.") { break; }
749         return addresses.releaseAsArray();
750       } else {
751         // getaddrinfo() can return multiple copies of the same address for several reasons.
752         // A major one is that we don't give it a socket type (SOCK_STREAM vs. SOCK_DGRAM), so
753         // it may return two copies of the same address, one for each type, unless it explicitly
754         // knows that the service name given is specific to one type.  But we can't tell it a type,
755         // because we don't actually know which one the user wants, and if we specify SOCK_STREAM
756         // while the user specified a UDP service name then they'll get a resolution error which
757         // is lame.  (At least, I think that's how it works.)
758         //
759         // So we instead resort to de-duping results.
760         if (alreadySeen.insert(current).second) {
761           if (current.parseAllowedBy(filter)) {
762             addresses.add(current);
763           }
764         }
765         return read();
766       }
767     });
768   }
769 
770 private:
771   kj::Own<Thread> thread;
772   kj::Own<AsyncInputStream> input;
773   _::NetworkFilter& filter;
774   SocketAddress current;
775   kj::Vector<SocketAddress> addresses;
776   std::set<SocketAddress> alreadySeen;
777 };
778 
779 struct SocketAddress::LookupParams {
780   kj::String host;
781   kj::String service;
782 };
783 
lookupHost(LowLevelAsyncIoProvider & lowLevel,kj::String host,kj::String service,uint portHint,_::NetworkFilter & filter)784 Promise<Array<SocketAddress>> SocketAddress::lookupHost(
785     LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint,
786     _::NetworkFilter& filter) {
787   // This shitty function spawns a thread to run getaddrinfo().  Unfortunately, getaddrinfo() is
788   // the only cross-platform DNS API and it is blocking.
789   //
790   // TODO(perf): Use GetAddrInfoEx(). But there are problems:
791   // - Not implemented in Wine.
792   // - Doesn't seem compatible with I/O completion ports, in particular because it's not associated
793   //   with a handle. Could signal completion as an APC instead, but that requires the IOCP code
794   //   to use GetQueuedCompletionStatusEx() which it doesn't right now because it's not available
795   //   in Wine.
796   // - Requires Unicode, for some reason. Only GetAddrInfoExW() supports async, according to the
797   //   docs. Never mind that DNS itself is ASCII...
798 
799   SOCKET fds[2];
800   KJ_WINSOCK(_::win32Socketpair(fds));
801 
802   auto input = lowLevel.wrapInputFd(fds[0], NEW_FD_FLAGS);
803 
804   int outFd = fds[1];
805 
806   LookupParams params = { kj::mv(host), kj::mv(service) };
807 
808   auto thread = heap<Thread>(kj::mvCapture(params, [outFd,portHint](LookupParams&& params) {
809     KJ_DEFER(closesocket(outFd));
810 
811     struct addrinfo* list;
812     int status = getaddrinfo(
813         params.host == "*" ? nullptr : params.host.cStr(),
814         params.service == nullptr ? nullptr : params.service.cStr(),
815         nullptr, &list);
816     if (status == 0) {
817       KJ_DEFER(freeaddrinfo(list));
818 
819       struct addrinfo* cur = list;
820       while (cur != nullptr) {
821         if (params.service == nullptr) {
822           switch (cur->ai_addr->sa_family) {
823             case AF_INET:
824               ((struct sockaddr_in*)cur->ai_addr)->sin_port = htons(portHint);
825               break;
826             case AF_INET6:
827               ((struct sockaddr_in6*)cur->ai_addr)->sin6_port = htons(portHint);
828               break;
829             default:
830               break;
831           }
832         }
833 
834         SocketAddress addr;
835         memset(&addr, 0, sizeof(addr));  // mollify valgrind
836         if (params.host == "*") {
837           // Set up a wildcard SocketAddress.  Only use the port number returned by getaddrinfo().
838           addr.wildcard = true;
839           addr.addrlen = sizeof(addr.addr.inet6);
840           addr.addr.inet6.sin6_family = AF_INET6;
841           switch (cur->ai_addr->sa_family) {
842             case AF_INET:
843               addr.addr.inet6.sin6_port = ((struct sockaddr_in*)cur->ai_addr)->sin_port;
844               break;
845             case AF_INET6:
846               addr.addr.inet6.sin6_port = ((struct sockaddr_in6*)cur->ai_addr)->sin6_port;
847               break;
848             default:
849               addr.addr.inet6.sin6_port = portHint;
850               break;
851           }
852         } else {
853           addr.addrlen = cur->ai_addrlen;
854           memcpy(&addr.addr.generic, cur->ai_addr, cur->ai_addrlen);
855         }
856         KJ_ASSERT_CAN_MEMCPY(SocketAddress);
857 
858         const char* data = reinterpret_cast<const char*>(&addr);
859         size_t size = sizeof(addr);
860         while (size > 0) {
861           int n;
862           KJ_WINSOCK(n = send(outFd, data, size, 0));
863           data += n;
864           size -= n;
865         }
866 
867         cur = cur->ai_next;
868       }
869     } else {
870       KJ_FAIL_WIN32("getaddrinfo()", status, params.host, params.service) {
871         return;
872       }
873     }
874   }));
875 
876   auto reader = heap<LookupReader>(kj::mv(thread), kj::mv(input), filter);
877   return reader->read().attach(kj::mv(reader));
878 }
879 
880 // =======================================================================================
881 
882 class FdConnectionReceiver final: public ConnectionReceiver, public OwnedFd {
883 public:
FdConnectionReceiver(Win32EventPort & eventPort,SOCKET fd,LowLevelAsyncIoProvider::NetworkFilter & filter,uint flags)884   FdConnectionReceiver(Win32EventPort& eventPort, SOCKET fd,
885                        LowLevelAsyncIoProvider::NetworkFilter& filter, uint flags)
886       : OwnedFd(fd, flags), eventPort(eventPort), filter(filter),
887         observer(eventPort.observeIo(reinterpret_cast<HANDLE>(fd))),
888         address(SocketAddress::getLocalAddress(fd)) {
889     // In order to accept asynchronously, we need the AcceptEx() function. Apparently, we have
890     // to query the socket for it dynamically, I guess because of the insanity in which winsock
891     // can be implemented in userspace and old implementations may not support it.
892     GUID guid = WSAID_ACCEPTEX;
893     DWORD n = 0;
894     KJ_WINSOCK(WSAIoctl(fd, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid, sizeof(guid),
895                         &acceptEx, sizeof(acceptEx), &n, NULL, NULL)) {
896       acceptEx = nullptr;
897       return;
898     }
899   }
900 
accept()901   Promise<Own<AsyncIoStream>> accept() override {
902     SOCKET newFd = address.socket(SOCK_STREAM);
903     KJ_ASSERT(newFd != INVALID_SOCKET);
904     auto result = heap<AsyncStreamFd>(eventPort, newFd, NEW_FD_FLAGS);
905 
906     auto scratch = heapArray<byte>(256);
907     DWORD dummy;
908     auto op = observer->newOperation(0);
909     if (!acceptEx(fd, newFd, scratch.begin(), 0, 128, 128, &dummy, op->getOverlapped())) {
910       DWORD error = WSAGetLastError();
911       if (error != ERROR_IO_PENDING) {
912         KJ_FAIL_WIN32("AcceptEx()", error) { break; }
913         return Own<AsyncIoStream>(kj::mv(result));  // dummy, won't be used
914       }
915     }
916 
917     return op->onComplete().then(mvCapture(result, mvCapture(scratch,
918         [this,newFd]
919         (Array<byte> scratch, Own<AsyncIoStream> stream, Win32EventPort::IoResult ioResult)
920         -> Promise<Own<AsyncIoStream>> {
921       if (ioResult.errorCode != ERROR_SUCCESS) {
922         KJ_FAIL_WIN32("AcceptEx()", ioResult.errorCode) { break; }
923       } else {
924         SOCKET me = fd;
925         stream->setsockopt(SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT,
926                            reinterpret_cast<char*>(&me), sizeof(me));
927       }
928 
929       // Supposedly, AcceptEx() places the local and peer addresses into the buffer (which we've
930       // named `scratch`). However, the format in which it writes these is undocumented, and
931       // doesn't even match between native Windows and WINE. Apparently it is useless. I don't know
932       // why they require the buffer to have space for it in the first place. We'll need to call
933       // getpeername() to get the address.
934       auto addr = SocketAddress::getPeerAddress(newFd);
935       if (addr.allowedBy(filter)) {
936         return kj::mv(stream);
937       } else {
938         return accept();
939       }
940     })));
941   }
942 
getPort()943   uint getPort() override {
944     return address.getPort();
945   }
946 
getsockopt(int level,int option,void * value,uint * length)947   void getsockopt(int level, int option, void* value, uint* length) override {
948     socklen_t socklen = *length;
949     KJ_WINSOCK(::getsockopt(fd, level, option,
950                             reinterpret_cast<char*>(value), &socklen));
951     *length = socklen;
952   }
setsockopt(int level,int option,const void * value,uint length)953   void setsockopt(int level, int option, const void* value, uint length) override {
954     KJ_WINSOCK(::setsockopt(fd, level, option,
955                             reinterpret_cast<const char*>(value), length));
956   }
getsockname(struct sockaddr * addr,uint * length)957   void getsockname(struct sockaddr* addr, uint* length) override {
958     socklen_t socklen = *length;
959     KJ_WINSOCK(::getsockname(fd, addr, &socklen));
960     *length = socklen;
961   }
962 
963 public:
964   Win32EventPort& eventPort;
965   LowLevelAsyncIoProvider::NetworkFilter& filter;
966   Own<Win32EventPort::IoObserver> observer;
967   LPFN_ACCEPTEX acceptEx = nullptr;
968   SocketAddress address;
969 };
970 
971 // TODO(someday): DatagramPortImpl
972 
973 class LowLevelAsyncIoProviderImpl final: public LowLevelAsyncIoProvider {
974 public:
LowLevelAsyncIoProviderImpl()975   LowLevelAsyncIoProviderImpl()
976       : eventLoop(eventPort), waitScope(eventLoop) {}
977 
getWaitScope()978   inline WaitScope& getWaitScope() { return waitScope; }
979 
wrapInputFd(SOCKET fd,uint flags=0)980   Own<AsyncInputStream> wrapInputFd(SOCKET fd, uint flags = 0) override {
981     return heap<AsyncStreamFd>(eventPort, fd, flags);
982   }
wrapOutputFd(SOCKET fd,uint flags=0)983   Own<AsyncOutputStream> wrapOutputFd(SOCKET fd, uint flags = 0) override {
984     return heap<AsyncStreamFd>(eventPort, fd, flags);
985   }
wrapSocketFd(SOCKET fd,uint flags=0)986   Own<AsyncIoStream> wrapSocketFd(SOCKET fd, uint flags = 0) override {
987     return heap<AsyncStreamFd>(eventPort, fd, flags);
988   }
wrapConnectingSocketFd(SOCKET fd,const struct sockaddr * addr,uint addrlen,uint flags=0)989   Promise<Own<AsyncIoStream>> wrapConnectingSocketFd(
990       SOCKET fd, const struct sockaddr* addr, uint addrlen, uint flags = 0) override {
991     auto result = heap<AsyncStreamFd>(eventPort, fd, flags);
992 
993     // ConnectEx requires that the socket be bound, for some reason. Bind to an arbitrary port.
994     SocketAddress::getWildcardForFamily(addr->sa_family).bind(fd);
995 
996     auto connected = result->connect(addr, addrlen);
997     return connected.then(kj::mvCapture(result, [](Own<AsyncIoStream>&& result) {
998       return kj::mv(result);
999     }));
1000   }
wrapListenSocketFd(SOCKET fd,NetworkFilter & filter,uint flags=0)1001   Own<ConnectionReceiver> wrapListenSocketFd(
1002       SOCKET fd, NetworkFilter& filter, uint flags = 0) override {
1003     return heap<FdConnectionReceiver>(eventPort, fd, filter, flags);
1004   }
1005 
getTimer()1006   Timer& getTimer() override { return eventPort.getTimer(); }
1007 
getEventPort()1008   Win32EventPort& getEventPort() { return eventPort; }
1009 
1010 private:
1011   Win32IocpEventPort eventPort;
1012   EventLoop eventLoop;
1013   WaitScope waitScope;
1014 };
1015 
1016 // =======================================================================================
1017 
1018 class NetworkAddressImpl final: public NetworkAddress {
1019 public:
NetworkAddressImpl(LowLevelAsyncIoProvider & lowLevel,LowLevelAsyncIoProvider::NetworkFilter & filter,Array<SocketAddress> addrs)1020   NetworkAddressImpl(LowLevelAsyncIoProvider& lowLevel,
1021                      LowLevelAsyncIoProvider::NetworkFilter& filter,
1022                      Array<SocketAddress> addrs)
1023       : lowLevel(lowLevel), filter(filter), addrs(kj::mv(addrs)) {}
1024 
connect()1025   Promise<Own<AsyncIoStream>> connect() override {
1026     auto addrsCopy = heapArray(addrs.asPtr());
1027     auto promise = connectImpl(lowLevel, filter, addrsCopy);
1028     return promise.attach(kj::mv(addrsCopy));
1029   }
1030 
listen()1031   Own<ConnectionReceiver> listen() override {
1032     if (addrs.size() > 1) {
1033       KJ_LOG(WARNING, "Bind address resolved to multiple addresses.  Only the first address will "
1034           "be used.  If this is incorrect, specify the address numerically.  This may be fixed "
1035           "in the future.", addrs[0].toString());
1036     }
1037 
1038     int fd = addrs[0].socket(SOCK_STREAM);
1039 
1040     {
1041       KJ_ON_SCOPE_FAILURE(closesocket(fd));
1042 
1043       // We always enable SO_REUSEADDR because having to take your server down for five minutes
1044       // before it can restart really sucks.
1045       int optval = 1;
1046       KJ_WINSOCK(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR,
1047                             reinterpret_cast<char*>(&optval), sizeof(optval)));
1048 
1049       addrs[0].bind(fd);
1050 
1051       // TODO(someday):  Let queue size be specified explicitly in string addresses.
1052       KJ_WINSOCK(::listen(fd, SOMAXCONN));
1053     }
1054 
1055     return lowLevel.wrapListenSocketFd(fd, filter, NEW_FD_FLAGS);
1056   }
1057 
bindDatagramPort()1058   Own<DatagramPort> bindDatagramPort() override {
1059     if (addrs.size() > 1) {
1060       KJ_LOG(WARNING, "Bind address resolved to multiple addresses.  Only the first address will "
1061           "be used.  If this is incorrect, specify the address numerically.  This may be fixed "
1062           "in the future.", addrs[0].toString());
1063     }
1064 
1065     int fd = addrs[0].socket(SOCK_DGRAM);
1066 
1067     {
1068       KJ_ON_SCOPE_FAILURE(closesocket(fd));
1069 
1070       // We always enable SO_REUSEADDR because having to take your server down for five minutes
1071       // before it can restart really sucks.
1072       int optval = 1;
1073       KJ_WINSOCK(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR,
1074                             reinterpret_cast<char*>(&optval), sizeof(optval)));
1075 
1076       addrs[0].bind(fd);
1077     }
1078 
1079     return lowLevel.wrapDatagramSocketFd(fd, filter, NEW_FD_FLAGS);
1080   }
1081 
clone()1082   Own<NetworkAddress> clone() override {
1083     return kj::heap<NetworkAddressImpl>(lowLevel, filter, kj::heapArray(addrs.asPtr()));
1084   }
1085 
toString()1086   String toString() override {
1087     return strArray(KJ_MAP(addr, addrs) { return addr.toString(); }, ",");
1088   }
1089 
chooseOneAddress()1090   const SocketAddress& chooseOneAddress() {
1091     KJ_REQUIRE(addrs.size() > 0, "No addresses available.");
1092     return addrs[counter++ % addrs.size()];
1093   }
1094 
1095 private:
1096   LowLevelAsyncIoProvider& lowLevel;
1097   LowLevelAsyncIoProvider::NetworkFilter& filter;
1098   Array<SocketAddress> addrs;
1099   uint counter = 0;
1100 
connectImpl(LowLevelAsyncIoProvider & lowLevel,LowLevelAsyncIoProvider::NetworkFilter & filter,ArrayPtr<SocketAddress> addrs)1101   static Promise<Own<AsyncIoStream>> connectImpl(
1102       LowLevelAsyncIoProvider& lowLevel,
1103       LowLevelAsyncIoProvider::NetworkFilter& filter,
1104       ArrayPtr<SocketAddress> addrs) {
1105     KJ_ASSERT(addrs.size() > 0);
1106 
1107     int fd = addrs[0].socket(SOCK_STREAM);
1108 
1109     return kj::evalNow([&]() -> Promise<Own<AsyncIoStream>> {
1110       if (!addrs[0].allowedBy(filter)) {
1111         return KJ_EXCEPTION(FAILED, "connect() blocked by restrictPeers()");
1112       } else {
1113         return lowLevel.wrapConnectingSocketFd(
1114             fd, addrs[0].getRaw(), addrs[0].getRawSize(), NEW_FD_FLAGS);
1115       }
1116     }).then([](Own<AsyncIoStream>&& stream) -> Promise<Own<AsyncIoStream>> {
1117       // Success, pass along.
1118       return kj::mv(stream);
1119     }, [&lowLevel,&filter,KJ_CPCAP(addrs)](Exception&& exception) mutable
1120         -> Promise<Own<AsyncIoStream>> {
1121       // Connect failed.
1122       if (addrs.size() > 1) {
1123         // Try the next address instead.
1124         return connectImpl(lowLevel, filter, addrs.slice(1, addrs.size()));
1125       } else {
1126         // No more addresses to try, so propagate the exception.
1127         return kj::mv(exception);
1128       }
1129     });
1130   }
1131 };
1132 
1133 class SocketNetwork final: public Network {
1134 public:
SocketNetwork(LowLevelAsyncIoProvider & lowLevel)1135   explicit SocketNetwork(LowLevelAsyncIoProvider& lowLevel): lowLevel(lowLevel) {}
SocketNetwork(SocketNetwork & parent,kj::ArrayPtr<const kj::StringPtr> allow,kj::ArrayPtr<const kj::StringPtr> deny)1136   explicit SocketNetwork(SocketNetwork& parent,
1137                          kj::ArrayPtr<const kj::StringPtr> allow,
1138                          kj::ArrayPtr<const kj::StringPtr> deny)
1139       : lowLevel(parent.lowLevel), filter(allow, deny, parent.filter) {}
1140 
parseAddress(StringPtr addr,uint portHint=0)1141   Promise<Own<NetworkAddress>> parseAddress(StringPtr addr, uint portHint = 0) override {
1142     return evalLater(mvCapture(heapString(addr), [this,portHint](String&& addr) {
1143       return SocketAddress::parse(lowLevel, addr, portHint, filter);
1144     })).then([this](Array<SocketAddress> addresses) -> Own<NetworkAddress> {
1145       return heap<NetworkAddressImpl>(lowLevel, filter, kj::mv(addresses));
1146     });
1147   }
1148 
getSockaddr(const void * sockaddr,uint len)1149   Own<NetworkAddress> getSockaddr(const void* sockaddr, uint len) override {
1150     auto array = kj::heapArrayBuilder<SocketAddress>(1);
1151     array.add(SocketAddress(sockaddr, len));
1152     KJ_REQUIRE(array[0].allowedBy(filter), "address blocked by restrictPeers()") { break; }
1153     return Own<NetworkAddress>(heap<NetworkAddressImpl>(lowLevel, filter, array.finish()));
1154   }
1155 
restrictPeers(kj::ArrayPtr<const kj::StringPtr> allow,kj::ArrayPtr<const kj::StringPtr> deny=nullptr)1156   Own<Network> restrictPeers(
1157       kj::ArrayPtr<const kj::StringPtr> allow,
1158       kj::ArrayPtr<const kj::StringPtr> deny = nullptr) override {
1159     return heap<SocketNetwork>(*this, allow, deny);
1160   }
1161 
1162 private:
1163   LowLevelAsyncIoProvider& lowLevel;
1164   _::NetworkFilter filter;
1165 };
1166 
1167 // =======================================================================================
1168 
1169 class AsyncIoProviderImpl final: public AsyncIoProvider {
1170 public:
AsyncIoProviderImpl(LowLevelAsyncIoProvider & lowLevel)1171   AsyncIoProviderImpl(LowLevelAsyncIoProvider& lowLevel)
1172       : lowLevel(lowLevel), network(lowLevel) {}
1173 
newOneWayPipe()1174   OneWayPipe newOneWayPipe() override {
1175     SOCKET fds[2];
1176     KJ_WINSOCK(_::win32Socketpair(fds));
1177     auto in = lowLevel.wrapSocketFd(fds[0], NEW_FD_FLAGS);
1178     auto out = lowLevel.wrapOutputFd(fds[1], NEW_FD_FLAGS);
1179     in->shutdownWrite();
1180     return { kj::mv(in), kj::mv(out) };
1181   }
1182 
newTwoWayPipe()1183   TwoWayPipe newTwoWayPipe() override {
1184     SOCKET fds[2];
1185     KJ_WINSOCK(_::win32Socketpair(fds));
1186     return TwoWayPipe { {
1187       lowLevel.wrapSocketFd(fds[0], NEW_FD_FLAGS),
1188       lowLevel.wrapSocketFd(fds[1], NEW_FD_FLAGS)
1189     } };
1190   }
1191 
getNetwork()1192   Network& getNetwork() override {
1193     return network;
1194   }
1195 
newPipeThread(Function<void (AsyncIoProvider &,AsyncIoStream &,WaitScope &)> startFunc)1196   PipeThread newPipeThread(
1197       Function<void(AsyncIoProvider&, AsyncIoStream&, WaitScope&)> startFunc) override {
1198     SOCKET fds[2];
1199     KJ_WINSOCK(_::win32Socketpair(fds));
1200 
1201     int threadFd = fds[1];
1202     KJ_ON_SCOPE_FAILURE(closesocket(threadFd));
1203 
1204     auto pipe = lowLevel.wrapSocketFd(fds[0], NEW_FD_FLAGS);
1205 
1206     auto thread = heap<Thread>(kj::mvCapture(startFunc,
1207         [threadFd](Function<void(AsyncIoProvider&, AsyncIoStream&, WaitScope&)>&& startFunc) {
1208       LowLevelAsyncIoProviderImpl lowLevel;
1209       auto stream = lowLevel.wrapSocketFd(threadFd, NEW_FD_FLAGS);
1210       AsyncIoProviderImpl ioProvider(lowLevel);
1211       startFunc(ioProvider, *stream, lowLevel.getWaitScope());
1212     }));
1213 
1214     return { kj::mv(thread), kj::mv(pipe) };
1215   }
1216 
getTimer()1217   Timer& getTimer() override { return lowLevel.getTimer(); }
1218 
1219 private:
1220   LowLevelAsyncIoProvider& lowLevel;
1221   SocketNetwork network;
1222 };
1223 
1224 }  // namespace
1225 
newAsyncIoProvider(LowLevelAsyncIoProvider & lowLevel)1226 Own<AsyncIoProvider> newAsyncIoProvider(LowLevelAsyncIoProvider& lowLevel) {
1227   return kj::heap<AsyncIoProviderImpl>(lowLevel);
1228 }
1229 
setupAsyncIo()1230 AsyncIoContext setupAsyncIo() {
1231   _::initWinsockOnce();
1232 
1233   auto lowLevel = heap<LowLevelAsyncIoProviderImpl>();
1234   auto ioProvider = kj::heap<AsyncIoProviderImpl>(*lowLevel);
1235   auto& waitScope = lowLevel->getWaitScope();
1236   auto& eventPort = lowLevel->getEventPort();
1237   return { kj::mv(lowLevel), kj::mv(ioProvider), waitScope, eventPort };
1238 }
1239 
1240 }  // namespace kj
1241 
1242 #endif  // _WIN32
1243