1 // Copyright (c) 2016 Sandstorm Development Group, Inc. and contributors
2 // Licensed under the MIT License:
3 //
4 // Permission is hereby granted, free of charge, to any person obtaining a copy
5 // of this software and associated documentation files (the "Software"), to deal
6 // in the Software without restriction, including without limitation the rights
7 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 // copies of the Software, and to permit persons to whom the Software is
9 // furnished to do so, subject to the following conditions:
10 //
11 // The above copyright notice and this permission notice shall be included in
12 // all copies or substantial portions of the Software.
13 //
14 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 // THE SOFTWARE.
21
22 #if _WIN32
23 // For Unix implementation, see async-io-unix.c++.
24
25 // Request Vista-level APIs.
26 #define WINVER 0x0600
27 #define _WIN32_WINNT 0x0600
28
29 #include "async-io.h"
30 #include "async-io-internal.h"
31 #include "async-win32.h"
32 #include "debug.h"
33 #include "thread.h"
34 #include "io.h"
35 #include "vector.h"
36 #include <set>
37
38 #include <winsock2.h>
39 #include <ws2ipdef.h>
40 #include <ws2tcpip.h>
41 #include <mswsock.h>
42 #include <stdlib.h>
43
44 #ifndef IPV6_V6ONLY
45 // MinGW's headers are missing this.
46 #define IPV6_V6ONLY 27
47 #endif
48
49 namespace kj {
50
51 namespace _ { // private
52
53 struct WinsockInitializer {
WinsockInitializerkj::_::WinsockInitializer54 WinsockInitializer() {
55 WSADATA dontcare;
56 int result = WSAStartup(MAKEWORD(2, 2), &dontcare);
57 if (result != 0) {
58 KJ_FAIL_WIN32("WSAStartup()", result);
59 }
60 }
61 };
62
initWinsockOnce()63 void initWinsockOnce() {
64 static WinsockInitializer initializer;
65 }
66
win32Socketpair(SOCKET socks[2])67 int win32Socketpair(SOCKET socks[2]) {
68 // This function from: https://github.com/ncm/selectable-socketpair/blob/master/socketpair.c
69 //
70 // Copyright notice:
71 //
72 // Copyright 2007, 2010 by Nathan C. Myers <ncm@cantrip.org>
73 // Redistribution and use in source and binary forms, with or without modification,
74 // are permitted provided that the following conditions are met:
75 //
76 // Redistributions of source code must retain the above copyright notice, this
77 // list of conditions and the following disclaimer.
78 //
79 // Redistributions in binary form must reproduce the above copyright notice,
80 // this list of conditions and the following disclaimer in the documentation
81 // and/or other materials provided with the distribution.
82 //
83 // The name of the author must not be used to endorse or promote products
84 // derived from this software without specific prior written permission.
85 //
86 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
87 // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
88 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
89 // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
90 // FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
91 // DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
92 // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
93 // CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
94 // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
95 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
96
97 // Note: This function is called from some Cap'n Proto unit tests, despite not having a public
98 // header declaration.
99 // TODO(cleanup): Consider putting this somewhere public? Note that since it depends on Winsock,
100 // it needs to be in the kj-async library.
101
102 initWinsockOnce();
103
104 union {
105 struct sockaddr_in inaddr;
106 struct sockaddr addr;
107 } a;
108 SOCKET listener;
109 int e;
110 socklen_t addrlen = sizeof(a.inaddr);
111 int reuse = 1;
112
113 if (socks == 0) {
114 WSASetLastError(WSAEINVAL);
115 return SOCKET_ERROR;
116 }
117 socks[0] = socks[1] = -1;
118
119 listener = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
120 if (listener == -1)
121 return SOCKET_ERROR;
122
123 memset(&a, 0, sizeof(a));
124 a.inaddr.sin_family = AF_INET;
125 a.inaddr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
126 a.inaddr.sin_port = 0;
127
128 for (;;) {
129 if (setsockopt(listener, SOL_SOCKET, SO_REUSEADDR,
130 (char*) &reuse, (socklen_t) sizeof(reuse)) == -1)
131 break;
132 if (bind(listener, &a.addr, sizeof(a.inaddr)) == SOCKET_ERROR)
133 break;
134
135 memset(&a, 0, sizeof(a));
136 if (getsockname(listener, &a.addr, &addrlen) == SOCKET_ERROR)
137 break;
138 // win32 getsockname may only set the port number, p=0.0005.
139 // ( http://msdn.microsoft.com/library/ms738543.aspx ):
140 a.inaddr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
141 a.inaddr.sin_family = AF_INET;
142
143 if (listen(listener, 1) == SOCKET_ERROR)
144 break;
145
146 socks[0] = WSASocket(AF_INET, SOCK_STREAM, 0, NULL, 0, WSA_FLAG_OVERLAPPED);
147 if (socks[0] == -1)
148 break;
149 if (connect(socks[0], &a.addr, sizeof(a.inaddr)) == SOCKET_ERROR)
150 break;
151
152 retryAccept:
153 socks[1] = accept(listener, NULL, NULL);
154 if (socks[1] == -1)
155 break;
156
157 // Verify that the client is actually us and not someone else who raced to connect first.
158 // (This check added by Kenton for security.)
159 union {
160 struct sockaddr_in inaddr;
161 struct sockaddr addr;
162 } b, c;
163 socklen_t bAddrlen = sizeof(b.inaddr);
164 socklen_t cAddrlen = sizeof(b.inaddr);
165 if (getpeername(socks[1], &b.addr, &bAddrlen) == SOCKET_ERROR)
166 break;
167 if (getsockname(socks[0], &c.addr, &cAddrlen) == SOCKET_ERROR)
168 break;
169 if (bAddrlen != cAddrlen || memcmp(&b.addr, &c.addr, bAddrlen) != 0) {
170 // Someone raced to connect first. Ignore.
171 closesocket(socks[1]);
172 goto retryAccept;
173 }
174
175 closesocket(listener);
176 return 0;
177 }
178
179 e = WSAGetLastError();
180 closesocket(listener);
181 closesocket(socks[0]);
182 closesocket(socks[1]);
183 WSASetLastError(e);
184 socks[0] = socks[1] = -1;
185 return SOCKET_ERROR;
186 }
187
188 } // namespace _
189
190 namespace {
191
192 // =======================================================================================
193
194 static constexpr uint NEW_FD_FLAGS = LowLevelAsyncIoProvider::TAKE_OWNERSHIP;
195
196 class OwnedFd {
197 public:
OwnedFd(SOCKET fd,uint flags)198 OwnedFd(SOCKET fd, uint flags): fd(fd), flags(flags) {
199 // TODO(perf): Maybe use SetFileCompletionNotificationModes() to tell Windows not to bother
200 // delivering an event when the operation completes inline. Not currently implemented on
201 // Wine, though.
202 }
203
~OwnedFd()204 ~OwnedFd() noexcept(false) {
205 if (flags & LowLevelAsyncIoProvider::TAKE_OWNERSHIP) {
206 KJ_WINSOCK(closesocket(fd)) { break; }
207 }
208 }
209
210 protected:
211 SOCKET fd;
212
213 private:
214 uint flags;
215 };
216
217 // =======================================================================================
218
219 class AsyncStreamFd: public OwnedFd, public AsyncIoStream {
220 public:
AsyncStreamFd(Win32EventPort & eventPort,SOCKET fd,uint flags)221 AsyncStreamFd(Win32EventPort& eventPort, SOCKET fd, uint flags)
222 : OwnedFd(fd, flags),
223 observer(eventPort.observeIo(reinterpret_cast<HANDLE>(fd))) {}
~AsyncStreamFd()224 virtual ~AsyncStreamFd() noexcept(false) {}
225
read(void * buffer,size_t minBytes,size_t maxBytes)226 Promise<size_t> read(void* buffer, size_t minBytes, size_t maxBytes) override {
227 return tryRead(buffer, minBytes, maxBytes).then([=](size_t result) {
228 KJ_REQUIRE(result >= minBytes, "Premature EOF") {
229 // Pretend we read zeros from the input.
230 memset(reinterpret_cast<byte*>(buffer) + result, 0, minBytes - result);
231 return minBytes;
232 }
233 return result;
234 });
235 }
236
tryRead(void * buffer,size_t minBytes,size_t maxBytes)237 Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
238 auto bufs = heapArray<WSABUF>(1);
239 bufs[0].buf = reinterpret_cast<char*>(buffer);
240 bufs[0].len = maxBytes;
241
242 ArrayPtr<WSABUF> ref = bufs;
243 return tryReadInternal(ref, minBytes, 0).attach(kj::mv(bufs));
244 }
245
write(const void * buffer,size_t size)246 Promise<void> write(const void* buffer, size_t size) override {
247 auto bufs = heapArray<WSABUF>(1);
248 bufs[0].buf = const_cast<char*>(reinterpret_cast<const char*>(buffer));
249 bufs[0].len = size;
250
251 ArrayPtr<WSABUF> ref = bufs;
252 return writeInternal(ref).attach(kj::mv(bufs));
253 }
254
write(ArrayPtr<const ArrayPtr<const byte>> pieces)255 Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override {
256 auto bufs = heapArray<WSABUF>(pieces.size());
257 for (auto i: kj::indices(pieces)) {
258 bufs[i].buf = const_cast<char*>(pieces[i].asChars().begin());
259 bufs[i].len = pieces[i].size();
260 }
261
262 ArrayPtr<WSABUF> ref = bufs;
263 return writeInternal(ref).attach(kj::mv(bufs));
264 }
265
connect(const struct sockaddr * addr,uint addrlen)266 kj::Promise<void> connect(const struct sockaddr* addr, uint addrlen) {
267 // In order to connect asynchronously, we need the ConnectEx() function. Apparently, we have
268 // to query the socket for it dynamically, I guess because of the insanity in which winsock
269 // can be implemented in userspace and old implementations may not support it.
270 GUID guid = WSAID_CONNECTEX;
271 LPFN_CONNECTEX connectEx = nullptr;
272 DWORD n = 0;
273 KJ_WINSOCK(WSAIoctl(fd, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid, sizeof(guid),
274 &connectEx, sizeof(connectEx), &n, NULL, NULL)) {
275 goto fail; // avoid memory leak due to compiler bugs
276 }
277 if (false) {
278 fail:
279 return kj::READY_NOW;
280 }
281
282 // OK, phew, we now have our ConnectEx function pointer. Call it.
283 auto op = observer->newOperation(0);
284
285 if (!connectEx(fd, addr, addrlen, NULL, 0, NULL, op->getOverlapped())) {
286 DWORD error = WSAGetLastError();
287 if (error != ERROR_IO_PENDING) {
288 KJ_FAIL_WIN32("ConnectEx()", error) { break; }
289 return kj::READY_NOW;
290 }
291 }
292
293 return op->onComplete().then([this](Win32EventPort::IoResult result) {
294 if (result.errorCode != ERROR_SUCCESS) {
295 KJ_FAIL_WIN32("ConnectEx()", result.errorCode) { return; }
296 }
297
298 // Enable shutdown() to work.
299 setsockopt(SOL_SOCKET, SO_UPDATE_CONNECT_CONTEXT, NULL, 0);
300 });
301 }
302
whenWriteDisconnected()303 Promise<void> whenWriteDisconnected() override {
304 // Windows IOCP does not provide a direct, documented way to detect when the socket disconnects
305 // without actually doing a read or write. However, there is an undocoumented-but-stable
306 // ioctl called IOCTL_AFD_POLL which can be used for this purpose. In fact, select() is
307 // implemented in terms of this ioctl -- performed synchronously -- but it's entirely possible
308 // to put only one socket into the list and perform the ioctl asynchronously. Here's the
309 // source code for select() in Windows 2000 (not sure how this became public...):
310 //
311 // https://github.com/pustladi/Windows-2000/blob/661d000d50637ed6fab2329d30e31775046588a9/private/net/sockets/winsock2/wsp/msafd/select.c#L59-L655
312 //
313 // And here's an interesting discussion: https://github.com/python-trio/trio/issues/52
314 //
315 // TODO(someday): Implement this with IOCTL_AFD_POLL. For now I'm leaving it unimplemented
316 // because I added this method for a Linux-only use case.
317 return NEVER_DONE;
318 }
319
shutdownWrite()320 void shutdownWrite() override {
321 // There's no legitimate way to get an AsyncStreamFd that isn't a socket through the
322 // Win32AsyncIoProvider interface.
323 KJ_WINSOCK(shutdown(fd, SD_SEND));
324 }
325
abortRead()326 void abortRead() override {
327 // There's no legitimate way to get an AsyncStreamFd that isn't a socket through the
328 // Win32AsyncIoProvider interface.
329 KJ_WINSOCK(shutdown(fd, SD_RECEIVE));
330 }
331
getsockopt(int level,int option,void * value,uint * length)332 void getsockopt(int level, int option, void* value, uint* length) override {
333 socklen_t socklen = *length;
334 KJ_WINSOCK(::getsockopt(fd, level, option,
335 reinterpret_cast<char*>(value), &socklen));
336 *length = socklen;
337 }
338
setsockopt(int level,int option,const void * value,uint length)339 void setsockopt(int level, int option, const void* value, uint length) override {
340 KJ_WINSOCK(::setsockopt(fd, level, option,
341 reinterpret_cast<const char*>(value), length));
342 }
343
getsockname(struct sockaddr * addr,uint * length)344 void getsockname(struct sockaddr* addr, uint* length) override {
345 socklen_t socklen = *length;
346 KJ_WINSOCK(::getsockname(fd, addr, &socklen));
347 *length = socklen;
348 }
349
getpeername(struct sockaddr * addr,uint * length)350 void getpeername(struct sockaddr* addr, uint* length) override {
351 socklen_t socklen = *length;
352 KJ_WINSOCK(::getpeername(fd, addr, &socklen));
353 *length = socklen;
354 }
355
356 private:
357 Own<Win32EventPort::IoObserver> observer;
358
tryReadInternal(ArrayPtr<WSABUF> bufs,size_t minBytes,size_t alreadyRead)359 Promise<size_t> tryReadInternal(ArrayPtr<WSABUF> bufs, size_t minBytes, size_t alreadyRead) {
360 // `bufs` will remain valid until the promise completes and may be freely modified.
361 //
362 // `alreadyRead` is the number of bytes we have already received via previous reads -- minBytes
363 // and buffer have already been adjusted to account for them, but this count must be included
364 // in the final return value.
365
366 auto op = observer->newOperation(0);
367
368 DWORD flags = 0;
369 if (WSARecv(fd, bufs.begin(), bufs.size(), NULL, &flags,
370 op->getOverlapped(), NULL) == SOCKET_ERROR) {
371 DWORD error = WSAGetLastError();
372 if (error != WSA_IO_PENDING) {
373 KJ_FAIL_WIN32("WSARecv()", error) { break; }
374 return alreadyRead;
375 }
376 }
377
378 return op->onComplete()
379 .then([this,KJ_CPCAP(bufs),minBytes,alreadyRead](Win32IocpEventPort::IoResult result) mutable
380 -> Promise<size_t> {
381 if (result.errorCode != ERROR_SUCCESS) {
382 if (alreadyRead > 0) {
383 // Report what we already read.
384 return alreadyRead;
385 } else {
386 KJ_FAIL_WIN32("WSARecv()", result.errorCode) { break; }
387 return size_t(0);
388 }
389 }
390
391 if (result.bytesTransferred == 0) {
392 return alreadyRead;
393 }
394
395 alreadyRead += result.bytesTransferred;
396 if (result.bytesTransferred >= minBytes) {
397 // We can stop here.
398 return alreadyRead;
399 }
400 minBytes -= result.bytesTransferred;
401
402 while (result.bytesTransferred >= bufs[0].len) {
403 result.bytesTransferred -= bufs[0].len;
404 bufs = bufs.slice(1, bufs.size());
405 }
406
407 if (result.bytesTransferred > 0) {
408 bufs[0].buf += result.bytesTransferred;
409 bufs[0].len -= result.bytesTransferred;
410 }
411
412 return tryReadInternal(bufs, minBytes, alreadyRead);
413 }).attach(kj::mv(bufs));
414 }
415
writeInternal(ArrayPtr<WSABUF> bufs)416 Promise<void> writeInternal(ArrayPtr<WSABUF> bufs) {
417 // `bufs` will remain valid until the promise completes and may be freely modified.
418
419 auto op = observer->newOperation(0);
420
421 if (WSASend(fd, bufs.begin(), bufs.size(), NULL, 0,
422 op->getOverlapped(), NULL) == SOCKET_ERROR) {
423 DWORD error = WSAGetLastError();
424 if (error != WSA_IO_PENDING) {
425 KJ_FAIL_WIN32("WSASend()", error) { break; }
426 return kj::READY_NOW;
427 }
428 }
429
430 return op->onComplete()
431 .then([this,KJ_CPCAP(bufs)](Win32IocpEventPort::IoResult result) mutable -> Promise<void> {
432 if (result.errorCode != ERROR_SUCCESS) {
433 KJ_FAIL_WIN32("WSASend()", result.errorCode) { break; }
434 return kj::READY_NOW;
435 }
436
437 while (bufs.size() > 0 && result.bytesTransferred >= bufs[0].len) {
438 result.bytesTransferred -= bufs[0].len;
439 bufs = bufs.slice(1, bufs.size());
440 }
441
442 if (result.bytesTransferred > 0) {
443 bufs[0].buf += result.bytesTransferred;
444 bufs[0].len -= result.bytesTransferred;
445 }
446
447 if (bufs.size() > 0) {
448 return writeInternal(bufs);
449 } else {
450 return kj::READY_NOW;
451 }
452 }).attach(kj::mv(bufs));
453 }
454 };
455
456 // =======================================================================================
457
458 class SocketAddress {
459 public:
SocketAddress(const void * sockaddr,uint len)460 SocketAddress(const void* sockaddr, uint len): addrlen(len) {
461 KJ_REQUIRE(len <= sizeof(addr), "Sorry, your sockaddr is too big for me.");
462 memcpy(&addr.generic, sockaddr, len);
463 }
464
operator <(const SocketAddress & other) const465 bool operator<(const SocketAddress& other) const {
466 // So we can use std::set<SocketAddress>... see DNS lookup code.
467
468 if (wildcard < other.wildcard) return true;
469 if (wildcard > other.wildcard) return false;
470
471 if (addrlen < other.addrlen) return true;
472 if (addrlen > other.addrlen) return false;
473
474 return memcmp(&addr.generic, &other.addr.generic, addrlen) < 0;
475 }
476
getRaw() const477 const struct sockaddr* getRaw() const { return &addr.generic; }
getRawSize() const478 int getRawSize() const { return addrlen; }
479
socket(int type) const480 SOCKET socket(int type) const {
481 bool isStream = type == SOCK_STREAM;
482
483 SOCKET result = ::socket(addr.generic.sa_family, type, 0);
484
485 if (result == INVALID_SOCKET) {
486 KJ_FAIL_WIN32("WSASocket()", WSAGetLastError()) { return INVALID_SOCKET; }
487 }
488
489 if (isStream && (addr.generic.sa_family == AF_INET ||
490 addr.generic.sa_family == AF_INET6)) {
491 // TODO(perf): As a hack for the 0.4 release we are always setting
492 // TCP_NODELAY because Nagle's algorithm pretty much kills Cap'n Proto's
493 // RPC protocol. Later, we should extend the interface to provide more
494 // control over this. Perhaps write() should have a flag which
495 // specifies whether to pass MSG_MORE.
496 BOOL one = TRUE;
497 KJ_WINSOCK(setsockopt(result, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(one)));
498 }
499
500 return result;
501 }
502
bind(SOCKET sockfd) const503 void bind(SOCKET sockfd) const {
504 if (wildcard) {
505 // Disable IPV6_V6ONLY because we want to handle both ipv4 and ipv6 on this socket. (The
506 // default value of this option varies across platforms.)
507 DWORD value = 0;
508 KJ_WINSOCK(setsockopt(sockfd, IPPROTO_IPV6, IPV6_V6ONLY,
509 reinterpret_cast<char*>(&value), sizeof(value)));
510 }
511
512 KJ_WINSOCK(::bind(sockfd, &addr.generic, addrlen), toString());
513 }
514
getPort() const515 uint getPort() const {
516 switch (addr.generic.sa_family) {
517 case AF_INET: return ntohs(addr.inet4.sin_port);
518 case AF_INET6: return ntohs(addr.inet6.sin6_port);
519 default: return 0;
520 }
521 }
522
toString() const523 String toString() const {
524 if (wildcard) {
525 return str("*:", getPort());
526 }
527
528 switch (addr.generic.sa_family) {
529 case AF_INET: {
530 char buffer[16];
531 if (InetNtopA(addr.inet4.sin_family, const_cast<struct in_addr*>(&addr.inet4.sin_addr),
532 buffer, sizeof(buffer)) == nullptr) {
533 KJ_FAIL_WIN32("InetNtop", WSAGetLastError()) { break; }
534 return heapString("(inet_ntop error)");
535 }
536 return str(buffer, ':', ntohs(addr.inet4.sin_port));
537 }
538 case AF_INET6: {
539 char buffer[46];
540 if (InetNtopA(addr.inet6.sin6_family, const_cast<struct in6_addr*>(&addr.inet6.sin6_addr),
541 buffer, sizeof(buffer)) == nullptr) {
542 KJ_FAIL_WIN32("InetNtop", WSAGetLastError()) { break; }
543 return heapString("(inet_ntop error)");
544 }
545 return str('[', buffer, "]:", ntohs(addr.inet6.sin6_port));
546 }
547 default:
548 return str("(unknown address family ", addr.generic.sa_family, ")");
549 }
550 }
551
552 static Promise<Array<SocketAddress>> lookupHost(
553 LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint,
554 _::NetworkFilter& filter);
555 // Perform a DNS lookup.
556
parse(LowLevelAsyncIoProvider & lowLevel,StringPtr str,uint portHint,_::NetworkFilter & filter)557 static Promise<Array<SocketAddress>> parse(
558 LowLevelAsyncIoProvider& lowLevel, StringPtr str, uint portHint, _::NetworkFilter& filter) {
559 // TODO(someday): Allow commas in `str`.
560
561 SocketAddress result;
562
563 // Try to separate the address and port.
564 ArrayPtr<const char> addrPart;
565 Maybe<StringPtr> portPart;
566
567 int af;
568
569 if (str.startsWith("[")) {
570 // Address starts with a bracket, which is a common way to write an ip6 address with a port,
571 // since without brackets around the address part, the port looks like another segment of
572 // the address.
573 af = AF_INET6;
574 size_t closeBracket = KJ_ASSERT_NONNULL(str.findLast(']'),
575 "Unclosed '[' in address string.", str);
576
577 addrPart = str.slice(1, closeBracket);
578 if (str.size() > closeBracket + 1) {
579 KJ_REQUIRE(str.slice(closeBracket + 1).startsWith(":"),
580 "Expected port suffix after ']'.", str);
581 portPart = str.slice(closeBracket + 2);
582 }
583 } else {
584 KJ_IF_MAYBE(colon, str.findFirst(':')) {
585 if (str.slice(*colon + 1).findFirst(':') == nullptr) {
586 // There is exactly one colon and no brackets, so it must be an ip4 address with port.
587 af = AF_INET;
588 addrPart = str.slice(0, *colon);
589 portPart = str.slice(*colon + 1);
590 } else {
591 // There are two or more colons and no brackets, so the whole thing must be an ip6
592 // address with no port.
593 af = AF_INET6;
594 addrPart = str;
595 }
596 } else {
597 // No colons, so it must be an ip4 address without port.
598 af = AF_INET;
599 addrPart = str;
600 }
601 }
602
603 // Parse the port.
604 unsigned long port;
605 KJ_IF_MAYBE(portText, portPart) {
606 char* endptr;
607 port = strtoul(portText->cStr(), &endptr, 0);
608 if (portText->size() == 0 || *endptr != '\0') {
609 // Not a number. Maybe it's a service name. Fall back to DNS.
610 return lookupHost(lowLevel, kj::heapString(addrPart), kj::heapString(*portText), portHint,
611 filter);
612 }
613 KJ_REQUIRE(port < 65536, "Port number too large.");
614 } else {
615 port = portHint;
616 }
617
618 // Check for wildcard.
619 if (addrPart.size() == 1 && addrPart[0] == '*') {
620 result.wildcard = true;
621 // Create an ip6 socket and set IPV6_V6ONLY to 0 later.
622 result.addrlen = sizeof(addr.inet6);
623 result.addr.inet6.sin6_family = AF_INET6;
624 result.addr.inet6.sin6_port = htons(port);
625 auto array = kj::heapArrayBuilder<SocketAddress>(1);
626 array.add(result);
627 return array.finish();
628 }
629
630 void* addrTarget;
631 if (af == AF_INET6) {
632 result.addrlen = sizeof(addr.inet6);
633 result.addr.inet6.sin6_family = AF_INET6;
634 result.addr.inet6.sin6_port = htons(port);
635 addrTarget = &result.addr.inet6.sin6_addr;
636 } else {
637 result.addrlen = sizeof(addr.inet4);
638 result.addr.inet4.sin_family = AF_INET;
639 result.addr.inet4.sin_port = htons(port);
640 addrTarget = &result.addr.inet4.sin_addr;
641 }
642
643 char buffer[64];
644 if (addrPart.size() < sizeof(buffer) - 1) {
645 // addrPart is not necessarily NUL-terminated so we have to make a copy. :(
646 memcpy(buffer, addrPart.begin(), addrPart.size());
647 buffer[addrPart.size()] = '\0';
648
649 // OK, parse it!
650 switch (InetPtonA(af, buffer, addrTarget)) {
651 case 1: {
652 // success.
653 if (!result.parseAllowedBy(filter)) {
654 KJ_FAIL_REQUIRE("address family blocked by restrictPeers()");
655 return Array<SocketAddress>();
656 }
657
658 auto array = kj::heapArrayBuilder<SocketAddress>(1);
659 array.add(result);
660 return array.finish();
661 }
662 case 0:
663 // It's apparently not a simple address... fall back to DNS.
664 break;
665 default:
666 KJ_FAIL_WIN32("InetPton", WSAGetLastError(), af, addrPart);
667 }
668 }
669
670 return lookupHost(lowLevel, kj::heapString(addrPart), nullptr, port, filter);
671 }
672
getLocalAddress(SOCKET sockfd)673 static SocketAddress getLocalAddress(SOCKET sockfd) {
674 SocketAddress result;
675 result.addrlen = sizeof(addr);
676 KJ_WINSOCK(getsockname(sockfd, &result.addr.generic, &result.addrlen));
677 return result;
678 }
679
getPeerAddress(SOCKET sockfd)680 static SocketAddress getPeerAddress(SOCKET sockfd) {
681 SocketAddress result;
682 result.addrlen = sizeof(addr);
683 KJ_WINSOCK(getpeername(sockfd, &result.addr.generic, &result.addrlen));
684 return result;
685 }
686
allowedBy(LowLevelAsyncIoProvider::NetworkFilter & filter)687 bool allowedBy(LowLevelAsyncIoProvider::NetworkFilter& filter) {
688 return filter.shouldAllow(&addr.generic, addrlen);
689 }
690
parseAllowedBy(_::NetworkFilter & filter)691 bool parseAllowedBy(_::NetworkFilter& filter) {
692 return filter.shouldAllowParse(&addr.generic, addrlen);
693 }
694
getWildcardForFamily(int family)695 static SocketAddress getWildcardForFamily(int family) {
696 SocketAddress result;
697 switch (family) {
698 case AF_INET:
699 result.addrlen = sizeof(addr.inet4);
700 result.addr.inet4.sin_family = AF_INET;
701 return result;
702 case AF_INET6:
703 result.addrlen = sizeof(addr.inet6);
704 result.addr.inet6.sin6_family = AF_INET6;
705 return result;
706 default:
707 KJ_FAIL_REQUIRE("unknown address family", family);
708 }
709 }
710
711 private:
SocketAddress()712 SocketAddress(): addrlen(0) {
713 memset(&addr, 0, sizeof(addr));
714 }
715
716 socklen_t addrlen;
717 bool wildcard = false;
718 union {
719 struct sockaddr generic;
720 struct sockaddr_in inet4;
721 struct sockaddr_in6 inet6;
722 struct sockaddr_storage storage;
723 } addr;
724
725 struct LookupParams;
726 class LookupReader;
727 };
728
729 class SocketAddress::LookupReader {
730 // Reads SocketAddresses off of a pipe coming from another thread that is performing
731 // getaddrinfo.
732
733 public:
LookupReader(kj::Own<Thread> && thread,kj::Own<AsyncInputStream> && input,_::NetworkFilter & filter)734 LookupReader(kj::Own<Thread>&& thread, kj::Own<AsyncInputStream>&& input,
735 _::NetworkFilter& filter)
736 : thread(kj::mv(thread)), input(kj::mv(input)), filter(filter) {}
737
~LookupReader()738 ~LookupReader() {
739 if (thread) thread->detach();
740 }
741
read()742 Promise<Array<SocketAddress>> read() {
743 return input->tryRead(¤t, sizeof(current), sizeof(current)).then(
744 [this](size_t n) -> Promise<Array<SocketAddress>> {
745 if (n < sizeof(current)) {
746 thread = nullptr;
747 // getaddrinfo()'s docs seem to say it will never return an empty list, but let's check
748 // anyway.
749 KJ_REQUIRE(addresses.size() > 0, "DNS lookup returned no permitted addresses.") { break; }
750 return addresses.releaseAsArray();
751 } else {
752 // getaddrinfo() can return multiple copies of the same address for several reasons.
753 // A major one is that we don't give it a socket type (SOCK_STREAM vs. SOCK_DGRAM), so
754 // it may return two copies of the same address, one for each type, unless it explicitly
755 // knows that the service name given is specific to one type. But we can't tell it a type,
756 // because we don't actually know which one the user wants, and if we specify SOCK_STREAM
757 // while the user specified a UDP service name then they'll get a resolution error which
758 // is lame. (At least, I think that's how it works.)
759 //
760 // So we instead resort to de-duping results.
761 if (alreadySeen.insert(current).second) {
762 if (current.parseAllowedBy(filter)) {
763 addresses.add(current);
764 }
765 }
766 return read();
767 }
768 });
769 }
770
771 private:
772 kj::Own<Thread> thread;
773 kj::Own<AsyncInputStream> input;
774 _::NetworkFilter& filter;
775 SocketAddress current;
776 kj::Vector<SocketAddress> addresses;
777 std::set<SocketAddress> alreadySeen;
778 };
779
780 struct SocketAddress::LookupParams {
781 kj::String host;
782 kj::String service;
783 };
784
lookupHost(LowLevelAsyncIoProvider & lowLevel,kj::String host,kj::String service,uint portHint,_::NetworkFilter & filter)785 Promise<Array<SocketAddress>> SocketAddress::lookupHost(
786 LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint,
787 _::NetworkFilter& filter) {
788 // This shitty function spawns a thread to run getaddrinfo(). Unfortunately, getaddrinfo() is
789 // the only cross-platform DNS API and it is blocking.
790 //
791 // TODO(perf): Use GetAddrInfoEx(). But there are problems:
792 // - Not implemented in Wine.
793 // - Doesn't seem compatible with I/O completion ports, in particular because it's not associated
794 // with a handle. Could signal completion as an APC instead, but that requires the IOCP code
795 // to use GetQueuedCompletionStatusEx() which it doesn't right now because it's not available
796 // in Wine.
797 // - Requires Unicode, for some reason. Only GetAddrInfoExW() supports async, according to the
798 // docs. Never mind that DNS itself is ASCII...
799
800 SOCKET fds[2];
801 KJ_WINSOCK(_::win32Socketpair(fds));
802
803 auto input = lowLevel.wrapInputFd(fds[0], NEW_FD_FLAGS);
804
805 int outFd = fds[1];
806
807 LookupParams params = { kj::mv(host), kj::mv(service) };
808
809 auto thread = heap<Thread>(kj::mvCapture(params, [outFd,portHint](LookupParams&& params) {
810 KJ_DEFER(closesocket(outFd));
811
812 struct addrinfo* list;
813 int status = getaddrinfo(
814 params.host == "*" ? nullptr : params.host.cStr(),
815 params.service == nullptr ? nullptr : params.service.cStr(),
816 nullptr, &list);
817 if (status == 0) {
818 KJ_DEFER(freeaddrinfo(list));
819
820 struct addrinfo* cur = list;
821 while (cur != nullptr) {
822 if (params.service == nullptr) {
823 switch (cur->ai_addr->sa_family) {
824 case AF_INET:
825 ((struct sockaddr_in*)cur->ai_addr)->sin_port = htons(portHint);
826 break;
827 case AF_INET6:
828 ((struct sockaddr_in6*)cur->ai_addr)->sin6_port = htons(portHint);
829 break;
830 default:
831 break;
832 }
833 }
834
835 SocketAddress addr;
836 memset(&addr, 0, sizeof(addr)); // mollify valgrind
837 if (params.host == "*") {
838 // Set up a wildcard SocketAddress. Only use the port number returned by getaddrinfo().
839 addr.wildcard = true;
840 addr.addrlen = sizeof(addr.addr.inet6);
841 addr.addr.inet6.sin6_family = AF_INET6;
842 switch (cur->ai_addr->sa_family) {
843 case AF_INET:
844 addr.addr.inet6.sin6_port = ((struct sockaddr_in*)cur->ai_addr)->sin_port;
845 break;
846 case AF_INET6:
847 addr.addr.inet6.sin6_port = ((struct sockaddr_in6*)cur->ai_addr)->sin6_port;
848 break;
849 default:
850 addr.addr.inet6.sin6_port = portHint;
851 break;
852 }
853 } else {
854 addr.addrlen = cur->ai_addrlen;
855 memcpy(&addr.addr.generic, cur->ai_addr, cur->ai_addrlen);
856 }
857 KJ_ASSERT_CAN_MEMCPY(SocketAddress);
858
859 const char* data = reinterpret_cast<const char*>(&addr);
860 size_t size = sizeof(addr);
861 while (size > 0) {
862 int n;
863 KJ_WINSOCK(n = send(outFd, data, size, 0));
864 data += n;
865 size -= n;
866 }
867
868 cur = cur->ai_next;
869 }
870 } else {
871 KJ_FAIL_WIN32("getaddrinfo()", status, params.host, params.service) {
872 return;
873 }
874 }
875 }));
876
877 auto reader = heap<LookupReader>(kj::mv(thread), kj::mv(input), filter);
878 return reader->read().attach(kj::mv(reader));
879 }
880
881 // =======================================================================================
882
883 class FdConnectionReceiver final: public ConnectionReceiver, public OwnedFd {
884 public:
FdConnectionReceiver(Win32EventPort & eventPort,SOCKET fd,LowLevelAsyncIoProvider::NetworkFilter & filter,uint flags)885 FdConnectionReceiver(Win32EventPort& eventPort, SOCKET fd,
886 LowLevelAsyncIoProvider::NetworkFilter& filter, uint flags)
887 : OwnedFd(fd, flags), eventPort(eventPort), filter(filter),
888 observer(eventPort.observeIo(reinterpret_cast<HANDLE>(fd))),
889 address(SocketAddress::getLocalAddress(fd)) {
890 // In order to accept asynchronously, we need the AcceptEx() function. Apparently, we have
891 // to query the socket for it dynamically, I guess because of the insanity in which winsock
892 // can be implemented in userspace and old implementations may not support it.
893 GUID guid = WSAID_ACCEPTEX;
894 DWORD n = 0;
895 KJ_WINSOCK(WSAIoctl(fd, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid, sizeof(guid),
896 &acceptEx, sizeof(acceptEx), &n, NULL, NULL)) {
897 acceptEx = nullptr;
898 return;
899 }
900 }
901
accept()902 Promise<Own<AsyncIoStream>> accept() override {
903 SOCKET newFd = address.socket(SOCK_STREAM);
904 KJ_ASSERT(newFd != INVALID_SOCKET);
905 auto result = heap<AsyncStreamFd>(eventPort, newFd, NEW_FD_FLAGS);
906
907 auto scratch = heapArray<byte>(256);
908 DWORD dummy;
909 auto op = observer->newOperation(0);
910 if (!acceptEx(fd, newFd, scratch.begin(), 0, 128, 128, &dummy, op->getOverlapped())) {
911 DWORD error = WSAGetLastError();
912 if (error != ERROR_IO_PENDING) {
913 KJ_FAIL_WIN32("AcceptEx()", error) { break; }
914 return Own<AsyncIoStream>(kj::mv(result)); // dummy, won't be used
915 }
916 }
917
918 return op->onComplete().then(mvCapture(result, mvCapture(scratch,
919 [this,newFd]
920 (Array<byte> scratch, Own<AsyncIoStream> stream, Win32EventPort::IoResult ioResult)
921 -> Promise<Own<AsyncIoStream>> {
922 if (ioResult.errorCode != ERROR_SUCCESS) {
923 KJ_FAIL_WIN32("AcceptEx()", ioResult.errorCode) { break; }
924 } else {
925 SOCKET me = fd;
926 stream->setsockopt(SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT,
927 reinterpret_cast<char*>(&me), sizeof(me));
928 }
929
930 // Supposedly, AcceptEx() places the local and peer addresses into the buffer (which we've
931 // named `scratch`). However, the format in which it writes these is undocumented, and
932 // doesn't even match between native Windows and WINE. Apparently it is useless. I don't know
933 // why they require the buffer to have space for it in the first place. We'll need to call
934 // getpeername() to get the address.
935 auto addr = SocketAddress::getPeerAddress(newFd);
936 if (addr.allowedBy(filter)) {
937 return kj::mv(stream);
938 } else {
939 return accept();
940 }
941 })));
942 }
943
getPort()944 uint getPort() override {
945 return address.getPort();
946 }
947
getsockopt(int level,int option,void * value,uint * length)948 void getsockopt(int level, int option, void* value, uint* length) override {
949 socklen_t socklen = *length;
950 KJ_WINSOCK(::getsockopt(fd, level, option,
951 reinterpret_cast<char*>(value), &socklen));
952 *length = socklen;
953 }
setsockopt(int level,int option,const void * value,uint length)954 void setsockopt(int level, int option, const void* value, uint length) override {
955 KJ_WINSOCK(::setsockopt(fd, level, option,
956 reinterpret_cast<const char*>(value), length));
957 }
getsockname(struct sockaddr * addr,uint * length)958 void getsockname(struct sockaddr* addr, uint* length) override {
959 socklen_t socklen = *length;
960 KJ_WINSOCK(::getsockname(fd, addr, &socklen));
961 *length = socklen;
962 }
963
964 public:
965 Win32EventPort& eventPort;
966 LowLevelAsyncIoProvider::NetworkFilter& filter;
967 Own<Win32EventPort::IoObserver> observer;
968 LPFN_ACCEPTEX acceptEx = nullptr;
969 SocketAddress address;
970 };
971
972 // TODO(someday): DatagramPortImpl
973
974 class LowLevelAsyncIoProviderImpl final: public LowLevelAsyncIoProvider {
975 public:
LowLevelAsyncIoProviderImpl()976 LowLevelAsyncIoProviderImpl()
977 : eventLoop(eventPort), waitScope(eventLoop) {}
978
getWaitScope()979 inline WaitScope& getWaitScope() { return waitScope; }
980
wrapInputFd(SOCKET fd,uint flags=0)981 Own<AsyncInputStream> wrapInputFd(SOCKET fd, uint flags = 0) override {
982 return heap<AsyncStreamFd>(eventPort, fd, flags);
983 }
wrapOutputFd(SOCKET fd,uint flags=0)984 Own<AsyncOutputStream> wrapOutputFd(SOCKET fd, uint flags = 0) override {
985 return heap<AsyncStreamFd>(eventPort, fd, flags);
986 }
wrapSocketFd(SOCKET fd,uint flags=0)987 Own<AsyncIoStream> wrapSocketFd(SOCKET fd, uint flags = 0) override {
988 return heap<AsyncStreamFd>(eventPort, fd, flags);
989 }
wrapConnectingSocketFd(SOCKET fd,const struct sockaddr * addr,uint addrlen,uint flags=0)990 Promise<Own<AsyncIoStream>> wrapConnectingSocketFd(
991 SOCKET fd, const struct sockaddr* addr, uint addrlen, uint flags = 0) override {
992 auto result = heap<AsyncStreamFd>(eventPort, fd, flags);
993
994 // ConnectEx requires that the socket be bound, for some reason. Bind to an arbitrary port.
995 SocketAddress::getWildcardForFamily(addr->sa_family).bind(fd);
996
997 auto connected = result->connect(addr, addrlen);
998 return connected.then(kj::mvCapture(result, [](Own<AsyncIoStream>&& result) {
999 return kj::mv(result);
1000 }));
1001 }
wrapListenSocketFd(SOCKET fd,NetworkFilter & filter,uint flags=0)1002 Own<ConnectionReceiver> wrapListenSocketFd(
1003 SOCKET fd, NetworkFilter& filter, uint flags = 0) override {
1004 return heap<FdConnectionReceiver>(eventPort, fd, filter, flags);
1005 }
1006
getTimer()1007 Timer& getTimer() override { return eventPort.getTimer(); }
1008
getEventPort()1009 Win32EventPort& getEventPort() { return eventPort; }
1010
1011 private:
1012 Win32IocpEventPort eventPort;
1013 EventLoop eventLoop;
1014 WaitScope waitScope;
1015 };
1016
1017 // =======================================================================================
1018
1019 class NetworkAddressImpl final: public NetworkAddress {
1020 public:
NetworkAddressImpl(LowLevelAsyncIoProvider & lowLevel,LowLevelAsyncIoProvider::NetworkFilter & filter,Array<SocketAddress> addrs)1021 NetworkAddressImpl(LowLevelAsyncIoProvider& lowLevel,
1022 LowLevelAsyncIoProvider::NetworkFilter& filter,
1023 Array<SocketAddress> addrs)
1024 : lowLevel(lowLevel), filter(filter), addrs(kj::mv(addrs)) {}
1025
connect()1026 Promise<Own<AsyncIoStream>> connect() override {
1027 auto addrsCopy = heapArray(addrs.asPtr());
1028 auto promise = connectImpl(lowLevel, filter, addrsCopy);
1029 return promise.attach(kj::mv(addrsCopy));
1030 }
1031
listen()1032 Own<ConnectionReceiver> listen() override {
1033 if (addrs.size() > 1) {
1034 KJ_LOG(WARNING, "Bind address resolved to multiple addresses. Only the first address will "
1035 "be used. If this is incorrect, specify the address numerically. This may be fixed "
1036 "in the future.", addrs[0].toString());
1037 }
1038
1039 int fd = addrs[0].socket(SOCK_STREAM);
1040
1041 {
1042 KJ_ON_SCOPE_FAILURE(closesocket(fd));
1043
1044 // We always enable SO_REUSEADDR because having to take your server down for five minutes
1045 // before it can restart really sucks.
1046 int optval = 1;
1047 KJ_WINSOCK(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR,
1048 reinterpret_cast<char*>(&optval), sizeof(optval)));
1049
1050 addrs[0].bind(fd);
1051
1052 // TODO(someday): Let queue size be specified explicitly in string addresses.
1053 KJ_WINSOCK(::listen(fd, SOMAXCONN));
1054 }
1055
1056 return lowLevel.wrapListenSocketFd(fd, filter, NEW_FD_FLAGS);
1057 }
1058
bindDatagramPort()1059 Own<DatagramPort> bindDatagramPort() override {
1060 if (addrs.size() > 1) {
1061 KJ_LOG(WARNING, "Bind address resolved to multiple addresses. Only the first address will "
1062 "be used. If this is incorrect, specify the address numerically. This may be fixed "
1063 "in the future.", addrs[0].toString());
1064 }
1065
1066 int fd = addrs[0].socket(SOCK_DGRAM);
1067
1068 {
1069 KJ_ON_SCOPE_FAILURE(closesocket(fd));
1070
1071 // We always enable SO_REUSEADDR because having to take your server down for five minutes
1072 // before it can restart really sucks.
1073 int optval = 1;
1074 KJ_WINSOCK(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR,
1075 reinterpret_cast<char*>(&optval), sizeof(optval)));
1076
1077 addrs[0].bind(fd);
1078 }
1079
1080 return lowLevel.wrapDatagramSocketFd(fd, filter, NEW_FD_FLAGS);
1081 }
1082
clone()1083 Own<NetworkAddress> clone() override {
1084 return kj::heap<NetworkAddressImpl>(lowLevel, filter, kj::heapArray(addrs.asPtr()));
1085 }
1086
toString()1087 String toString() override {
1088 return strArray(KJ_MAP(addr, addrs) { return addr.toString(); }, ",");
1089 }
1090
chooseOneAddress()1091 const SocketAddress& chooseOneAddress() {
1092 KJ_REQUIRE(addrs.size() > 0, "No addresses available.");
1093 return addrs[counter++ % addrs.size()];
1094 }
1095
1096 private:
1097 LowLevelAsyncIoProvider& lowLevel;
1098 LowLevelAsyncIoProvider::NetworkFilter& filter;
1099 Array<SocketAddress> addrs;
1100 uint counter = 0;
1101
connectImpl(LowLevelAsyncIoProvider & lowLevel,LowLevelAsyncIoProvider::NetworkFilter & filter,ArrayPtr<SocketAddress> addrs)1102 static Promise<Own<AsyncIoStream>> connectImpl(
1103 LowLevelAsyncIoProvider& lowLevel,
1104 LowLevelAsyncIoProvider::NetworkFilter& filter,
1105 ArrayPtr<SocketAddress> addrs) {
1106 KJ_ASSERT(addrs.size() > 0);
1107
1108 int fd = addrs[0].socket(SOCK_STREAM);
1109
1110 return kj::evalNow([&]() -> Promise<Own<AsyncIoStream>> {
1111 if (!addrs[0].allowedBy(filter)) {
1112 return KJ_EXCEPTION(FAILED, "connect() blocked by restrictPeers()");
1113 } else {
1114 return lowLevel.wrapConnectingSocketFd(
1115 fd, addrs[0].getRaw(), addrs[0].getRawSize(), NEW_FD_FLAGS);
1116 }
1117 }).then([](Own<AsyncIoStream>&& stream) -> Promise<Own<AsyncIoStream>> {
1118 // Success, pass along.
1119 return kj::mv(stream);
1120 }, [&lowLevel,&filter,KJ_CPCAP(addrs)](Exception&& exception) mutable
1121 -> Promise<Own<AsyncIoStream>> {
1122 // Connect failed.
1123 if (addrs.size() > 1) {
1124 // Try the next address instead.
1125 return connectImpl(lowLevel, filter, addrs.slice(1, addrs.size()));
1126 } else {
1127 // No more addresses to try, so propagate the exception.
1128 return kj::mv(exception);
1129 }
1130 });
1131 }
1132 };
1133
1134 class SocketNetwork final: public Network {
1135 public:
SocketNetwork(LowLevelAsyncIoProvider & lowLevel)1136 explicit SocketNetwork(LowLevelAsyncIoProvider& lowLevel): lowLevel(lowLevel) {}
SocketNetwork(SocketNetwork & parent,kj::ArrayPtr<const kj::StringPtr> allow,kj::ArrayPtr<const kj::StringPtr> deny)1137 explicit SocketNetwork(SocketNetwork& parent,
1138 kj::ArrayPtr<const kj::StringPtr> allow,
1139 kj::ArrayPtr<const kj::StringPtr> deny)
1140 : lowLevel(parent.lowLevel), filter(allow, deny, parent.filter) {}
1141
parseAddress(StringPtr addr,uint portHint=0)1142 Promise<Own<NetworkAddress>> parseAddress(StringPtr addr, uint portHint = 0) override {
1143 return evalLater(mvCapture(heapString(addr), [this,portHint](String&& addr) {
1144 return SocketAddress::parse(lowLevel, addr, portHint, filter);
1145 })).then([this](Array<SocketAddress> addresses) -> Own<NetworkAddress> {
1146 return heap<NetworkAddressImpl>(lowLevel, filter, kj::mv(addresses));
1147 });
1148 }
1149
getSockaddr(const void * sockaddr,uint len)1150 Own<NetworkAddress> getSockaddr(const void* sockaddr, uint len) override {
1151 auto array = kj::heapArrayBuilder<SocketAddress>(1);
1152 array.add(SocketAddress(sockaddr, len));
1153 KJ_REQUIRE(array[0].allowedBy(filter), "address blocked by restrictPeers()") { break; }
1154 return Own<NetworkAddress>(heap<NetworkAddressImpl>(lowLevel, filter, array.finish()));
1155 }
1156
restrictPeers(kj::ArrayPtr<const kj::StringPtr> allow,kj::ArrayPtr<const kj::StringPtr> deny=nullptr)1157 Own<Network> restrictPeers(
1158 kj::ArrayPtr<const kj::StringPtr> allow,
1159 kj::ArrayPtr<const kj::StringPtr> deny = nullptr) override {
1160 return heap<SocketNetwork>(*this, allow, deny);
1161 }
1162
1163 private:
1164 LowLevelAsyncIoProvider& lowLevel;
1165 _::NetworkFilter filter;
1166 };
1167
1168 // =======================================================================================
1169
1170 class AsyncIoProviderImpl final: public AsyncIoProvider {
1171 public:
AsyncIoProviderImpl(LowLevelAsyncIoProvider & lowLevel)1172 AsyncIoProviderImpl(LowLevelAsyncIoProvider& lowLevel)
1173 : lowLevel(lowLevel), network(lowLevel) {}
1174
newOneWayPipe()1175 OneWayPipe newOneWayPipe() override {
1176 SOCKET fds[2];
1177 KJ_WINSOCK(_::win32Socketpair(fds));
1178 auto in = lowLevel.wrapSocketFd(fds[0], NEW_FD_FLAGS);
1179 auto out = lowLevel.wrapOutputFd(fds[1], NEW_FD_FLAGS);
1180 in->shutdownWrite();
1181 return { kj::mv(in), kj::mv(out) };
1182 }
1183
newTwoWayPipe()1184 TwoWayPipe newTwoWayPipe() override {
1185 SOCKET fds[2];
1186 KJ_WINSOCK(_::win32Socketpair(fds));
1187 return TwoWayPipe { {
1188 lowLevel.wrapSocketFd(fds[0], NEW_FD_FLAGS),
1189 lowLevel.wrapSocketFd(fds[1], NEW_FD_FLAGS)
1190 } };
1191 }
1192
getNetwork()1193 Network& getNetwork() override {
1194 return network;
1195 }
1196
newPipeThread(Function<void (AsyncIoProvider &,AsyncIoStream &,WaitScope &)> startFunc)1197 PipeThread newPipeThread(
1198 Function<void(AsyncIoProvider&, AsyncIoStream&, WaitScope&)> startFunc) override {
1199 SOCKET fds[2];
1200 KJ_WINSOCK(_::win32Socketpair(fds));
1201
1202 int threadFd = fds[1];
1203 KJ_ON_SCOPE_FAILURE(closesocket(threadFd));
1204
1205 auto pipe = lowLevel.wrapSocketFd(fds[0], NEW_FD_FLAGS);
1206
1207 auto thread = heap<Thread>(kj::mvCapture(startFunc,
1208 [threadFd](Function<void(AsyncIoProvider&, AsyncIoStream&, WaitScope&)>&& startFunc) {
1209 LowLevelAsyncIoProviderImpl lowLevel;
1210 auto stream = lowLevel.wrapSocketFd(threadFd, NEW_FD_FLAGS);
1211 AsyncIoProviderImpl ioProvider(lowLevel);
1212 startFunc(ioProvider, *stream, lowLevel.getWaitScope());
1213 }));
1214
1215 return { kj::mv(thread), kj::mv(pipe) };
1216 }
1217
getTimer()1218 Timer& getTimer() override { return lowLevel.getTimer(); }
1219
1220 private:
1221 LowLevelAsyncIoProvider& lowLevel;
1222 SocketNetwork network;
1223 };
1224
1225 } // namespace
1226
newAsyncIoProvider(LowLevelAsyncIoProvider & lowLevel)1227 Own<AsyncIoProvider> newAsyncIoProvider(LowLevelAsyncIoProvider& lowLevel) {
1228 return kj::heap<AsyncIoProviderImpl>(lowLevel);
1229 }
1230
setupAsyncIo()1231 AsyncIoContext setupAsyncIo() {
1232 _::initWinsockOnce();
1233
1234 auto lowLevel = heap<LowLevelAsyncIoProviderImpl>();
1235 auto ioProvider = kj::heap<AsyncIoProviderImpl>(*lowLevel);
1236 auto& waitScope = lowLevel->getWaitScope();
1237 auto& eventPort = lowLevel->getEventPort();
1238 return { kj::mv(lowLevel), kj::mv(ioProvider), waitScope, eventPort };
1239 }
1240
1241 } // namespace kj
1242
1243 #endif // _WIN32
1244