1 // Copyright (c) 2016 Sandstorm Development Group, Inc. and contributors
2 // Licensed under the MIT License:
3 //
4 // Permission is hereby granted, free of charge, to any person obtaining a copy
5 // of this software and associated documentation files (the "Software"), to deal
6 // in the Software without restriction, including without limitation the rights
7 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 // copies of the Software, and to permit persons to whom the Software is
9 // furnished to do so, subject to the following conditions:
10 //
11 // The above copyright notice and this permission notice shall be included in
12 // all copies or substantial portions of the Software.
13 //
14 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 // THE SOFTWARE.
21
22 #if _WIN32
23 // For Unix implementation, see async-io-unix.c++.
24
25 // Request Vista-level APIs.
26 #include "win32-api-version.h"
27
28 #include "async-io.h"
29 #include "async-io-internal.h"
30 #include "async-win32.h"
31 #include "debug.h"
32 #include "thread.h"
33 #include "io.h"
34 #include "vector.h"
35 #include <set>
36
37 #include <winsock2.h>
38 #include <ws2ipdef.h>
39 #include <ws2tcpip.h>
40 #include <mswsock.h>
41 #include <stdlib.h>
42
43 #ifndef IPV6_V6ONLY
44 // MinGW's headers are missing this.
45 #define IPV6_V6ONLY 27
46 #endif
47
48 namespace kj {
49
50 namespace _ { // private
51
52 struct WinsockInitializer {
WinsockInitializerkj::_::WinsockInitializer53 WinsockInitializer() {
54 WSADATA dontcare;
55 int result = WSAStartup(MAKEWORD(2, 2), &dontcare);
56 if (result != 0) {
57 KJ_FAIL_WIN32("WSAStartup()", result);
58 }
59 }
60 };
61
initWinsockOnce()62 void initWinsockOnce() {
63 static WinsockInitializer initializer;
64 }
65
win32Socketpair(SOCKET socks[2])66 int win32Socketpair(SOCKET socks[2]) {
67 // This function from: https://github.com/ncm/selectable-socketpair/blob/master/socketpair.c
68 //
69 // Copyright notice:
70 //
71 // Copyright 2007, 2010 by Nathan C. Myers <ncm@cantrip.org>
72 // Redistribution and use in source and binary forms, with or without modification,
73 // are permitted provided that the following conditions are met:
74 //
75 // Redistributions of source code must retain the above copyright notice, this
76 // list of conditions and the following disclaimer.
77 //
78 // Redistributions in binary form must reproduce the above copyright notice,
79 // this list of conditions and the following disclaimer in the documentation
80 // and/or other materials provided with the distribution.
81 //
82 // The name of the author must not be used to endorse or promote products
83 // derived from this software without specific prior written permission.
84 //
85 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
86 // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
87 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
88 // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
89 // FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
90 // DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
91 // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
92 // CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
93 // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
94 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
95
96 // Note: This function is called from some Cap'n Proto unit tests, despite not having a public
97 // header declaration.
98 // TODO(cleanup): Consider putting this somewhere public? Note that since it depends on Winsock,
99 // it needs to be in the kj-async library.
100
101 initWinsockOnce();
102
103 union {
104 struct sockaddr_in inaddr;
105 struct sockaddr addr;
106 } a;
107 SOCKET listener;
108 int e;
109 socklen_t addrlen = sizeof(a.inaddr);
110 int reuse = 1;
111
112 if (socks == 0) {
113 WSASetLastError(WSAEINVAL);
114 return SOCKET_ERROR;
115 }
116 socks[0] = socks[1] = -1;
117
118 listener = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
119 if (listener == -1)
120 return SOCKET_ERROR;
121
122 memset(&a, 0, sizeof(a));
123 a.inaddr.sin_family = AF_INET;
124 a.inaddr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
125 a.inaddr.sin_port = 0;
126
127 for (;;) {
128 if (setsockopt(listener, SOL_SOCKET, SO_REUSEADDR,
129 (char*) &reuse, (socklen_t) sizeof(reuse)) == -1)
130 break;
131 if (bind(listener, &a.addr, sizeof(a.inaddr)) == SOCKET_ERROR)
132 break;
133
134 memset(&a, 0, sizeof(a));
135 if (getsockname(listener, &a.addr, &addrlen) == SOCKET_ERROR)
136 break;
137 // win32 getsockname may only set the port number, p=0.0005.
138 // ( http://msdn.microsoft.com/library/ms738543.aspx ):
139 a.inaddr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
140 a.inaddr.sin_family = AF_INET;
141
142 if (listen(listener, 1) == SOCKET_ERROR)
143 break;
144
145 socks[0] = WSASocket(AF_INET, SOCK_STREAM, 0, NULL, 0, WSA_FLAG_OVERLAPPED);
146 if (socks[0] == -1)
147 break;
148 if (connect(socks[0], &a.addr, sizeof(a.inaddr)) == SOCKET_ERROR)
149 break;
150
151 retryAccept:
152 socks[1] = accept(listener, NULL, NULL);
153 if (socks[1] == -1)
154 break;
155
156 // Verify that the client is actually us and not someone else who raced to connect first.
157 // (This check added by Kenton for security.)
158 union {
159 struct sockaddr_in inaddr;
160 struct sockaddr addr;
161 } b, c;
162 socklen_t bAddrlen = sizeof(b.inaddr);
163 socklen_t cAddrlen = sizeof(b.inaddr);
164 if (getpeername(socks[1], &b.addr, &bAddrlen) == SOCKET_ERROR)
165 break;
166 if (getsockname(socks[0], &c.addr, &cAddrlen) == SOCKET_ERROR)
167 break;
168 if (bAddrlen != cAddrlen || memcmp(&b.addr, &c.addr, bAddrlen) != 0) {
169 // Someone raced to connect first. Ignore.
170 closesocket(socks[1]);
171 goto retryAccept;
172 }
173
174 closesocket(listener);
175 return 0;
176 }
177
178 e = WSAGetLastError();
179 closesocket(listener);
180 closesocket(socks[0]);
181 closesocket(socks[1]);
182 WSASetLastError(e);
183 socks[0] = socks[1] = -1;
184 return SOCKET_ERROR;
185 }
186
187 } // namespace _
188
189 namespace {
190
191 // =======================================================================================
192
193 static constexpr uint NEW_FD_FLAGS = LowLevelAsyncIoProvider::TAKE_OWNERSHIP;
194
195 class OwnedFd {
196 public:
OwnedFd(SOCKET fd,uint flags)197 OwnedFd(SOCKET fd, uint flags): fd(fd), flags(flags) {
198 // TODO(perf): Maybe use SetFileCompletionNotificationModes() to tell Windows not to bother
199 // delivering an event when the operation completes inline. Not currently implemented on
200 // Wine, though.
201 }
202
~OwnedFd()203 ~OwnedFd() noexcept(false) {
204 if (flags & LowLevelAsyncIoProvider::TAKE_OWNERSHIP) {
205 KJ_WINSOCK(closesocket(fd)) { break; }
206 }
207 }
208
209 protected:
210 SOCKET fd;
211
212 private:
213 uint flags;
214 };
215
216 // =======================================================================================
217
218 class AsyncStreamFd: public OwnedFd, public AsyncIoStream {
219 public:
AsyncStreamFd(Win32EventPort & eventPort,SOCKET fd,uint flags)220 AsyncStreamFd(Win32EventPort& eventPort, SOCKET fd, uint flags)
221 : OwnedFd(fd, flags),
222 observer(eventPort.observeIo(reinterpret_cast<HANDLE>(fd))) {}
~AsyncStreamFd()223 virtual ~AsyncStreamFd() noexcept(false) {}
224
read(void * buffer,size_t minBytes,size_t maxBytes)225 Promise<size_t> read(void* buffer, size_t minBytes, size_t maxBytes) override {
226 return tryRead(buffer, minBytes, maxBytes).then([=](size_t result) {
227 KJ_REQUIRE(result >= minBytes, "Premature EOF") {
228 // Pretend we read zeros from the input.
229 memset(reinterpret_cast<byte*>(buffer) + result, 0, minBytes - result);
230 return minBytes;
231 }
232 return result;
233 });
234 }
235
tryRead(void * buffer,size_t minBytes,size_t maxBytes)236 Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
237 auto bufs = heapArray<WSABUF>(1);
238 bufs[0].buf = reinterpret_cast<char*>(buffer);
239 bufs[0].len = maxBytes;
240
241 ArrayPtr<WSABUF> ref = bufs;
242 return tryReadInternal(ref, minBytes, 0).attach(kj::mv(bufs));
243 }
244
write(const void * buffer,size_t size)245 Promise<void> write(const void* buffer, size_t size) override {
246 auto bufs = heapArray<WSABUF>(1);
247 bufs[0].buf = const_cast<char*>(reinterpret_cast<const char*>(buffer));
248 bufs[0].len = size;
249
250 ArrayPtr<WSABUF> ref = bufs;
251 return writeInternal(ref).attach(kj::mv(bufs));
252 }
253
write(ArrayPtr<const ArrayPtr<const byte>> pieces)254 Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override {
255 auto bufs = heapArray<WSABUF>(pieces.size());
256 for (auto i: kj::indices(pieces)) {
257 bufs[i].buf = const_cast<char*>(pieces[i].asChars().begin());
258 bufs[i].len = pieces[i].size();
259 }
260
261 ArrayPtr<WSABUF> ref = bufs;
262 return writeInternal(ref).attach(kj::mv(bufs));
263 }
264
connect(const struct sockaddr * addr,uint addrlen)265 kj::Promise<void> connect(const struct sockaddr* addr, uint addrlen) {
266 // In order to connect asynchronously, we need the ConnectEx() function. Apparently, we have
267 // to query the socket for it dynamically, I guess because of the insanity in which winsock
268 // can be implemented in userspace and old implementations may not support it.
269 GUID guid = WSAID_CONNECTEX;
270 LPFN_CONNECTEX connectEx = nullptr;
271 DWORD n = 0;
272 KJ_WINSOCK(WSAIoctl(fd, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid, sizeof(guid),
273 &connectEx, sizeof(connectEx), &n, NULL, NULL)) {
274 goto fail; // avoid memory leak due to compiler bugs
275 }
276 if (false) {
277 fail:
278 return kj::READY_NOW;
279 }
280
281 // OK, phew, we now have our ConnectEx function pointer. Call it.
282 auto op = observer->newOperation(0);
283
284 if (!connectEx(fd, addr, addrlen, NULL, 0, NULL, op->getOverlapped())) {
285 DWORD error = WSAGetLastError();
286 if (error != ERROR_IO_PENDING) {
287 KJ_FAIL_WIN32("ConnectEx()", error) { break; }
288 return kj::READY_NOW;
289 }
290 }
291
292 return op->onComplete().then([this](Win32EventPort::IoResult result) {
293 if (result.errorCode != ERROR_SUCCESS) {
294 KJ_FAIL_WIN32("ConnectEx()", result.errorCode) { return; }
295 }
296
297 // Enable shutdown() to work.
298 setsockopt(SOL_SOCKET, SO_UPDATE_CONNECT_CONTEXT, NULL, 0);
299 });
300 }
301
whenWriteDisconnected()302 Promise<void> whenWriteDisconnected() override {
303 // Windows IOCP does not provide a direct, documented way to detect when the socket disconnects
304 // without actually doing a read or write. However, there is an undocoumented-but-stable
305 // ioctl called IOCTL_AFD_POLL which can be used for this purpose. In fact, select() is
306 // implemented in terms of this ioctl -- performed synchronously -- but it's entirely possible
307 // to put only one socket into the list and perform the ioctl asynchronously. Here's the
308 // source code for select() in Windows 2000 (not sure how this became public...):
309 //
310 // https://github.com/pustladi/Windows-2000/blob/661d000d50637ed6fab2329d30e31775046588a9/private/net/sockets/winsock2/wsp/msafd/select.c#L59-L655
311 //
312 // And here's an interesting discussion: https://github.com/python-trio/trio/issues/52
313 //
314 // TODO(someday): Implement this with IOCTL_AFD_POLL. For now I'm leaving it unimplemented
315 // because I added this method for a Linux-only use case.
316 return NEVER_DONE;
317 }
318
shutdownWrite()319 void shutdownWrite() override {
320 // There's no legitimate way to get an AsyncStreamFd that isn't a socket through the
321 // Win32AsyncIoProvider interface.
322 KJ_WINSOCK(shutdown(fd, SD_SEND));
323 }
324
abortRead()325 void abortRead() override {
326 // There's no legitimate way to get an AsyncStreamFd that isn't a socket through the
327 // Win32AsyncIoProvider interface.
328 KJ_WINSOCK(shutdown(fd, SD_RECEIVE));
329 }
330
getsockopt(int level,int option,void * value,uint * length)331 void getsockopt(int level, int option, void* value, uint* length) override {
332 socklen_t socklen = *length;
333 KJ_WINSOCK(::getsockopt(fd, level, option,
334 reinterpret_cast<char*>(value), &socklen));
335 *length = socklen;
336 }
337
setsockopt(int level,int option,const void * value,uint length)338 void setsockopt(int level, int option, const void* value, uint length) override {
339 KJ_WINSOCK(::setsockopt(fd, level, option,
340 reinterpret_cast<const char*>(value), length));
341 }
342
getsockname(struct sockaddr * addr,uint * length)343 void getsockname(struct sockaddr* addr, uint* length) override {
344 socklen_t socklen = *length;
345 KJ_WINSOCK(::getsockname(fd, addr, &socklen));
346 *length = socklen;
347 }
348
getpeername(struct sockaddr * addr,uint * length)349 void getpeername(struct sockaddr* addr, uint* length) override {
350 socklen_t socklen = *length;
351 KJ_WINSOCK(::getpeername(fd, addr, &socklen));
352 *length = socklen;
353 }
354
355 private:
356 Own<Win32EventPort::IoObserver> observer;
357
tryReadInternal(ArrayPtr<WSABUF> bufs,size_t minBytes,size_t alreadyRead)358 Promise<size_t> tryReadInternal(ArrayPtr<WSABUF> bufs, size_t minBytes, size_t alreadyRead) {
359 // `bufs` will remain valid until the promise completes and may be freely modified.
360 //
361 // `alreadyRead` is the number of bytes we have already received via previous reads -- minBytes
362 // and buffer have already been adjusted to account for them, but this count must be included
363 // in the final return value.
364
365 auto op = observer->newOperation(0);
366
367 DWORD flags = 0;
368 if (WSARecv(fd, bufs.begin(), bufs.size(), NULL, &flags,
369 op->getOverlapped(), NULL) == SOCKET_ERROR) {
370 DWORD error = WSAGetLastError();
371 if (error != WSA_IO_PENDING) {
372 KJ_FAIL_WIN32("WSARecv()", error) { break; }
373 return alreadyRead;
374 }
375 }
376
377 return op->onComplete()
378 .then([this,KJ_CPCAP(bufs),minBytes,alreadyRead](Win32IocpEventPort::IoResult result) mutable
379 -> Promise<size_t> {
380 if (result.errorCode != ERROR_SUCCESS) {
381 if (alreadyRead > 0) {
382 // Report what we already read.
383 return alreadyRead;
384 } else {
385 KJ_FAIL_WIN32("WSARecv()", result.errorCode) { break; }
386 return size_t(0);
387 }
388 }
389
390 if (result.bytesTransferred == 0) {
391 return alreadyRead;
392 }
393
394 alreadyRead += result.bytesTransferred;
395 if (result.bytesTransferred >= minBytes) {
396 // We can stop here.
397 return alreadyRead;
398 }
399 minBytes -= result.bytesTransferred;
400
401 while (result.bytesTransferred >= bufs[0].len) {
402 result.bytesTransferred -= bufs[0].len;
403 bufs = bufs.slice(1, bufs.size());
404 }
405
406 if (result.bytesTransferred > 0) {
407 bufs[0].buf += result.bytesTransferred;
408 bufs[0].len -= result.bytesTransferred;
409 }
410
411 return tryReadInternal(bufs, minBytes, alreadyRead);
412 }).attach(kj::mv(bufs));
413 }
414
writeInternal(ArrayPtr<WSABUF> bufs)415 Promise<void> writeInternal(ArrayPtr<WSABUF> bufs) {
416 // `bufs` will remain valid until the promise completes and may be freely modified.
417
418 auto op = observer->newOperation(0);
419
420 if (WSASend(fd, bufs.begin(), bufs.size(), NULL, 0,
421 op->getOverlapped(), NULL) == SOCKET_ERROR) {
422 DWORD error = WSAGetLastError();
423 if (error != WSA_IO_PENDING) {
424 KJ_FAIL_WIN32("WSASend()", error) { break; }
425 return kj::READY_NOW;
426 }
427 }
428
429 return op->onComplete()
430 .then([this,KJ_CPCAP(bufs)](Win32IocpEventPort::IoResult result) mutable -> Promise<void> {
431 if (result.errorCode != ERROR_SUCCESS) {
432 KJ_FAIL_WIN32("WSASend()", result.errorCode) { break; }
433 return kj::READY_NOW;
434 }
435
436 while (bufs.size() > 0 && result.bytesTransferred >= bufs[0].len) {
437 result.bytesTransferred -= bufs[0].len;
438 bufs = bufs.slice(1, bufs.size());
439 }
440
441 if (result.bytesTransferred > 0) {
442 bufs[0].buf += result.bytesTransferred;
443 bufs[0].len -= result.bytesTransferred;
444 }
445
446 if (bufs.size() > 0) {
447 return writeInternal(bufs);
448 } else {
449 return kj::READY_NOW;
450 }
451 }).attach(kj::mv(bufs));
452 }
453 };
454
455 // =======================================================================================
456
457 class SocketAddress {
458 public:
SocketAddress(const void * sockaddr,uint len)459 SocketAddress(const void* sockaddr, uint len): addrlen(len) {
460 KJ_REQUIRE(len <= sizeof(addr), "Sorry, your sockaddr is too big for me.");
461 memcpy(&addr.generic, sockaddr, len);
462 }
463
operator <(const SocketAddress & other) const464 bool operator<(const SocketAddress& other) const {
465 // So we can use std::set<SocketAddress>... see DNS lookup code.
466
467 if (wildcard < other.wildcard) return true;
468 if (wildcard > other.wildcard) return false;
469
470 if (addrlen < other.addrlen) return true;
471 if (addrlen > other.addrlen) return false;
472
473 return memcmp(&addr.generic, &other.addr.generic, addrlen) < 0;
474 }
475
getRaw() const476 const struct sockaddr* getRaw() const { return &addr.generic; }
getRawSize() const477 int getRawSize() const { return addrlen; }
478
socket(int type) const479 SOCKET socket(int type) const {
480 bool isStream = type == SOCK_STREAM;
481
482 SOCKET result = ::socket(addr.generic.sa_family, type, 0);
483
484 if (result == INVALID_SOCKET) {
485 KJ_FAIL_WIN32("WSASocket()", WSAGetLastError()) { return INVALID_SOCKET; }
486 }
487
488 if (isStream && (addr.generic.sa_family == AF_INET ||
489 addr.generic.sa_family == AF_INET6)) {
490 // TODO(perf): As a hack for the 0.4 release we are always setting
491 // TCP_NODELAY because Nagle's algorithm pretty much kills Cap'n Proto's
492 // RPC protocol. Later, we should extend the interface to provide more
493 // control over this. Perhaps write() should have a flag which
494 // specifies whether to pass MSG_MORE.
495 BOOL one = TRUE;
496 KJ_WINSOCK(setsockopt(result, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(one)));
497 }
498
499 return result;
500 }
501
bind(SOCKET sockfd) const502 void bind(SOCKET sockfd) const {
503 if (wildcard) {
504 // Disable IPV6_V6ONLY because we want to handle both ipv4 and ipv6 on this socket. (The
505 // default value of this option varies across platforms.)
506 DWORD value = 0;
507 KJ_WINSOCK(setsockopt(sockfd, IPPROTO_IPV6, IPV6_V6ONLY,
508 reinterpret_cast<char*>(&value), sizeof(value)));
509 }
510
511 KJ_WINSOCK(::bind(sockfd, &addr.generic, addrlen), toString());
512 }
513
getPort() const514 uint getPort() const {
515 switch (addr.generic.sa_family) {
516 case AF_INET: return ntohs(addr.inet4.sin_port);
517 case AF_INET6: return ntohs(addr.inet6.sin6_port);
518 default: return 0;
519 }
520 }
521
toString() const522 String toString() const {
523 if (wildcard) {
524 return str("*:", getPort());
525 }
526
527 switch (addr.generic.sa_family) {
528 case AF_INET: {
529 char buffer[16];
530 if (InetNtopA(addr.inet4.sin_family, const_cast<struct in_addr*>(&addr.inet4.sin_addr),
531 buffer, sizeof(buffer)) == nullptr) {
532 KJ_FAIL_WIN32("InetNtop", WSAGetLastError()) { break; }
533 return heapString("(inet_ntop error)");
534 }
535 return str(buffer, ':', ntohs(addr.inet4.sin_port));
536 }
537 case AF_INET6: {
538 char buffer[46];
539 if (InetNtopA(addr.inet6.sin6_family, const_cast<struct in6_addr*>(&addr.inet6.sin6_addr),
540 buffer, sizeof(buffer)) == nullptr) {
541 KJ_FAIL_WIN32("InetNtop", WSAGetLastError()) { break; }
542 return heapString("(inet_ntop error)");
543 }
544 return str('[', buffer, "]:", ntohs(addr.inet6.sin6_port));
545 }
546 default:
547 return str("(unknown address family ", addr.generic.sa_family, ")");
548 }
549 }
550
551 static Promise<Array<SocketAddress>> lookupHost(
552 LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint,
553 _::NetworkFilter& filter);
554 // Perform a DNS lookup.
555
parse(LowLevelAsyncIoProvider & lowLevel,StringPtr str,uint portHint,_::NetworkFilter & filter)556 static Promise<Array<SocketAddress>> parse(
557 LowLevelAsyncIoProvider& lowLevel, StringPtr str, uint portHint, _::NetworkFilter& filter) {
558 // TODO(someday): Allow commas in `str`.
559
560 SocketAddress result;
561
562 // Try to separate the address and port.
563 ArrayPtr<const char> addrPart;
564 Maybe<StringPtr> portPart;
565
566 int af;
567
568 if (str.startsWith("[")) {
569 // Address starts with a bracket, which is a common way to write an ip6 address with a port,
570 // since without brackets around the address part, the port looks like another segment of
571 // the address.
572 af = AF_INET6;
573 size_t closeBracket = KJ_ASSERT_NONNULL(str.findLast(']'),
574 "Unclosed '[' in address string.", str);
575
576 addrPart = str.slice(1, closeBracket);
577 if (str.size() > closeBracket + 1) {
578 KJ_REQUIRE(str.slice(closeBracket + 1).startsWith(":"),
579 "Expected port suffix after ']'.", str);
580 portPart = str.slice(closeBracket + 2);
581 }
582 } else {
583 KJ_IF_MAYBE(colon, str.findFirst(':')) {
584 if (str.slice(*colon + 1).findFirst(':') == nullptr) {
585 // There is exactly one colon and no brackets, so it must be an ip4 address with port.
586 af = AF_INET;
587 addrPart = str.slice(0, *colon);
588 portPart = str.slice(*colon + 1);
589 } else {
590 // There are two or more colons and no brackets, so the whole thing must be an ip6
591 // address with no port.
592 af = AF_INET6;
593 addrPart = str;
594 }
595 } else {
596 // No colons, so it must be an ip4 address without port.
597 af = AF_INET;
598 addrPart = str;
599 }
600 }
601
602 // Parse the port.
603 unsigned long port;
604 KJ_IF_MAYBE(portText, portPart) {
605 char* endptr;
606 port = strtoul(portText->cStr(), &endptr, 0);
607 if (portText->size() == 0 || *endptr != '\0') {
608 // Not a number. Maybe it's a service name. Fall back to DNS.
609 return lookupHost(lowLevel, kj::heapString(addrPart), kj::heapString(*portText), portHint,
610 filter);
611 }
612 KJ_REQUIRE(port < 65536, "Port number too large.");
613 } else {
614 port = portHint;
615 }
616
617 // Check for wildcard.
618 if (addrPart.size() == 1 && addrPart[0] == '*') {
619 result.wildcard = true;
620 // Create an ip6 socket and set IPV6_V6ONLY to 0 later.
621 result.addrlen = sizeof(addr.inet6);
622 result.addr.inet6.sin6_family = AF_INET6;
623 result.addr.inet6.sin6_port = htons(port);
624 auto array = kj::heapArrayBuilder<SocketAddress>(1);
625 array.add(result);
626 return array.finish();
627 }
628
629 void* addrTarget;
630 if (af == AF_INET6) {
631 result.addrlen = sizeof(addr.inet6);
632 result.addr.inet6.sin6_family = AF_INET6;
633 result.addr.inet6.sin6_port = htons(port);
634 addrTarget = &result.addr.inet6.sin6_addr;
635 } else {
636 result.addrlen = sizeof(addr.inet4);
637 result.addr.inet4.sin_family = AF_INET;
638 result.addr.inet4.sin_port = htons(port);
639 addrTarget = &result.addr.inet4.sin_addr;
640 }
641
642 char buffer[64];
643 if (addrPart.size() < sizeof(buffer) - 1) {
644 // addrPart is not necessarily NUL-terminated so we have to make a copy. :(
645 memcpy(buffer, addrPart.begin(), addrPart.size());
646 buffer[addrPart.size()] = '\0';
647
648 // OK, parse it!
649 switch (InetPtonA(af, buffer, addrTarget)) {
650 case 1: {
651 // success.
652 if (!result.parseAllowedBy(filter)) {
653 KJ_FAIL_REQUIRE("address family blocked by restrictPeers()");
654 return Array<SocketAddress>();
655 }
656
657 auto array = kj::heapArrayBuilder<SocketAddress>(1);
658 array.add(result);
659 return array.finish();
660 }
661 case 0:
662 // It's apparently not a simple address... fall back to DNS.
663 break;
664 default:
665 KJ_FAIL_WIN32("InetPton", WSAGetLastError(), af, addrPart);
666 }
667 }
668
669 return lookupHost(lowLevel, kj::heapString(addrPart), nullptr, port, filter);
670 }
671
getLocalAddress(SOCKET sockfd)672 static SocketAddress getLocalAddress(SOCKET sockfd) {
673 SocketAddress result;
674 result.addrlen = sizeof(addr);
675 KJ_WINSOCK(getsockname(sockfd, &result.addr.generic, &result.addrlen));
676 return result;
677 }
678
getPeerAddress(SOCKET sockfd)679 static SocketAddress getPeerAddress(SOCKET sockfd) {
680 SocketAddress result;
681 result.addrlen = sizeof(addr);
682 KJ_WINSOCK(getpeername(sockfd, &result.addr.generic, &result.addrlen));
683 return result;
684 }
685
allowedBy(LowLevelAsyncIoProvider::NetworkFilter & filter)686 bool allowedBy(LowLevelAsyncIoProvider::NetworkFilter& filter) {
687 return filter.shouldAllow(&addr.generic, addrlen);
688 }
689
parseAllowedBy(_::NetworkFilter & filter)690 bool parseAllowedBy(_::NetworkFilter& filter) {
691 return filter.shouldAllowParse(&addr.generic, addrlen);
692 }
693
getWildcardForFamily(int family)694 static SocketAddress getWildcardForFamily(int family) {
695 SocketAddress result;
696 switch (family) {
697 case AF_INET:
698 result.addrlen = sizeof(addr.inet4);
699 result.addr.inet4.sin_family = AF_INET;
700 return result;
701 case AF_INET6:
702 result.addrlen = sizeof(addr.inet6);
703 result.addr.inet6.sin6_family = AF_INET6;
704 return result;
705 default:
706 KJ_FAIL_REQUIRE("unknown address family", family);
707 }
708 }
709
710 private:
SocketAddress()711 SocketAddress(): addrlen(0) {
712 memset(&addr, 0, sizeof(addr));
713 }
714
715 socklen_t addrlen;
716 bool wildcard = false;
717 union {
718 struct sockaddr generic;
719 struct sockaddr_in inet4;
720 struct sockaddr_in6 inet6;
721 struct sockaddr_storage storage;
722 } addr;
723
724 struct LookupParams;
725 class LookupReader;
726 };
727
728 class SocketAddress::LookupReader {
729 // Reads SocketAddresses off of a pipe coming from another thread that is performing
730 // getaddrinfo.
731
732 public:
LookupReader(kj::Own<Thread> && thread,kj::Own<AsyncInputStream> && input,_::NetworkFilter & filter)733 LookupReader(kj::Own<Thread>&& thread, kj::Own<AsyncInputStream>&& input,
734 _::NetworkFilter& filter)
735 : thread(kj::mv(thread)), input(kj::mv(input)), filter(filter) {}
736
~LookupReader()737 ~LookupReader() {
738 if (thread) thread->detach();
739 }
740
read()741 Promise<Array<SocketAddress>> read() {
742 return input->tryRead(¤t, sizeof(current), sizeof(current)).then(
743 [this](size_t n) -> Promise<Array<SocketAddress>> {
744 if (n < sizeof(current)) {
745 thread = nullptr;
746 // getaddrinfo()'s docs seem to say it will never return an empty list, but let's check
747 // anyway.
748 KJ_REQUIRE(addresses.size() > 0, "DNS lookup returned no permitted addresses.") { break; }
749 return addresses.releaseAsArray();
750 } else {
751 // getaddrinfo() can return multiple copies of the same address for several reasons.
752 // A major one is that we don't give it a socket type (SOCK_STREAM vs. SOCK_DGRAM), so
753 // it may return two copies of the same address, one for each type, unless it explicitly
754 // knows that the service name given is specific to one type. But we can't tell it a type,
755 // because we don't actually know which one the user wants, and if we specify SOCK_STREAM
756 // while the user specified a UDP service name then they'll get a resolution error which
757 // is lame. (At least, I think that's how it works.)
758 //
759 // So we instead resort to de-duping results.
760 if (alreadySeen.insert(current).second) {
761 if (current.parseAllowedBy(filter)) {
762 addresses.add(current);
763 }
764 }
765 return read();
766 }
767 });
768 }
769
770 private:
771 kj::Own<Thread> thread;
772 kj::Own<AsyncInputStream> input;
773 _::NetworkFilter& filter;
774 SocketAddress current;
775 kj::Vector<SocketAddress> addresses;
776 std::set<SocketAddress> alreadySeen;
777 };
778
779 struct SocketAddress::LookupParams {
780 kj::String host;
781 kj::String service;
782 };
783
lookupHost(LowLevelAsyncIoProvider & lowLevel,kj::String host,kj::String service,uint portHint,_::NetworkFilter & filter)784 Promise<Array<SocketAddress>> SocketAddress::lookupHost(
785 LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint,
786 _::NetworkFilter& filter) {
787 // This shitty function spawns a thread to run getaddrinfo(). Unfortunately, getaddrinfo() is
788 // the only cross-platform DNS API and it is blocking.
789 //
790 // TODO(perf): Use GetAddrInfoEx(). But there are problems:
791 // - Not implemented in Wine.
792 // - Doesn't seem compatible with I/O completion ports, in particular because it's not associated
793 // with a handle. Could signal completion as an APC instead, but that requires the IOCP code
794 // to use GetQueuedCompletionStatusEx() which it doesn't right now because it's not available
795 // in Wine.
796 // - Requires Unicode, for some reason. Only GetAddrInfoExW() supports async, according to the
797 // docs. Never mind that DNS itself is ASCII...
798
799 SOCKET fds[2];
800 KJ_WINSOCK(_::win32Socketpair(fds));
801
802 auto input = lowLevel.wrapInputFd(fds[0], NEW_FD_FLAGS);
803
804 int outFd = fds[1];
805
806 LookupParams params = { kj::mv(host), kj::mv(service) };
807
808 auto thread = heap<Thread>(kj::mvCapture(params, [outFd,portHint](LookupParams&& params) {
809 KJ_DEFER(closesocket(outFd));
810
811 struct addrinfo* list;
812 int status = getaddrinfo(
813 params.host == "*" ? nullptr : params.host.cStr(),
814 params.service == nullptr ? nullptr : params.service.cStr(),
815 nullptr, &list);
816 if (status == 0) {
817 KJ_DEFER(freeaddrinfo(list));
818
819 struct addrinfo* cur = list;
820 while (cur != nullptr) {
821 if (params.service == nullptr) {
822 switch (cur->ai_addr->sa_family) {
823 case AF_INET:
824 ((struct sockaddr_in*)cur->ai_addr)->sin_port = htons(portHint);
825 break;
826 case AF_INET6:
827 ((struct sockaddr_in6*)cur->ai_addr)->sin6_port = htons(portHint);
828 break;
829 default:
830 break;
831 }
832 }
833
834 SocketAddress addr;
835 memset(&addr, 0, sizeof(addr)); // mollify valgrind
836 if (params.host == "*") {
837 // Set up a wildcard SocketAddress. Only use the port number returned by getaddrinfo().
838 addr.wildcard = true;
839 addr.addrlen = sizeof(addr.addr.inet6);
840 addr.addr.inet6.sin6_family = AF_INET6;
841 switch (cur->ai_addr->sa_family) {
842 case AF_INET:
843 addr.addr.inet6.sin6_port = ((struct sockaddr_in*)cur->ai_addr)->sin_port;
844 break;
845 case AF_INET6:
846 addr.addr.inet6.sin6_port = ((struct sockaddr_in6*)cur->ai_addr)->sin6_port;
847 break;
848 default:
849 addr.addr.inet6.sin6_port = portHint;
850 break;
851 }
852 } else {
853 addr.addrlen = cur->ai_addrlen;
854 memcpy(&addr.addr.generic, cur->ai_addr, cur->ai_addrlen);
855 }
856 KJ_ASSERT_CAN_MEMCPY(SocketAddress);
857
858 const char* data = reinterpret_cast<const char*>(&addr);
859 size_t size = sizeof(addr);
860 while (size > 0) {
861 int n;
862 KJ_WINSOCK(n = send(outFd, data, size, 0));
863 data += n;
864 size -= n;
865 }
866
867 cur = cur->ai_next;
868 }
869 } else {
870 KJ_FAIL_WIN32("getaddrinfo()", status, params.host, params.service) {
871 return;
872 }
873 }
874 }));
875
876 auto reader = heap<LookupReader>(kj::mv(thread), kj::mv(input), filter);
877 return reader->read().attach(kj::mv(reader));
878 }
879
880 // =======================================================================================
881
882 class FdConnectionReceiver final: public ConnectionReceiver, public OwnedFd {
883 public:
FdConnectionReceiver(Win32EventPort & eventPort,SOCKET fd,LowLevelAsyncIoProvider::NetworkFilter & filter,uint flags)884 FdConnectionReceiver(Win32EventPort& eventPort, SOCKET fd,
885 LowLevelAsyncIoProvider::NetworkFilter& filter, uint flags)
886 : OwnedFd(fd, flags), eventPort(eventPort), filter(filter),
887 observer(eventPort.observeIo(reinterpret_cast<HANDLE>(fd))),
888 address(SocketAddress::getLocalAddress(fd)) {
889 // In order to accept asynchronously, we need the AcceptEx() function. Apparently, we have
890 // to query the socket for it dynamically, I guess because of the insanity in which winsock
891 // can be implemented in userspace and old implementations may not support it.
892 GUID guid = WSAID_ACCEPTEX;
893 DWORD n = 0;
894 KJ_WINSOCK(WSAIoctl(fd, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid, sizeof(guid),
895 &acceptEx, sizeof(acceptEx), &n, NULL, NULL)) {
896 acceptEx = nullptr;
897 return;
898 }
899 }
900
accept()901 Promise<Own<AsyncIoStream>> accept() override {
902 SOCKET newFd = address.socket(SOCK_STREAM);
903 KJ_ASSERT(newFd != INVALID_SOCKET);
904 auto result = heap<AsyncStreamFd>(eventPort, newFd, NEW_FD_FLAGS);
905
906 auto scratch = heapArray<byte>(256);
907 DWORD dummy;
908 auto op = observer->newOperation(0);
909 if (!acceptEx(fd, newFd, scratch.begin(), 0, 128, 128, &dummy, op->getOverlapped())) {
910 DWORD error = WSAGetLastError();
911 if (error != ERROR_IO_PENDING) {
912 KJ_FAIL_WIN32("AcceptEx()", error) { break; }
913 return Own<AsyncIoStream>(kj::mv(result)); // dummy, won't be used
914 }
915 }
916
917 return op->onComplete().then(mvCapture(result, mvCapture(scratch,
918 [this,newFd]
919 (Array<byte> scratch, Own<AsyncIoStream> stream, Win32EventPort::IoResult ioResult)
920 -> Promise<Own<AsyncIoStream>> {
921 if (ioResult.errorCode != ERROR_SUCCESS) {
922 KJ_FAIL_WIN32("AcceptEx()", ioResult.errorCode) { break; }
923 } else {
924 SOCKET me = fd;
925 stream->setsockopt(SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT,
926 reinterpret_cast<char*>(&me), sizeof(me));
927 }
928
929 // Supposedly, AcceptEx() places the local and peer addresses into the buffer (which we've
930 // named `scratch`). However, the format in which it writes these is undocumented, and
931 // doesn't even match between native Windows and WINE. Apparently it is useless. I don't know
932 // why they require the buffer to have space for it in the first place. We'll need to call
933 // getpeername() to get the address.
934 auto addr = SocketAddress::getPeerAddress(newFd);
935 if (addr.allowedBy(filter)) {
936 return kj::mv(stream);
937 } else {
938 return accept();
939 }
940 })));
941 }
942
getPort()943 uint getPort() override {
944 return address.getPort();
945 }
946
getsockopt(int level,int option,void * value,uint * length)947 void getsockopt(int level, int option, void* value, uint* length) override {
948 socklen_t socklen = *length;
949 KJ_WINSOCK(::getsockopt(fd, level, option,
950 reinterpret_cast<char*>(value), &socklen));
951 *length = socklen;
952 }
setsockopt(int level,int option,const void * value,uint length)953 void setsockopt(int level, int option, const void* value, uint length) override {
954 KJ_WINSOCK(::setsockopt(fd, level, option,
955 reinterpret_cast<const char*>(value), length));
956 }
getsockname(struct sockaddr * addr,uint * length)957 void getsockname(struct sockaddr* addr, uint* length) override {
958 socklen_t socklen = *length;
959 KJ_WINSOCK(::getsockname(fd, addr, &socklen));
960 *length = socklen;
961 }
962
963 public:
964 Win32EventPort& eventPort;
965 LowLevelAsyncIoProvider::NetworkFilter& filter;
966 Own<Win32EventPort::IoObserver> observer;
967 LPFN_ACCEPTEX acceptEx = nullptr;
968 SocketAddress address;
969 };
970
971 // TODO(someday): DatagramPortImpl
972
973 class LowLevelAsyncIoProviderImpl final: public LowLevelAsyncIoProvider {
974 public:
LowLevelAsyncIoProviderImpl()975 LowLevelAsyncIoProviderImpl()
976 : eventLoop(eventPort), waitScope(eventLoop) {}
977
getWaitScope()978 inline WaitScope& getWaitScope() { return waitScope; }
979
wrapInputFd(SOCKET fd,uint flags=0)980 Own<AsyncInputStream> wrapInputFd(SOCKET fd, uint flags = 0) override {
981 return heap<AsyncStreamFd>(eventPort, fd, flags);
982 }
wrapOutputFd(SOCKET fd,uint flags=0)983 Own<AsyncOutputStream> wrapOutputFd(SOCKET fd, uint flags = 0) override {
984 return heap<AsyncStreamFd>(eventPort, fd, flags);
985 }
wrapSocketFd(SOCKET fd,uint flags=0)986 Own<AsyncIoStream> wrapSocketFd(SOCKET fd, uint flags = 0) override {
987 return heap<AsyncStreamFd>(eventPort, fd, flags);
988 }
wrapConnectingSocketFd(SOCKET fd,const struct sockaddr * addr,uint addrlen,uint flags=0)989 Promise<Own<AsyncIoStream>> wrapConnectingSocketFd(
990 SOCKET fd, const struct sockaddr* addr, uint addrlen, uint flags = 0) override {
991 auto result = heap<AsyncStreamFd>(eventPort, fd, flags);
992
993 // ConnectEx requires that the socket be bound, for some reason. Bind to an arbitrary port.
994 SocketAddress::getWildcardForFamily(addr->sa_family).bind(fd);
995
996 auto connected = result->connect(addr, addrlen);
997 return connected.then(kj::mvCapture(result, [](Own<AsyncIoStream>&& result) {
998 return kj::mv(result);
999 }));
1000 }
wrapListenSocketFd(SOCKET fd,NetworkFilter & filter,uint flags=0)1001 Own<ConnectionReceiver> wrapListenSocketFd(
1002 SOCKET fd, NetworkFilter& filter, uint flags = 0) override {
1003 return heap<FdConnectionReceiver>(eventPort, fd, filter, flags);
1004 }
1005
getTimer()1006 Timer& getTimer() override { return eventPort.getTimer(); }
1007
getEventPort()1008 Win32EventPort& getEventPort() { return eventPort; }
1009
1010 private:
1011 Win32IocpEventPort eventPort;
1012 EventLoop eventLoop;
1013 WaitScope waitScope;
1014 };
1015
1016 // =======================================================================================
1017
1018 class NetworkAddressImpl final: public NetworkAddress {
1019 public:
NetworkAddressImpl(LowLevelAsyncIoProvider & lowLevel,LowLevelAsyncIoProvider::NetworkFilter & filter,Array<SocketAddress> addrs)1020 NetworkAddressImpl(LowLevelAsyncIoProvider& lowLevel,
1021 LowLevelAsyncIoProvider::NetworkFilter& filter,
1022 Array<SocketAddress> addrs)
1023 : lowLevel(lowLevel), filter(filter), addrs(kj::mv(addrs)) {}
1024
connect()1025 Promise<Own<AsyncIoStream>> connect() override {
1026 auto addrsCopy = heapArray(addrs.asPtr());
1027 auto promise = connectImpl(lowLevel, filter, addrsCopy);
1028 return promise.attach(kj::mv(addrsCopy));
1029 }
1030
listen()1031 Own<ConnectionReceiver> listen() override {
1032 if (addrs.size() > 1) {
1033 KJ_LOG(WARNING, "Bind address resolved to multiple addresses. Only the first address will "
1034 "be used. If this is incorrect, specify the address numerically. This may be fixed "
1035 "in the future.", addrs[0].toString());
1036 }
1037
1038 int fd = addrs[0].socket(SOCK_STREAM);
1039
1040 {
1041 KJ_ON_SCOPE_FAILURE(closesocket(fd));
1042
1043 // We always enable SO_REUSEADDR because having to take your server down for five minutes
1044 // before it can restart really sucks.
1045 int optval = 1;
1046 KJ_WINSOCK(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR,
1047 reinterpret_cast<char*>(&optval), sizeof(optval)));
1048
1049 addrs[0].bind(fd);
1050
1051 // TODO(someday): Let queue size be specified explicitly in string addresses.
1052 KJ_WINSOCK(::listen(fd, SOMAXCONN));
1053 }
1054
1055 return lowLevel.wrapListenSocketFd(fd, filter, NEW_FD_FLAGS);
1056 }
1057
bindDatagramPort()1058 Own<DatagramPort> bindDatagramPort() override {
1059 if (addrs.size() > 1) {
1060 KJ_LOG(WARNING, "Bind address resolved to multiple addresses. Only the first address will "
1061 "be used. If this is incorrect, specify the address numerically. This may be fixed "
1062 "in the future.", addrs[0].toString());
1063 }
1064
1065 int fd = addrs[0].socket(SOCK_DGRAM);
1066
1067 {
1068 KJ_ON_SCOPE_FAILURE(closesocket(fd));
1069
1070 // We always enable SO_REUSEADDR because having to take your server down for five minutes
1071 // before it can restart really sucks.
1072 int optval = 1;
1073 KJ_WINSOCK(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR,
1074 reinterpret_cast<char*>(&optval), sizeof(optval)));
1075
1076 addrs[0].bind(fd);
1077 }
1078
1079 return lowLevel.wrapDatagramSocketFd(fd, filter, NEW_FD_FLAGS);
1080 }
1081
clone()1082 Own<NetworkAddress> clone() override {
1083 return kj::heap<NetworkAddressImpl>(lowLevel, filter, kj::heapArray(addrs.asPtr()));
1084 }
1085
toString()1086 String toString() override {
1087 return strArray(KJ_MAP(addr, addrs) { return addr.toString(); }, ",");
1088 }
1089
chooseOneAddress()1090 const SocketAddress& chooseOneAddress() {
1091 KJ_REQUIRE(addrs.size() > 0, "No addresses available.");
1092 return addrs[counter++ % addrs.size()];
1093 }
1094
1095 private:
1096 LowLevelAsyncIoProvider& lowLevel;
1097 LowLevelAsyncIoProvider::NetworkFilter& filter;
1098 Array<SocketAddress> addrs;
1099 uint counter = 0;
1100
connectImpl(LowLevelAsyncIoProvider & lowLevel,LowLevelAsyncIoProvider::NetworkFilter & filter,ArrayPtr<SocketAddress> addrs)1101 static Promise<Own<AsyncIoStream>> connectImpl(
1102 LowLevelAsyncIoProvider& lowLevel,
1103 LowLevelAsyncIoProvider::NetworkFilter& filter,
1104 ArrayPtr<SocketAddress> addrs) {
1105 KJ_ASSERT(addrs.size() > 0);
1106
1107 int fd = addrs[0].socket(SOCK_STREAM);
1108
1109 return kj::evalNow([&]() -> Promise<Own<AsyncIoStream>> {
1110 if (!addrs[0].allowedBy(filter)) {
1111 return KJ_EXCEPTION(FAILED, "connect() blocked by restrictPeers()");
1112 } else {
1113 return lowLevel.wrapConnectingSocketFd(
1114 fd, addrs[0].getRaw(), addrs[0].getRawSize(), NEW_FD_FLAGS);
1115 }
1116 }).then([](Own<AsyncIoStream>&& stream) -> Promise<Own<AsyncIoStream>> {
1117 // Success, pass along.
1118 return kj::mv(stream);
1119 }, [&lowLevel,&filter,KJ_CPCAP(addrs)](Exception&& exception) mutable
1120 -> Promise<Own<AsyncIoStream>> {
1121 // Connect failed.
1122 if (addrs.size() > 1) {
1123 // Try the next address instead.
1124 return connectImpl(lowLevel, filter, addrs.slice(1, addrs.size()));
1125 } else {
1126 // No more addresses to try, so propagate the exception.
1127 return kj::mv(exception);
1128 }
1129 });
1130 }
1131 };
1132
1133 class SocketNetwork final: public Network {
1134 public:
SocketNetwork(LowLevelAsyncIoProvider & lowLevel)1135 explicit SocketNetwork(LowLevelAsyncIoProvider& lowLevel): lowLevel(lowLevel) {}
SocketNetwork(SocketNetwork & parent,kj::ArrayPtr<const kj::StringPtr> allow,kj::ArrayPtr<const kj::StringPtr> deny)1136 explicit SocketNetwork(SocketNetwork& parent,
1137 kj::ArrayPtr<const kj::StringPtr> allow,
1138 kj::ArrayPtr<const kj::StringPtr> deny)
1139 : lowLevel(parent.lowLevel), filter(allow, deny, parent.filter) {}
1140
parseAddress(StringPtr addr,uint portHint=0)1141 Promise<Own<NetworkAddress>> parseAddress(StringPtr addr, uint portHint = 0) override {
1142 return evalLater(mvCapture(heapString(addr), [this,portHint](String&& addr) {
1143 return SocketAddress::parse(lowLevel, addr, portHint, filter);
1144 })).then([this](Array<SocketAddress> addresses) -> Own<NetworkAddress> {
1145 return heap<NetworkAddressImpl>(lowLevel, filter, kj::mv(addresses));
1146 });
1147 }
1148
getSockaddr(const void * sockaddr,uint len)1149 Own<NetworkAddress> getSockaddr(const void* sockaddr, uint len) override {
1150 auto array = kj::heapArrayBuilder<SocketAddress>(1);
1151 array.add(SocketAddress(sockaddr, len));
1152 KJ_REQUIRE(array[0].allowedBy(filter), "address blocked by restrictPeers()") { break; }
1153 return Own<NetworkAddress>(heap<NetworkAddressImpl>(lowLevel, filter, array.finish()));
1154 }
1155
restrictPeers(kj::ArrayPtr<const kj::StringPtr> allow,kj::ArrayPtr<const kj::StringPtr> deny=nullptr)1156 Own<Network> restrictPeers(
1157 kj::ArrayPtr<const kj::StringPtr> allow,
1158 kj::ArrayPtr<const kj::StringPtr> deny = nullptr) override {
1159 return heap<SocketNetwork>(*this, allow, deny);
1160 }
1161
1162 private:
1163 LowLevelAsyncIoProvider& lowLevel;
1164 _::NetworkFilter filter;
1165 };
1166
1167 // =======================================================================================
1168
1169 class AsyncIoProviderImpl final: public AsyncIoProvider {
1170 public:
AsyncIoProviderImpl(LowLevelAsyncIoProvider & lowLevel)1171 AsyncIoProviderImpl(LowLevelAsyncIoProvider& lowLevel)
1172 : lowLevel(lowLevel), network(lowLevel) {}
1173
newOneWayPipe()1174 OneWayPipe newOneWayPipe() override {
1175 SOCKET fds[2];
1176 KJ_WINSOCK(_::win32Socketpair(fds));
1177 auto in = lowLevel.wrapSocketFd(fds[0], NEW_FD_FLAGS);
1178 auto out = lowLevel.wrapOutputFd(fds[1], NEW_FD_FLAGS);
1179 in->shutdownWrite();
1180 return { kj::mv(in), kj::mv(out) };
1181 }
1182
newTwoWayPipe()1183 TwoWayPipe newTwoWayPipe() override {
1184 SOCKET fds[2];
1185 KJ_WINSOCK(_::win32Socketpair(fds));
1186 return TwoWayPipe { {
1187 lowLevel.wrapSocketFd(fds[0], NEW_FD_FLAGS),
1188 lowLevel.wrapSocketFd(fds[1], NEW_FD_FLAGS)
1189 } };
1190 }
1191
getNetwork()1192 Network& getNetwork() override {
1193 return network;
1194 }
1195
newPipeThread(Function<void (AsyncIoProvider &,AsyncIoStream &,WaitScope &)> startFunc)1196 PipeThread newPipeThread(
1197 Function<void(AsyncIoProvider&, AsyncIoStream&, WaitScope&)> startFunc) override {
1198 SOCKET fds[2];
1199 KJ_WINSOCK(_::win32Socketpair(fds));
1200
1201 int threadFd = fds[1];
1202 KJ_ON_SCOPE_FAILURE(closesocket(threadFd));
1203
1204 auto pipe = lowLevel.wrapSocketFd(fds[0], NEW_FD_FLAGS);
1205
1206 auto thread = heap<Thread>(kj::mvCapture(startFunc,
1207 [threadFd](Function<void(AsyncIoProvider&, AsyncIoStream&, WaitScope&)>&& startFunc) {
1208 LowLevelAsyncIoProviderImpl lowLevel;
1209 auto stream = lowLevel.wrapSocketFd(threadFd, NEW_FD_FLAGS);
1210 AsyncIoProviderImpl ioProvider(lowLevel);
1211 startFunc(ioProvider, *stream, lowLevel.getWaitScope());
1212 }));
1213
1214 return { kj::mv(thread), kj::mv(pipe) };
1215 }
1216
getTimer()1217 Timer& getTimer() override { return lowLevel.getTimer(); }
1218
1219 private:
1220 LowLevelAsyncIoProvider& lowLevel;
1221 SocketNetwork network;
1222 };
1223
1224 } // namespace
1225
newAsyncIoProvider(LowLevelAsyncIoProvider & lowLevel)1226 Own<AsyncIoProvider> newAsyncIoProvider(LowLevelAsyncIoProvider& lowLevel) {
1227 return kj::heap<AsyncIoProviderImpl>(lowLevel);
1228 }
1229
setupAsyncIo()1230 AsyncIoContext setupAsyncIo() {
1231 _::initWinsockOnce();
1232
1233 auto lowLevel = heap<LowLevelAsyncIoProviderImpl>();
1234 auto ioProvider = kj::heap<AsyncIoProviderImpl>(*lowLevel);
1235 auto& waitScope = lowLevel->getWaitScope();
1236 auto& eventPort = lowLevel->getEventPort();
1237 return { kj::mv(lowLevel), kj::mv(ioProvider), waitScope, eventPort };
1238 }
1239
1240 } // namespace kj
1241
1242 #endif // _WIN32
1243