1 /*
2  * wepoll - epoll for Windows
3  * https://github.com/piscisaureus/wepoll
4  *
5  * Copyright 2012-2018, Bert Belder <bertbelder@gmail.com>
6  * All rights reserved.
7  *
8  * Redistribution and use in source and binary forms, with or without
9  * modification, are permitted provided that the following conditions are
10  * met:
11  *
12  *   * Redistributions of source code must retain the above copyright
13  *     notice, this list of conditions and the following disclaimer.
14  *
15  *   * Redistributions in binary form must reproduce the above copyright
16  *     notice, this list of conditions and the following disclaimer in the
17  *     documentation and/or other materials provided with the distribution.
18  *
19  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22  * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23  * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30  */
31 
32 #ifndef WEPOLL_EXPORT
33 #define WEPOLL_EXPORT
34 #endif
35 
36 #include <stdint.h>
37 
38 /* clang-format off */
39 
40 enum EPOLL_EVENTS {
41   EPOLLIN      = (int) (1U <<  0),
42   EPOLLPRI     = (int) (1U <<  1),
43   EPOLLOUT     = (int) (1U <<  2),
44   EPOLLERR     = (int) (1U <<  3),
45   EPOLLHUP     = (int) (1U <<  4),
46   EPOLLRDNORM  = (int) (1U <<  6),
47   EPOLLRDBAND  = (int) (1U <<  7),
48   EPOLLWRNORM  = (int) (1U <<  8),
49   EPOLLWRBAND  = (int) (1U <<  9),
50   EPOLLMSG     = (int) (1U << 10), /* Never reported. */
51   EPOLLRDHUP   = (int) (1U << 13),
52   EPOLLONESHOT = (int) (1U << 31)
53 };
54 
55 #define EPOLLIN      (1U <<  0)
56 #define EPOLLPRI     (1U <<  1)
57 #define EPOLLOUT     (1U <<  2)
58 #define EPOLLERR     (1U <<  3)
59 #define EPOLLHUP     (1U <<  4)
60 #define EPOLLRDNORM  (1U <<  6)
61 #define EPOLLRDBAND  (1U <<  7)
62 #define EPOLLWRNORM  (1U <<  8)
63 #define EPOLLWRBAND  (1U <<  9)
64 #define EPOLLMSG     (1U << 10)
65 #define EPOLLRDHUP   (1U << 13)
66 #define EPOLLONESHOT (1U << 31)
67 
68 #define EPOLL_CTL_ADD 1
69 #define EPOLL_CTL_MOD 2
70 #define EPOLL_CTL_DEL 3
71 
72 /* clang-format on */
73 
74 typedef void* HANDLE;
75 typedef uintptr_t SOCKET;
76 
77 typedef union epoll_data {
78   void* ptr;
79   int fd;
80   uint32_t u32;
81   uint64_t u64;
82   SOCKET sock; /* Windows specific */
83   HANDLE hnd;  /* Windows specific */
84 } epoll_data_t;
85 
86 struct epoll_event {
87   uint32_t events;   /* Epoll events and flags */
88   epoll_data_t data; /* User data variable */
89 };
90 
91 #ifdef __cplusplus
92 extern "C" {
93 #endif
94 
95 WEPOLL_EXPORT HANDLE epoll_create(int size);
96 WEPOLL_EXPORT HANDLE epoll_create1(int flags);
97 
98 WEPOLL_EXPORT int epoll_close(HANDLE ephnd);
99 
100 WEPOLL_EXPORT int epoll_ctl(HANDLE ephnd,
101                             int op,
102                             SOCKET sock,
103                             struct epoll_event* event);
104 
105 WEPOLL_EXPORT int epoll_wait(HANDLE ephnd,
106                              struct epoll_event* events,
107                              int maxevents,
108                              int timeout);
109 
110 #ifdef __cplusplus
111 } /* extern "C" */
112 #endif
113 
114 #include <malloc.h>
115 #include <stdlib.h>
116 
117 #define WEPOLL_INTERNAL static
118 #define WEPOLL_INTERNAL_VAR static
119 
120 #ifndef WIN32_LEAN_AND_MEAN
121 #define WIN32_LEAN_AND_MEAN
122 #endif
123 
124 #ifdef __clang__
125 #pragma clang diagnostic push
126 #pragma clang diagnostic ignored "-Wreserved-id-macro"
127 #endif
128 
129 #ifdef _WIN32_WINNT
130 #undef _WIN32_WINNT
131 #endif
132 
133 #define _WIN32_WINNT 0x0600
134 
135 #ifdef __clang__
136 #pragma clang diagnostic pop
137 #endif
138 
139 #ifndef __GNUC__
140 #pragma warning(push, 1)
141 #endif
142 
143 #include <WS2tcpip.h>
144 #include <WinSock2.h>
145 #include <Windows.h>
146 
147 #ifndef __GNUC__
148 #pragma warning(pop)
149 #endif
150 
151 WEPOLL_INTERNAL int nt_global_init(void);
152 
153 typedef LONG NTSTATUS;
154 typedef NTSTATUS* PNTSTATUS;
155 
156 #ifndef NT_SUCCESS
157 #define NT_SUCCESS(status) (((NTSTATUS)(status)) >= 0)
158 #endif
159 
160 #ifndef STATUS_SUCCESS
161 #define STATUS_SUCCESS ((NTSTATUS) 0x00000000L)
162 #endif
163 
164 #ifndef STATUS_PENDING
165 #define STATUS_PENDING ((NTSTATUS) 0x00000103L)
166 #endif
167 
168 #ifndef STATUS_CANCELLED
169 #define STATUS_CANCELLED ((NTSTATUS) 0xC0000120L)
170 #endif
171 
172 typedef struct _IO_STATUS_BLOCK {
173   NTSTATUS Status;
174   ULONG_PTR Information;
175 } IO_STATUS_BLOCK, *PIO_STATUS_BLOCK;
176 
177 typedef VOID(NTAPI* PIO_APC_ROUTINE)(PVOID ApcContext,
178                                      PIO_STATUS_BLOCK IoStatusBlock,
179                                      ULONG Reserved);
180 
181 typedef struct _LSA_UNICODE_STRING {
182   USHORT Length;
183   USHORT MaximumLength;
184   PWSTR Buffer;
185 } LSA_UNICODE_STRING, *PLSA_UNICODE_STRING, UNICODE_STRING, *PUNICODE_STRING;
186 
187 #define RTL_CONSTANT_STRING(s) \
188   { sizeof(s) - sizeof((s)[0]), sizeof(s), s }
189 
190 typedef struct _OBJECT_ATTRIBUTES {
191   ULONG Length;
192   HANDLE RootDirectory;
193   PUNICODE_STRING ObjectName;
194   ULONG Attributes;
195   PVOID SecurityDescriptor;
196   PVOID SecurityQualityOfService;
197 } OBJECT_ATTRIBUTES, *POBJECT_ATTRIBUTES;
198 
199 #define RTL_CONSTANT_OBJECT_ATTRIBUTES(ObjectName, Attributes) \
200   { sizeof(OBJECT_ATTRIBUTES), NULL, ObjectName, Attributes, NULL, NULL }
201 
202 #ifndef FILE_OPEN
203 #define FILE_OPEN 0x00000001UL
204 #endif
205 
206 #define NT_NTDLL_IMPORT_LIST(X)                                              \
207   X(NTSTATUS,                                                                \
208     NTAPI,                                                                   \
209     NtCreateFile,                                                            \
210     (PHANDLE FileHandle,                                                     \
211      ACCESS_MASK DesiredAccess,                                              \
212      POBJECT_ATTRIBUTES ObjectAttributes,                                    \
213      PIO_STATUS_BLOCK IoStatusBlock,                                         \
214      PLARGE_INTEGER AllocationSize,                                          \
215      ULONG FileAttributes,                                                   \
216      ULONG ShareAccess,                                                      \
217      ULONG CreateDisposition,                                                \
218      ULONG CreateOptions,                                                    \
219      PVOID EaBuffer,                                                         \
220      ULONG EaLength))                                                        \
221                                                                              \
222   X(NTSTATUS,                                                                \
223     NTAPI,                                                                   \
224     NtDeviceIoControlFile,                                                   \
225     (HANDLE FileHandle,                                                      \
226      HANDLE Event,                                                           \
227      PIO_APC_ROUTINE ApcRoutine,                                             \
228      PVOID ApcContext,                                                       \
229      PIO_STATUS_BLOCK IoStatusBlock,                                         \
230      ULONG IoControlCode,                                                    \
231      PVOID InputBuffer,                                                      \
232      ULONG InputBufferLength,                                                \
233      PVOID OutputBuffer,                                                     \
234      ULONG OutputBufferLength))                                              \
235                                                                              \
236   X(ULONG, WINAPI, RtlNtStatusToDosError, (NTSTATUS Status))                 \
237                                                                              \
238   X(NTSTATUS,                                                                \
239     NTAPI,                                                                   \
240     NtCreateKeyedEvent,                                                      \
241     (PHANDLE handle,                                                         \
242      ACCESS_MASK access,                                                     \
243      POBJECT_ATTRIBUTES attr,                                                \
244      ULONG flags))                                                           \
245                                                                              \
246   X(NTSTATUS,                                                                \
247     NTAPI,                                                                   \
248     NtWaitForKeyedEvent,                                                     \
249     (HANDLE handle, PVOID key, BOOLEAN alertable, PLARGE_INTEGER mstimeout)) \
250                                                                              \
251   X(NTSTATUS,                                                                \
252     NTAPI,                                                                   \
253     NtReleaseKeyedEvent,                                                     \
254     (HANDLE handle, PVOID key, BOOLEAN alertable, PLARGE_INTEGER mstimeout))
255 
256 #define X(return_type, attributes, name, parameters) \
257   WEPOLL_INTERNAL_VAR return_type(attributes* name) parameters;
258 NT_NTDLL_IMPORT_LIST(X)
259 #undef X
260 
261 #include <assert.h>
262 #include <stddef.h>
263 
264 #ifndef _SSIZE_T_DEFINED
265 typedef intptr_t ssize_t;
266 #endif
267 
268 #define array_count(a) (sizeof(a) / (sizeof((a)[0])))
269 
270 /* clang-format off */
271 #define container_of(ptr, type, member) \
272   ((type*) ((uintptr_t) (ptr) - offsetof(type, member)))
273 /* clang-format on */
274 
275 #define unused_var(v) ((void) (v))
276 
277 /* Polyfill `inline` for older versions of msvc (up to Visual Studio 2013) */
278 #if defined(_MSC_VER) && _MSC_VER < 1900
279 #define inline __inline
280 #endif
281 
282 /* clang-format off */
283 #define AFD_POLL_RECEIVE           0x0001
284 #define AFD_POLL_RECEIVE_EXPEDITED 0x0002
285 #define AFD_POLL_SEND              0x0004
286 #define AFD_POLL_DISCONNECT        0x0008
287 #define AFD_POLL_ABORT             0x0010
288 #define AFD_POLL_LOCAL_CLOSE       0x0020
289 #define AFD_POLL_ACCEPT            0x0080
290 #define AFD_POLL_CONNECT_FAIL      0x0100
291 /* clang-format on */
292 
293 typedef struct _AFD_POLL_HANDLE_INFO {
294   HANDLE Handle;
295   ULONG Events;
296   NTSTATUS Status;
297 } AFD_POLL_HANDLE_INFO, *PAFD_POLL_HANDLE_INFO;
298 
299 typedef struct _AFD_POLL_INFO {
300   LARGE_INTEGER Timeout;
301   ULONG NumberOfHandles;
302   ULONG Exclusive;
303   AFD_POLL_HANDLE_INFO Handles[1];
304 } AFD_POLL_INFO, *PAFD_POLL_INFO;
305 
306 WEPOLL_INTERNAL int afd_create_helper_handle(HANDLE iocp,
307                                              HANDLE* afd_helper_handle_out);
308 
309 WEPOLL_INTERNAL int afd_poll(HANDLE afd_helper_handle,
310                              AFD_POLL_INFO* poll_info,
311                              OVERLAPPED* overlapped);
312 
313 #define return_map_error(value) \
314   do {                          \
315     err_map_win_error();        \
316     return (value);             \
317   } while (0)
318 
319 #define return_set_error(value, error) \
320   do {                                 \
321     err_set_win_error(error);          \
322     return (value);                    \
323   } while (0)
324 
325 WEPOLL_INTERNAL void err_map_win_error(void);
326 WEPOLL_INTERNAL void err_set_win_error(DWORD error);
327 WEPOLL_INTERNAL int err_check_handle(HANDLE handle);
328 
329 WEPOLL_INTERNAL int ws_global_init(void);
330 WEPOLL_INTERNAL SOCKET ws_get_base_socket(SOCKET socket);
331 
332 #define IOCTL_AFD_POLL 0x00012024
333 
334 static UNICODE_STRING afd__helper_name =
335     RTL_CONSTANT_STRING(L"\\Device\\Afd\\Wepoll");
336 
337 static OBJECT_ATTRIBUTES afd__helper_attributes =
338     RTL_CONSTANT_OBJECT_ATTRIBUTES(&afd__helper_name, 0);
339 
afd_create_helper_handle(HANDLE iocp,HANDLE * afd_helper_handle_out)340 int afd_create_helper_handle(HANDLE iocp, HANDLE* afd_helper_handle_out) {
341   HANDLE afd_helper_handle;
342   IO_STATUS_BLOCK iosb;
343   NTSTATUS status;
344 
345   /* By opening \Device\Afd without specifying any extended attributes, we'll
346    * get a handle that lets us talk to the AFD driver, but that doesn't have an
347    * associated endpoint (so it's not a socket). */
348   status = NtCreateFile(&afd_helper_handle,
349                         SYNCHRONIZE,
350                         &afd__helper_attributes,
351                         &iosb,
352                         NULL,
353                         0,
354                         FILE_SHARE_READ | FILE_SHARE_WRITE,
355                         FILE_OPEN,
356                         0,
357                         NULL,
358                         0);
359   if (status != STATUS_SUCCESS)
360     return_set_error(-1, RtlNtStatusToDosError(status));
361 
362   if (CreateIoCompletionPort(afd_helper_handle, iocp, 0, 0) == NULL)
363     goto error;
364 
365   if (!SetFileCompletionNotificationModes(afd_helper_handle,
366                                           FILE_SKIP_SET_EVENT_ON_HANDLE))
367     goto error;
368 
369   *afd_helper_handle_out = afd_helper_handle;
370   return 0;
371 
372 error:
373   CloseHandle(afd_helper_handle);
374   return_map_error(-1);
375 }
376 
afd_poll(HANDLE afd_helper_handle,AFD_POLL_INFO * poll_info,OVERLAPPED * overlapped)377 int afd_poll(HANDLE afd_helper_handle,
378              AFD_POLL_INFO* poll_info,
379              OVERLAPPED* overlapped) {
380   IO_STATUS_BLOCK* iosb;
381   HANDLE event;
382   void* apc_context;
383   NTSTATUS status;
384 
385   /* Blocking operation is not supported. */
386   assert(overlapped != NULL);
387 
388   iosb = (IO_STATUS_BLOCK*) &overlapped->Internal;
389   event = overlapped->hEvent;
390 
391   /* Do what other windows APIs would do: if hEvent has it's lowest bit set,
392    * don't post a completion to the completion port. */
393   if ((uintptr_t) event & 1) {
394     event = (HANDLE)((uintptr_t) event & ~(uintptr_t) 1);
395     apc_context = NULL;
396   } else {
397     apc_context = overlapped;
398   }
399 
400   iosb->Status = STATUS_PENDING;
401   status = NtDeviceIoControlFile(afd_helper_handle,
402                                  event,
403                                  NULL,
404                                  apc_context,
405                                  iosb,
406                                  IOCTL_AFD_POLL,
407                                  poll_info,
408                                  sizeof *poll_info,
409                                  poll_info,
410                                  sizeof *poll_info);
411 
412   if (status == STATUS_SUCCESS)
413     return 0;
414   else if (status == STATUS_PENDING)
415     return_set_error(-1, ERROR_IO_PENDING);
416   else
417     return_set_error(-1, RtlNtStatusToDosError(status));
418 }
419 
420 WEPOLL_INTERNAL int epoll_global_init(void);
421 
422 WEPOLL_INTERNAL int init(void);
423 
424 #include <stdbool.h>
425 
426 typedef struct queue_node queue_node_t;
427 
428 typedef struct queue_node {
429   queue_node_t* prev;
430   queue_node_t* next;
431 } queue_node_t;
432 
433 typedef struct queue {
434   queue_node_t head;
435 } queue_t;
436 
437 WEPOLL_INTERNAL void queue_init(queue_t* queue);
438 WEPOLL_INTERNAL void queue_node_init(queue_node_t* node);
439 
440 WEPOLL_INTERNAL queue_node_t* queue_first(const queue_t* queue);
441 WEPOLL_INTERNAL queue_node_t* queue_last(const queue_t* queue);
442 
443 WEPOLL_INTERNAL void queue_prepend(queue_t* queue, queue_node_t* node);
444 WEPOLL_INTERNAL void queue_append(queue_t* queue, queue_node_t* node);
445 WEPOLL_INTERNAL void queue_move_first(queue_t* queue, queue_node_t* node);
446 WEPOLL_INTERNAL void queue_move_last(queue_t* queue, queue_node_t* node);
447 WEPOLL_INTERNAL void queue_remove(queue_node_t* node);
448 
449 WEPOLL_INTERNAL bool queue_empty(const queue_t* queue);
450 WEPOLL_INTERNAL bool queue_enqueued(const queue_node_t* node);
451 
452 typedef struct port_state port_state_t;
453 typedef struct poll_group poll_group_t;
454 
455 WEPOLL_INTERNAL poll_group_t* poll_group_acquire(port_state_t* port);
456 WEPOLL_INTERNAL void poll_group_release(poll_group_t* poll_group);
457 
458 WEPOLL_INTERNAL void poll_group_delete(poll_group_t* poll_group);
459 
460 WEPOLL_INTERNAL poll_group_t* poll_group_from_queue_node(
461     queue_node_t* queue_node);
462 WEPOLL_INTERNAL HANDLE
463     poll_group_get_afd_helper_handle(poll_group_t* poll_group);
464 
465 /* N.b.: the tree functions do not set errno or LastError when they fail. Each
466  * of the API functions has at most one failure mode. It is up to the caller to
467  * set an appropriate error code when necessary. */
468 
469 typedef struct tree tree_t;
470 typedef struct tree_node tree_node_t;
471 
472 typedef struct tree {
473   tree_node_t* root;
474 } tree_t;
475 
476 typedef struct tree_node {
477   tree_node_t* left;
478   tree_node_t* right;
479   tree_node_t* parent;
480   uintptr_t key;
481   bool red;
482 } tree_node_t;
483 
484 WEPOLL_INTERNAL void tree_init(tree_t* tree);
485 WEPOLL_INTERNAL void tree_node_init(tree_node_t* node);
486 
487 WEPOLL_INTERNAL int tree_add(tree_t* tree, tree_node_t* node, uintptr_t key);
488 WEPOLL_INTERNAL void tree_del(tree_t* tree, tree_node_t* node);
489 
490 WEPOLL_INTERNAL tree_node_t* tree_find(const tree_t* tree, uintptr_t key);
491 WEPOLL_INTERNAL tree_node_t* tree_root(const tree_t* tree);
492 
493 typedef struct port_state port_state_t;
494 typedef struct sock_state sock_state_t;
495 
496 WEPOLL_INTERNAL sock_state_t* sock_new(port_state_t* port_state,
497                                        SOCKET socket);
498 WEPOLL_INTERNAL void sock_delete(port_state_t* port_state,
499                                  sock_state_t* sock_state);
500 WEPOLL_INTERNAL void sock_force_delete(port_state_t* port_state,
501                                        sock_state_t* sock_state);
502 
503 WEPOLL_INTERNAL int sock_set_event(port_state_t* port_state,
504                                    sock_state_t* sock_state,
505                                    const struct epoll_event* ev);
506 
507 WEPOLL_INTERNAL int sock_update(port_state_t* port_state,
508                                 sock_state_t* sock_state);
509 WEPOLL_INTERNAL int sock_feed_event(port_state_t* port_state,
510                                     OVERLAPPED* overlapped,
511                                     struct epoll_event* ev);
512 
513 WEPOLL_INTERNAL sock_state_t* sock_state_from_queue_node(
514     queue_node_t* queue_node);
515 WEPOLL_INTERNAL queue_node_t* sock_state_to_queue_node(
516     sock_state_t* sock_state);
517 WEPOLL_INTERNAL sock_state_t* sock_state_from_tree_node(
518     tree_node_t* tree_node);
519 WEPOLL_INTERNAL tree_node_t* sock_state_to_tree_node(sock_state_t* sock_state);
520 
521 /* The reflock is a special kind of lock that normally prevents a chunk of
522  * memory from being freed, but does allow the chunk of memory to eventually be
523  * released in a coordinated fashion.
524  *
525  * Under normal operation, threads increase and decrease the reference count,
526  * which are wait-free operations.
527  *
528  * Exactly once during the reflock's lifecycle, a thread holding a reference to
529  * the lock may "destroy" the lock; this operation blocks until all other
530  * threads holding a reference to the lock have dereferenced it. After
531  * "destroy" returns, the calling thread may assume that no other threads have
532  * a reference to the lock.
533  *
534  * Attemmpting to lock or destroy a lock after reflock_unref_and_destroy() has
535  * been called is invalid and results in undefined behavior. Therefore the user
536  * should use another lock to guarantee that this can't happen.
537  */
538 
539 typedef struct reflock {
540   volatile long state; /* 32-bit Interlocked APIs operate on `long` values. */
541 } reflock_t;
542 
543 WEPOLL_INTERNAL int reflock_global_init(void);
544 
545 WEPOLL_INTERNAL void reflock_init(reflock_t* reflock);
546 WEPOLL_INTERNAL void reflock_ref(reflock_t* reflock);
547 WEPOLL_INTERNAL void reflock_unref(reflock_t* reflock);
548 WEPOLL_INTERNAL void reflock_unref_and_destroy(reflock_t* reflock);
549 
550 typedef struct ts_tree {
551   tree_t tree;
552   SRWLOCK lock;
553 } ts_tree_t;
554 
555 typedef struct ts_tree_node {
556   tree_node_t tree_node;
557   reflock_t reflock;
558 } ts_tree_node_t;
559 
560 WEPOLL_INTERNAL void ts_tree_init(ts_tree_t* rtl);
561 WEPOLL_INTERNAL void ts_tree_node_init(ts_tree_node_t* node);
562 
563 WEPOLL_INTERNAL int ts_tree_add(ts_tree_t* ts_tree,
564                                 ts_tree_node_t* node,
565                                 uintptr_t key);
566 
567 WEPOLL_INTERNAL ts_tree_node_t* ts_tree_del_and_ref(ts_tree_t* ts_tree,
568                                                     uintptr_t key);
569 WEPOLL_INTERNAL ts_tree_node_t* ts_tree_find_and_ref(ts_tree_t* ts_tree,
570                                                      uintptr_t key);
571 
572 WEPOLL_INTERNAL void ts_tree_node_unref(ts_tree_node_t* node);
573 WEPOLL_INTERNAL void ts_tree_node_unref_and_destroy(ts_tree_node_t* node);
574 
575 typedef struct port_state port_state_t;
576 typedef struct sock_state sock_state_t;
577 
578 typedef struct port_state {
579   HANDLE iocp;
580   tree_t sock_tree;
581   queue_t sock_update_queue;
582   queue_t sock_deleted_queue;
583   queue_t poll_group_queue;
584   ts_tree_node_t handle_tree_node;
585   CRITICAL_SECTION lock;
586   size_t active_poll_count;
587 } port_state_t;
588 
589 WEPOLL_INTERNAL port_state_t* port_new(HANDLE* iocp_out);
590 WEPOLL_INTERNAL int port_close(port_state_t* port_state);
591 WEPOLL_INTERNAL int port_delete(port_state_t* port_state);
592 
593 WEPOLL_INTERNAL int port_wait(port_state_t* port_state,
594                               struct epoll_event* events,
595                               int maxevents,
596                               int timeout);
597 
598 WEPOLL_INTERNAL int port_ctl(port_state_t* port_state,
599                              int op,
600                              SOCKET sock,
601                              struct epoll_event* ev);
602 
603 WEPOLL_INTERNAL int port_register_socket_handle(port_state_t* port_state,
604                                                 sock_state_t* sock_state,
605                                                 SOCKET socket);
606 WEPOLL_INTERNAL void port_unregister_socket_handle(port_state_t* port_state,
607                                                    sock_state_t* sock_state);
608 WEPOLL_INTERNAL sock_state_t* port_find_socket(port_state_t* port_state,
609                                                SOCKET socket);
610 
611 WEPOLL_INTERNAL void port_request_socket_update(port_state_t* port_state,
612                                                 sock_state_t* sock_state);
613 WEPOLL_INTERNAL void port_cancel_socket_update(port_state_t* port_state,
614                                                sock_state_t* sock_state);
615 
616 WEPOLL_INTERNAL void port_add_deleted_socket(port_state_t* port_state,
617                                              sock_state_t* sock_state);
618 WEPOLL_INTERNAL void port_remove_deleted_socket(port_state_t* port_state,
619                                                 sock_state_t* sock_state);
620 
621 static ts_tree_t epoll__handle_tree;
622 
epoll__handle_tree_node_to_port(ts_tree_node_t * tree_node)623 static inline port_state_t* epoll__handle_tree_node_to_port(
624     ts_tree_node_t* tree_node) {
625   return container_of(tree_node, port_state_t, handle_tree_node);
626 }
627 
epoll_global_init(void)628 int epoll_global_init(void) {
629   ts_tree_init(&epoll__handle_tree);
630   return 0;
631 }
632 
epoll__create(void)633 static HANDLE epoll__create(void) {
634   port_state_t* port_state;
635   HANDLE ephnd;
636 
637   if (init() < 0)
638     return NULL;
639 
640   port_state = port_new(&ephnd);
641   if (port_state == NULL)
642     return NULL;
643 
644   if (ts_tree_add(&epoll__handle_tree,
645                   &port_state->handle_tree_node,
646                   (uintptr_t) ephnd) < 0) {
647     /* This should never happen. */
648     port_delete(port_state);
649     return_set_error(NULL, ERROR_ALREADY_EXISTS);
650   }
651 
652   return ephnd;
653 }
654 
epoll_create(int size)655 HANDLE epoll_create(int size) {
656   if (size <= 0)
657     return_set_error(NULL, ERROR_INVALID_PARAMETER);
658 
659   return epoll__create();
660 }
661 
epoll_create1(int flags)662 HANDLE epoll_create1(int flags) {
663   if (flags != 0)
664     return_set_error(NULL, ERROR_INVALID_PARAMETER);
665 
666   return epoll__create();
667 }
668 
epoll_close(HANDLE ephnd)669 int epoll_close(HANDLE ephnd) {
670   ts_tree_node_t* tree_node;
671   port_state_t* port_state;
672 
673   if (init() < 0)
674     return -1;
675 
676   tree_node = ts_tree_del_and_ref(&epoll__handle_tree, (uintptr_t) ephnd);
677   if (tree_node == NULL) {
678     err_set_win_error(ERROR_INVALID_PARAMETER);
679     goto err;
680   }
681 
682   port_state = epoll__handle_tree_node_to_port(tree_node);
683   port_close(port_state);
684 
685   ts_tree_node_unref_and_destroy(tree_node);
686 
687   return port_delete(port_state);
688 
689 err:
690   err_check_handle(ephnd);
691   return -1;
692 }
693 
epoll_ctl(HANDLE ephnd,int op,SOCKET sock,struct epoll_event * ev)694 int epoll_ctl(HANDLE ephnd, int op, SOCKET sock, struct epoll_event* ev) {
695   ts_tree_node_t* tree_node;
696   port_state_t* port_state;
697   int r;
698 
699   if (init() < 0)
700     return -1;
701 
702   tree_node = ts_tree_find_and_ref(&epoll__handle_tree, (uintptr_t) ephnd);
703   if (tree_node == NULL) {
704     err_set_win_error(ERROR_INVALID_PARAMETER);
705     goto err;
706   }
707 
708   port_state = epoll__handle_tree_node_to_port(tree_node);
709   r = port_ctl(port_state, op, sock, ev);
710 
711   ts_tree_node_unref(tree_node);
712 
713   if (r < 0)
714     goto err;
715 
716   return 0;
717 
718 err:
719   /* On Linux, in the case of epoll_ctl_mod(), EBADF takes priority over other
720    * errors. Wepoll mimics this behavior. */
721   err_check_handle(ephnd);
722   err_check_handle((HANDLE) sock);
723   return -1;
724 }
725 
epoll_wait(HANDLE ephnd,struct epoll_event * events,int maxevents,int timeout)726 int epoll_wait(HANDLE ephnd,
727                struct epoll_event* events,
728                int maxevents,
729                int timeout) {
730   ts_tree_node_t* tree_node;
731   port_state_t* port_state;
732   int num_events;
733 
734   if (maxevents <= 0)
735     return_set_error(-1, ERROR_INVALID_PARAMETER);
736 
737   if (init() < 0)
738     return -1;
739 
740   tree_node = ts_tree_find_and_ref(&epoll__handle_tree, (uintptr_t) ephnd);
741   if (tree_node == NULL) {
742     err_set_win_error(ERROR_INVALID_PARAMETER);
743     goto err;
744   }
745 
746   port_state = epoll__handle_tree_node_to_port(tree_node);
747   num_events = port_wait(port_state, events, maxevents, timeout);
748 
749   ts_tree_node_unref(tree_node);
750 
751   if (num_events < 0)
752     goto err;
753 
754   return num_events;
755 
756 err:
757   err_check_handle(ephnd);
758   return -1;
759 }
760 
761 #include <errno.h>
762 
763 #define ERR__ERRNO_MAPPINGS(X)               \
764   X(ERROR_ACCESS_DENIED, EACCES)             \
765   X(ERROR_ALREADY_EXISTS, EEXIST)            \
766   X(ERROR_BAD_COMMAND, EACCES)               \
767   X(ERROR_BAD_EXE_FORMAT, ENOEXEC)           \
768   X(ERROR_BAD_LENGTH, EACCES)                \
769   X(ERROR_BAD_NETPATH, ENOENT)               \
770   X(ERROR_BAD_NET_NAME, ENOENT)              \
771   X(ERROR_BAD_NET_RESP, ENETDOWN)            \
772   X(ERROR_BAD_PATHNAME, ENOENT)              \
773   X(ERROR_BROKEN_PIPE, EPIPE)                \
774   X(ERROR_CANNOT_MAKE, EACCES)               \
775   X(ERROR_COMMITMENT_LIMIT, ENOMEM)          \
776   X(ERROR_CONNECTION_ABORTED, ECONNABORTED)  \
777   X(ERROR_CONNECTION_ACTIVE, EISCONN)        \
778   X(ERROR_CONNECTION_REFUSED, ECONNREFUSED)  \
779   X(ERROR_CRC, EACCES)                       \
780   X(ERROR_DIR_NOT_EMPTY, ENOTEMPTY)          \
781   X(ERROR_DISK_FULL, ENOSPC)                 \
782   X(ERROR_DUP_NAME, EADDRINUSE)              \
783   X(ERROR_FILENAME_EXCED_RANGE, ENOENT)      \
784   X(ERROR_FILE_NOT_FOUND, ENOENT)            \
785   X(ERROR_GEN_FAILURE, EACCES)               \
786   X(ERROR_GRACEFUL_DISCONNECT, EPIPE)        \
787   X(ERROR_HOST_DOWN, EHOSTUNREACH)           \
788   X(ERROR_HOST_UNREACHABLE, EHOSTUNREACH)    \
789   X(ERROR_INSUFFICIENT_BUFFER, EFAULT)       \
790   X(ERROR_INVALID_ADDRESS, EADDRNOTAVAIL)    \
791   X(ERROR_INVALID_FUNCTION, EINVAL)          \
792   X(ERROR_INVALID_HANDLE, EBADF)             \
793   X(ERROR_INVALID_NETNAME, EADDRNOTAVAIL)    \
794   X(ERROR_INVALID_PARAMETER, EINVAL)         \
795   X(ERROR_INVALID_USER_BUFFER, EMSGSIZE)     \
796   X(ERROR_IO_PENDING, EINPROGRESS)           \
797   X(ERROR_LOCK_VIOLATION, EACCES)            \
798   X(ERROR_MORE_DATA, EMSGSIZE)               \
799   X(ERROR_NETNAME_DELETED, ECONNABORTED)     \
800   X(ERROR_NETWORK_ACCESS_DENIED, EACCES)     \
801   X(ERROR_NETWORK_BUSY, ENETDOWN)            \
802   X(ERROR_NETWORK_UNREACHABLE, ENETUNREACH)  \
803   X(ERROR_NOACCESS, EFAULT)                  \
804   X(ERROR_NONPAGED_SYSTEM_RESOURCES, ENOMEM) \
805   X(ERROR_NOT_ENOUGH_MEMORY, ENOMEM)         \
806   X(ERROR_NOT_ENOUGH_QUOTA, ENOMEM)          \
807   X(ERROR_NOT_FOUND, ENOENT)                 \
808   X(ERROR_NOT_LOCKED, EACCES)                \
809   X(ERROR_NOT_READY, EACCES)                 \
810   X(ERROR_NOT_SAME_DEVICE, EXDEV)            \
811   X(ERROR_NOT_SUPPORTED, ENOTSUP)            \
812   X(ERROR_NO_MORE_FILES, ENOENT)             \
813   X(ERROR_NO_SYSTEM_RESOURCES, ENOMEM)       \
814   X(ERROR_OPERATION_ABORTED, EINTR)          \
815   X(ERROR_OUT_OF_PAPER, EACCES)              \
816   X(ERROR_PAGED_SYSTEM_RESOURCES, ENOMEM)    \
817   X(ERROR_PAGEFILE_QUOTA, ENOMEM)            \
818   X(ERROR_PATH_NOT_FOUND, ENOENT)            \
819   X(ERROR_PIPE_NOT_CONNECTED, EPIPE)         \
820   X(ERROR_PORT_UNREACHABLE, ECONNRESET)      \
821   X(ERROR_PROTOCOL_UNREACHABLE, ENETUNREACH) \
822   X(ERROR_REM_NOT_LIST, ECONNREFUSED)        \
823   X(ERROR_REQUEST_ABORTED, EINTR)            \
824   X(ERROR_REQ_NOT_ACCEP, EWOULDBLOCK)        \
825   X(ERROR_SECTOR_NOT_FOUND, EACCES)          \
826   X(ERROR_SEM_TIMEOUT, ETIMEDOUT)            \
827   X(ERROR_SHARING_VIOLATION, EACCES)         \
828   X(ERROR_TOO_MANY_NAMES, ENOMEM)            \
829   X(ERROR_TOO_MANY_OPEN_FILES, EMFILE)       \
830   X(ERROR_UNEXP_NET_ERR, ECONNABORTED)       \
831   X(ERROR_WAIT_NO_CHILDREN, ECHILD)          \
832   X(ERROR_WORKING_SET_QUOTA, ENOMEM)         \
833   X(ERROR_WRITE_PROTECT, EACCES)             \
834   X(ERROR_WRONG_DISK, EACCES)                \
835   X(WSAEACCES, EACCES)                       \
836   X(WSAEADDRINUSE, EADDRINUSE)               \
837   X(WSAEADDRNOTAVAIL, EADDRNOTAVAIL)         \
838   X(WSAEAFNOSUPPORT, EAFNOSUPPORT)           \
839   X(WSAECONNABORTED, ECONNABORTED)           \
840   X(WSAECONNREFUSED, ECONNREFUSED)           \
841   X(WSAECONNRESET, ECONNRESET)               \
842   X(WSAEDISCON, EPIPE)                       \
843   X(WSAEFAULT, EFAULT)                       \
844   X(WSAEHOSTDOWN, EHOSTUNREACH)              \
845   X(WSAEHOSTUNREACH, EHOSTUNREACH)           \
846   X(WSAEINPROGRESS, EBUSY)                   \
847   X(WSAEINTR, EINTR)                         \
848   X(WSAEINVAL, EINVAL)                       \
849   X(WSAEISCONN, EISCONN)                     \
850   X(WSAEMSGSIZE, EMSGSIZE)                   \
851   X(WSAENETDOWN, ENETDOWN)                   \
852   X(WSAENETRESET, EHOSTUNREACH)              \
853   X(WSAENETUNREACH, ENETUNREACH)             \
854   X(WSAENOBUFS, ENOMEM)                      \
855   X(WSAENOTCONN, ENOTCONN)                   \
856   X(WSAENOTSOCK, ENOTSOCK)                   \
857   X(WSAEOPNOTSUPP, EOPNOTSUPP)               \
858   X(WSAEPROCLIM, ENOMEM)                     \
859   X(WSAESHUTDOWN, EPIPE)                     \
860   X(WSAETIMEDOUT, ETIMEDOUT)                 \
861   X(WSAEWOULDBLOCK, EWOULDBLOCK)             \
862   X(WSANOTINITIALISED, ENETDOWN)             \
863   X(WSASYSNOTREADY, ENETDOWN)                \
864   X(WSAVERNOTSUPPORTED, ENOSYS)
865 
err__map_win_error_to_errno(DWORD error)866 static errno_t err__map_win_error_to_errno(DWORD error) {
867   switch (error) {
868 #define X(error_sym, errno_sym) \
869   case error_sym:               \
870     return errno_sym;
871     ERR__ERRNO_MAPPINGS(X)
872 #undef X
873   }
874   return EINVAL;
875 }
876 
err_map_win_error(void)877 void err_map_win_error(void) {
878   errno = err__map_win_error_to_errno(GetLastError());
879 }
880 
err_set_win_error(DWORD error)881 void err_set_win_error(DWORD error) {
882   SetLastError(error);
883   errno = err__map_win_error_to_errno(error);
884 }
885 
err_check_handle(HANDLE handle)886 int err_check_handle(HANDLE handle) {
887   DWORD flags;
888 
889   /* GetHandleInformation() succeeds when passed INVALID_HANDLE_VALUE, so check
890    * for this condition explicitly. */
891   if (handle == INVALID_HANDLE_VALUE)
892     return_set_error(-1, ERROR_INVALID_HANDLE);
893 
894   if (!GetHandleInformation(handle, &flags))
895     return_map_error(-1);
896 
897   return 0;
898 }
899 
900 static bool init__done = false;
901 static INIT_ONCE init__once = INIT_ONCE_STATIC_INIT;
902 
init__once_callback(INIT_ONCE * once,void * parameter,void ** context)903 static BOOL CALLBACK init__once_callback(INIT_ONCE* once,
904                                          void* parameter,
905                                          void** context) {
906   unused_var(once);
907   unused_var(parameter);
908   unused_var(context);
909 
910   /* N.b. that initialization order matters here. */
911   if (ws_global_init() < 0 || nt_global_init() < 0 ||
912       reflock_global_init() < 0 || epoll_global_init() < 0)
913     return FALSE;
914 
915   init__done = true;
916   return TRUE;
917 }
918 
init(void)919 int init(void) {
920   if (!init__done &&
921       !InitOnceExecuteOnce(&init__once, init__once_callback, NULL, NULL))
922     return -1; /* LastError and errno aren't touched InitOnceExecuteOnce. */
923 
924   return 0;
925 }
926 
927 /* Set up a workaround for the following problem:
928  *   FARPROC addr = GetProcAddress(...);
929  *   MY_FUNC func = (MY_FUNC) addr;          <-- GCC 8 warning/error.
930  *   MY_FUNC func = (MY_FUNC) (void*) addr;  <-- MSVC  warning/error.
931  * To compile cleanly with either compiler, do casts with this "bridge" type:
932  *   MY_FUNC func = (MY_FUNC) (nt__fn_ptr_cast_t) addr; */
933 #ifdef __GNUC__
934 typedef void* nt__fn_ptr_cast_t;
935 #else
936 typedef FARPROC nt__fn_ptr_cast_t;
937 #endif
938 
939 #define X(return_type, attributes, name, parameters) \
940   WEPOLL_INTERNAL return_type(attributes* name) parameters = NULL;
NT_NTDLL_IMPORT_LIST(X)941 NT_NTDLL_IMPORT_LIST(X)
942 #undef X
943 
944 int nt_global_init(void) {
945   HMODULE ntdll;
946   FARPROC fn_ptr;
947 
948   ntdll = GetModuleHandleW(L"ntdll.dll");
949   if (ntdll == NULL)
950     return -1;
951 
952 #define X(return_type, attributes, name, parameters) \
953   fn_ptr = GetProcAddress(ntdll, #name);             \
954   if (fn_ptr == NULL)                                \
955     return -1;                                       \
956   name = (return_type(attributes*) parameters)(nt__fn_ptr_cast_t) fn_ptr;
957   NT_NTDLL_IMPORT_LIST(X)
958 #undef X
959 
960   return 0;
961 }
962 
963 #include <string.h>
964 
965 static const size_t POLL_GROUP__MAX_GROUP_SIZE = 32;
966 
967 typedef struct poll_group {
968   port_state_t* port_state;
969   queue_node_t queue_node;
970   HANDLE afd_helper_handle;
971   size_t group_size;
972 } poll_group_t;
973 
poll_group__new(port_state_t * port_state)974 static poll_group_t* poll_group__new(port_state_t* port_state) {
975   poll_group_t* poll_group = malloc(sizeof *poll_group);
976   if (poll_group == NULL)
977     return_set_error(NULL, ERROR_NOT_ENOUGH_MEMORY);
978 
979   memset(poll_group, 0, sizeof *poll_group);
980 
981   queue_node_init(&poll_group->queue_node);
982   poll_group->port_state = port_state;
983 
984   if (afd_create_helper_handle(port_state->iocp,
985                                &poll_group->afd_helper_handle) < 0) {
986     free(poll_group);
987     return NULL;
988   }
989 
990   queue_append(&port_state->poll_group_queue, &poll_group->queue_node);
991 
992   return poll_group;
993 }
994 
poll_group_delete(poll_group_t * poll_group)995 void poll_group_delete(poll_group_t* poll_group) {
996   assert(poll_group->group_size == 0);
997   CloseHandle(poll_group->afd_helper_handle);
998   queue_remove(&poll_group->queue_node);
999   free(poll_group);
1000 }
1001 
poll_group_from_queue_node(queue_node_t * queue_node)1002 poll_group_t* poll_group_from_queue_node(queue_node_t* queue_node) {
1003   return container_of(queue_node, poll_group_t, queue_node);
1004 }
1005 
poll_group_get_afd_helper_handle(poll_group_t * poll_group)1006 HANDLE poll_group_get_afd_helper_handle(poll_group_t* poll_group) {
1007   return poll_group->afd_helper_handle;
1008 }
1009 
poll_group_acquire(port_state_t * port_state)1010 poll_group_t* poll_group_acquire(port_state_t* port_state) {
1011   queue_t* queue = &port_state->poll_group_queue;
1012   poll_group_t* poll_group =
1013       !queue_empty(queue)
1014           ? container_of(queue_last(queue), poll_group_t, queue_node)
1015           : NULL;
1016 
1017   if (poll_group == NULL ||
1018       poll_group->group_size >= POLL_GROUP__MAX_GROUP_SIZE)
1019     poll_group = poll_group__new(port_state);
1020   if (poll_group == NULL)
1021     return NULL;
1022 
1023   if (++poll_group->group_size == POLL_GROUP__MAX_GROUP_SIZE)
1024     queue_move_first(&port_state->poll_group_queue, &poll_group->queue_node);
1025 
1026   return poll_group;
1027 }
1028 
poll_group_release(poll_group_t * poll_group)1029 void poll_group_release(poll_group_t* poll_group) {
1030   port_state_t* port_state = poll_group->port_state;
1031 
1032   poll_group->group_size--;
1033   assert(poll_group->group_size < POLL_GROUP__MAX_GROUP_SIZE);
1034 
1035   queue_move_last(&port_state->poll_group_queue, &poll_group->queue_node);
1036 
1037   /* Poll groups are currently only freed when the epoll port is closed. */
1038 }
1039 
1040 #define PORT__MAX_ON_STACK_COMPLETIONS 256
1041 
port__alloc(void)1042 static port_state_t* port__alloc(void) {
1043   port_state_t* port_state = malloc(sizeof *port_state);
1044   if (port_state == NULL)
1045     return_set_error(NULL, ERROR_NOT_ENOUGH_MEMORY);
1046 
1047   return port_state;
1048 }
1049 
port__free(port_state_t * port)1050 static void port__free(port_state_t* port) {
1051   assert(port != NULL);
1052   free(port);
1053 }
1054 
port__create_iocp(void)1055 static HANDLE port__create_iocp(void) {
1056   HANDLE iocp = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, 0);
1057   if (iocp == NULL)
1058     return_map_error(NULL);
1059 
1060   return iocp;
1061 }
1062 
port_new(HANDLE * iocp_out)1063 port_state_t* port_new(HANDLE* iocp_out) {
1064   port_state_t* port_state;
1065   HANDLE iocp;
1066 
1067   port_state = port__alloc();
1068   if (port_state == NULL)
1069     goto err1;
1070 
1071   iocp = port__create_iocp();
1072   if (iocp == NULL)
1073     goto err2;
1074 
1075   memset(port_state, 0, sizeof *port_state);
1076 
1077   port_state->iocp = iocp;
1078   tree_init(&port_state->sock_tree);
1079   queue_init(&port_state->sock_update_queue);
1080   queue_init(&port_state->sock_deleted_queue);
1081   queue_init(&port_state->poll_group_queue);
1082   ts_tree_node_init(&port_state->handle_tree_node);
1083   InitializeCriticalSection(&port_state->lock);
1084 
1085   *iocp_out = iocp;
1086   return port_state;
1087 
1088 err2:
1089   port__free(port_state);
1090 err1:
1091   return NULL;
1092 }
1093 
port__close_iocp(port_state_t * port_state)1094 static int port__close_iocp(port_state_t* port_state) {
1095   HANDLE iocp = port_state->iocp;
1096   port_state->iocp = NULL;
1097 
1098   if (!CloseHandle(iocp))
1099     return_map_error(-1);
1100 
1101   return 0;
1102 }
1103 
port_close(port_state_t * port_state)1104 int port_close(port_state_t* port_state) {
1105   int result;
1106 
1107   EnterCriticalSection(&port_state->lock);
1108   result = port__close_iocp(port_state);
1109   LeaveCriticalSection(&port_state->lock);
1110 
1111   return result;
1112 }
1113 
port_delete(port_state_t * port_state)1114 int port_delete(port_state_t* port_state) {
1115   tree_node_t* tree_node;
1116   queue_node_t* queue_node;
1117 
1118   /* At this point the IOCP port should have been closed. */
1119   assert(port_state->iocp == NULL);
1120 
1121   while ((tree_node = tree_root(&port_state->sock_tree)) != NULL) {
1122     sock_state_t* sock_state = sock_state_from_tree_node(tree_node);
1123     sock_force_delete(port_state, sock_state);
1124   }
1125 
1126   while ((queue_node = queue_first(&port_state->sock_deleted_queue)) != NULL) {
1127     sock_state_t* sock_state = sock_state_from_queue_node(queue_node);
1128     sock_force_delete(port_state, sock_state);
1129   }
1130 
1131   while ((queue_node = queue_first(&port_state->poll_group_queue)) != NULL) {
1132     poll_group_t* poll_group = poll_group_from_queue_node(queue_node);
1133     poll_group_delete(poll_group);
1134   }
1135 
1136   assert(queue_empty(&port_state->sock_update_queue));
1137 
1138   DeleteCriticalSection(&port_state->lock);
1139 
1140   port__free(port_state);
1141 
1142   return 0;
1143 }
1144 
port__update_events(port_state_t * port_state)1145 static int port__update_events(port_state_t* port_state) {
1146   queue_t* sock_update_queue = &port_state->sock_update_queue;
1147 
1148   /* Walk the queue, submitting new poll requests for every socket that needs
1149    * it. */
1150   while (!queue_empty(sock_update_queue)) {
1151     queue_node_t* queue_node = queue_first(sock_update_queue);
1152     sock_state_t* sock_state = sock_state_from_queue_node(queue_node);
1153 
1154     if (sock_update(port_state, sock_state) < 0)
1155       return -1;
1156 
1157     /* sock_update() removes the socket from the update queue. */
1158   }
1159 
1160   return 0;
1161 }
1162 
port__update_events_if_polling(port_state_t * port_state)1163 static void port__update_events_if_polling(port_state_t* port_state) {
1164   if (port_state->active_poll_count > 0)
1165     port__update_events(port_state);
1166 }
1167 
port__feed_events(port_state_t * port_state,struct epoll_event * epoll_events,OVERLAPPED_ENTRY * iocp_events,DWORD iocp_event_count)1168 static int port__feed_events(port_state_t* port_state,
1169                              struct epoll_event* epoll_events,
1170                              OVERLAPPED_ENTRY* iocp_events,
1171                              DWORD iocp_event_count) {
1172   int epoll_event_count = 0;
1173   DWORD i;
1174 
1175   for (i = 0; i < iocp_event_count; i++) {
1176     OVERLAPPED* overlapped = iocp_events[i].lpOverlapped;
1177     struct epoll_event* ev = &epoll_events[epoll_event_count];
1178 
1179     epoll_event_count += sock_feed_event(port_state, overlapped, ev);
1180   }
1181 
1182   return epoll_event_count;
1183 }
1184 
port__poll(port_state_t * port_state,struct epoll_event * epoll_events,OVERLAPPED_ENTRY * iocp_events,DWORD maxevents,DWORD timeout)1185 static int port__poll(port_state_t* port_state,
1186                       struct epoll_event* epoll_events,
1187                       OVERLAPPED_ENTRY* iocp_events,
1188                       DWORD maxevents,
1189                       DWORD timeout) {
1190   DWORD completion_count;
1191 
1192   if (port__update_events(port_state) < 0)
1193     return -1;
1194 
1195   port_state->active_poll_count++;
1196 
1197   LeaveCriticalSection(&port_state->lock);
1198 
1199   BOOL r = GetQueuedCompletionStatusEx(port_state->iocp,
1200                                        iocp_events,
1201                                        maxevents,
1202                                        &completion_count,
1203                                        timeout,
1204                                        FALSE);
1205 
1206   EnterCriticalSection(&port_state->lock);
1207 
1208   port_state->active_poll_count--;
1209 
1210   if (!r)
1211     return_map_error(-1);
1212 
1213   return port__feed_events(
1214       port_state, epoll_events, iocp_events, completion_count);
1215 }
1216 
port_wait(port_state_t * port_state,struct epoll_event * events,int maxevents,int timeout)1217 int port_wait(port_state_t* port_state,
1218               struct epoll_event* events,
1219               int maxevents,
1220               int timeout) {
1221   OVERLAPPED_ENTRY stack_iocp_events[PORT__MAX_ON_STACK_COMPLETIONS];
1222   OVERLAPPED_ENTRY* iocp_events;
1223   uint64_t due = 0;
1224   DWORD gqcs_timeout;
1225   int result;
1226 
1227   /* Check whether `maxevents` is in range. */
1228   if (maxevents <= 0)
1229     return_set_error(-1, ERROR_INVALID_PARAMETER);
1230 
1231   /* Decide whether the IOCP completion list can live on the stack, or allocate
1232    * memory for it on the heap. */
1233   if ((size_t) maxevents <= array_count(stack_iocp_events)) {
1234     iocp_events = stack_iocp_events;
1235   } else if ((iocp_events =
1236                   malloc((size_t) maxevents * sizeof *iocp_events)) == NULL) {
1237     iocp_events = stack_iocp_events;
1238     maxevents = array_count(stack_iocp_events);
1239   }
1240 
1241   /* Compute the timeout for GetQueuedCompletionStatus, and the wait end
1242    * time, if the user specified a timeout other than zero or infinite. */
1243   if (timeout > 0) {
1244     due = GetTickCount64() + (uint64_t) timeout;
1245     gqcs_timeout = (DWORD) timeout;
1246   } else if (timeout == 0) {
1247     gqcs_timeout = 0;
1248   } else {
1249     gqcs_timeout = INFINITE;
1250   }
1251 
1252   EnterCriticalSection(&port_state->lock);
1253 
1254   /* Dequeue completion packets until either at least one interesting event
1255    * has been discovered, or the timeout is reached. */
1256   for (;;) {
1257     uint64_t now;
1258 
1259     result = port__poll(
1260         port_state, events, iocp_events, (DWORD) maxevents, gqcs_timeout);
1261     if (result < 0 || result > 0)
1262       break; /* Result, error, or time-out. */
1263 
1264     if (timeout < 0)
1265       continue; /* When timeout is negative, never time out. */
1266 
1267     /* Update time. */
1268     now = GetTickCount64();
1269 
1270     /* Do not allow the due time to be in the past. */
1271     if (now >= due) {
1272       SetLastError(WAIT_TIMEOUT);
1273       break;
1274     }
1275 
1276     /* Recompute time-out argument for GetQueuedCompletionStatus. */
1277     gqcs_timeout = (DWORD)(due - now);
1278   }
1279 
1280   port__update_events_if_polling(port_state);
1281 
1282   LeaveCriticalSection(&port_state->lock);
1283 
1284   if (iocp_events != stack_iocp_events)
1285     free(iocp_events);
1286 
1287   if (result >= 0)
1288     return result;
1289   else if (GetLastError() == WAIT_TIMEOUT)
1290     return 0;
1291   else
1292     return -1;
1293 }
1294 
port__ctl_add(port_state_t * port_state,SOCKET sock,struct epoll_event * ev)1295 static int port__ctl_add(port_state_t* port_state,
1296                          SOCKET sock,
1297                          struct epoll_event* ev) {
1298   sock_state_t* sock_state = sock_new(port_state, sock);
1299   if (sock_state == NULL)
1300     return -1;
1301 
1302   if (sock_set_event(port_state, sock_state, ev) < 0) {
1303     sock_delete(port_state, sock_state);
1304     return -1;
1305   }
1306 
1307   port__update_events_if_polling(port_state);
1308 
1309   return 0;
1310 }
1311 
port__ctl_mod(port_state_t * port_state,SOCKET sock,struct epoll_event * ev)1312 static int port__ctl_mod(port_state_t* port_state,
1313                          SOCKET sock,
1314                          struct epoll_event* ev) {
1315   sock_state_t* sock_state = port_find_socket(port_state, sock);
1316   if (sock_state == NULL)
1317     return -1;
1318 
1319   if (sock_set_event(port_state, sock_state, ev) < 0)
1320     return -1;
1321 
1322   port__update_events_if_polling(port_state);
1323 
1324   return 0;
1325 }
1326 
port__ctl_del(port_state_t * port_state,SOCKET sock)1327 static int port__ctl_del(port_state_t* port_state, SOCKET sock) {
1328   sock_state_t* sock_state = port_find_socket(port_state, sock);
1329   if (sock_state == NULL)
1330     return -1;
1331 
1332   sock_delete(port_state, sock_state);
1333 
1334   return 0;
1335 }
1336 
port__ctl_op(port_state_t * port_state,int op,SOCKET sock,struct epoll_event * ev)1337 static int port__ctl_op(port_state_t* port_state,
1338                         int op,
1339                         SOCKET sock,
1340                         struct epoll_event* ev) {
1341   switch (op) {
1342     case EPOLL_CTL_ADD:
1343       return port__ctl_add(port_state, sock, ev);
1344     case EPOLL_CTL_MOD:
1345       return port__ctl_mod(port_state, sock, ev);
1346     case EPOLL_CTL_DEL:
1347       return port__ctl_del(port_state, sock);
1348     default:
1349       return_set_error(-1, ERROR_INVALID_PARAMETER);
1350   }
1351 }
1352 
port_ctl(port_state_t * port_state,int op,SOCKET sock,struct epoll_event * ev)1353 int port_ctl(port_state_t* port_state,
1354              int op,
1355              SOCKET sock,
1356              struct epoll_event* ev) {
1357   int result;
1358 
1359   EnterCriticalSection(&port_state->lock);
1360   result = port__ctl_op(port_state, op, sock, ev);
1361   LeaveCriticalSection(&port_state->lock);
1362 
1363   return result;
1364 }
1365 
port_register_socket_handle(port_state_t * port_state,sock_state_t * sock_state,SOCKET socket)1366 int port_register_socket_handle(port_state_t* port_state,
1367                                 sock_state_t* sock_state,
1368                                 SOCKET socket) {
1369   if (tree_add(&port_state->sock_tree,
1370                sock_state_to_tree_node(sock_state),
1371                socket) < 0)
1372     return_set_error(-1, ERROR_ALREADY_EXISTS);
1373   return 0;
1374 }
1375 
port_unregister_socket_handle(port_state_t * port_state,sock_state_t * sock_state)1376 void port_unregister_socket_handle(port_state_t* port_state,
1377                                    sock_state_t* sock_state) {
1378   tree_del(&port_state->sock_tree, sock_state_to_tree_node(sock_state));
1379 }
1380 
port_find_socket(port_state_t * port_state,SOCKET socket)1381 sock_state_t* port_find_socket(port_state_t* port_state, SOCKET socket) {
1382   tree_node_t* tree_node = tree_find(&port_state->sock_tree, socket);
1383   if (tree_node == NULL)
1384     return_set_error(NULL, ERROR_NOT_FOUND);
1385   return sock_state_from_tree_node(tree_node);
1386 }
1387 
port_request_socket_update(port_state_t * port_state,sock_state_t * sock_state)1388 void port_request_socket_update(port_state_t* port_state,
1389                                 sock_state_t* sock_state) {
1390   if (queue_enqueued(sock_state_to_queue_node(sock_state)))
1391     return;
1392   queue_append(&port_state->sock_update_queue,
1393                sock_state_to_queue_node(sock_state));
1394 }
1395 
port_cancel_socket_update(port_state_t * port_state,sock_state_t * sock_state)1396 void port_cancel_socket_update(port_state_t* port_state,
1397                                sock_state_t* sock_state) {
1398   unused_var(port_state);
1399   if (!queue_enqueued(sock_state_to_queue_node(sock_state)))
1400     return;
1401   queue_remove(sock_state_to_queue_node(sock_state));
1402 }
1403 
port_add_deleted_socket(port_state_t * port_state,sock_state_t * sock_state)1404 void port_add_deleted_socket(port_state_t* port_state,
1405                              sock_state_t* sock_state) {
1406   if (queue_enqueued(sock_state_to_queue_node(sock_state)))
1407     return;
1408   queue_append(&port_state->sock_deleted_queue,
1409                sock_state_to_queue_node(sock_state));
1410 }
1411 
port_remove_deleted_socket(port_state_t * port_state,sock_state_t * sock_state)1412 void port_remove_deleted_socket(port_state_t* port_state,
1413                                 sock_state_t* sock_state) {
1414   unused_var(port_state);
1415   if (!queue_enqueued(sock_state_to_queue_node(sock_state)))
1416     return;
1417   queue_remove(sock_state_to_queue_node(sock_state));
1418 }
1419 
queue_init(queue_t * queue)1420 void queue_init(queue_t* queue) {
1421   queue_node_init(&queue->head);
1422 }
1423 
queue_node_init(queue_node_t * node)1424 void queue_node_init(queue_node_t* node) {
1425   node->prev = node;
1426   node->next = node;
1427 }
1428 
queue__detach_node(queue_node_t * node)1429 static inline void queue__detach_node(queue_node_t* node) {
1430   node->prev->next = node->next;
1431   node->next->prev = node->prev;
1432 }
1433 
queue_first(const queue_t * queue)1434 queue_node_t* queue_first(const queue_t* queue) {
1435   return !queue_empty(queue) ? queue->head.next : NULL;
1436 }
1437 
queue_last(const queue_t * queue)1438 queue_node_t* queue_last(const queue_t* queue) {
1439   return !queue_empty(queue) ? queue->head.prev : NULL;
1440 }
1441 
queue_prepend(queue_t * queue,queue_node_t * node)1442 void queue_prepend(queue_t* queue, queue_node_t* node) {
1443   node->next = queue->head.next;
1444   node->prev = &queue->head;
1445   node->next->prev = node;
1446   queue->head.next = node;
1447 }
1448 
queue_append(queue_t * queue,queue_node_t * node)1449 void queue_append(queue_t* queue, queue_node_t* node) {
1450   node->next = &queue->head;
1451   node->prev = queue->head.prev;
1452   node->prev->next = node;
1453   queue->head.prev = node;
1454 }
1455 
queue_move_first(queue_t * queue,queue_node_t * node)1456 void queue_move_first(queue_t* queue, queue_node_t* node) {
1457   queue__detach_node(node);
1458   queue_prepend(queue, node);
1459 }
1460 
queue_move_last(queue_t * queue,queue_node_t * node)1461 void queue_move_last(queue_t* queue, queue_node_t* node) {
1462   queue__detach_node(node);
1463   queue_append(queue, node);
1464 }
1465 
queue_remove(queue_node_t * node)1466 void queue_remove(queue_node_t* node) {
1467   queue__detach_node(node);
1468   queue_node_init(node);
1469 }
1470 
queue_empty(const queue_t * queue)1471 bool queue_empty(const queue_t* queue) {
1472   return !queue_enqueued(&queue->head);
1473 }
1474 
queue_enqueued(const queue_node_t * node)1475 bool queue_enqueued(const queue_node_t* node) {
1476   return node->prev != node;
1477 }
1478 
1479 /* clang-format off */
1480 static const long REFLOCK__REF          = (long) 0x00000001;
1481 static const long REFLOCK__REF_MASK     = (long) 0x0fffffff;
1482 static const long REFLOCK__DESTROY      = (long) 0x10000000;
1483 static const long REFLOCK__DESTROY_MASK = (long) 0xf0000000;
1484 static const long REFLOCK__POISON       = (long) 0x300DEAD0;
1485 /* clang-format on */
1486 
1487 static HANDLE reflock__keyed_event = NULL;
1488 
reflock_global_init(void)1489 int reflock_global_init(void) {
1490   NTSTATUS status =
1491       NtCreateKeyedEvent(&reflock__keyed_event, ~(ACCESS_MASK) 0, NULL, 0);
1492   if (status != STATUS_SUCCESS)
1493     return_set_error(-1, RtlNtStatusToDosError(status));
1494   return 0;
1495 }
1496 
reflock_init(reflock_t * reflock)1497 void reflock_init(reflock_t* reflock) {
1498   reflock->state = 0;
1499 }
1500 
reflock__signal_event(void * address)1501 static void reflock__signal_event(void* address) {
1502   NTSTATUS status =
1503       NtReleaseKeyedEvent(reflock__keyed_event, address, FALSE, NULL);
1504   if (status != STATUS_SUCCESS)
1505     abort();
1506 }
1507 
reflock__await_event(void * address)1508 static void reflock__await_event(void* address) {
1509   NTSTATUS status =
1510       NtWaitForKeyedEvent(reflock__keyed_event, address, FALSE, NULL);
1511   if (status != STATUS_SUCCESS)
1512     abort();
1513 }
1514 
reflock_ref(reflock_t * reflock)1515 void reflock_ref(reflock_t* reflock) {
1516   long state = InterlockedAdd(&reflock->state, REFLOCK__REF);
1517   unused_var(state);
1518   assert((state & REFLOCK__DESTROY_MASK) == 0); /* Overflow or destroyed. */
1519 }
1520 
reflock_unref(reflock_t * reflock)1521 void reflock_unref(reflock_t* reflock) {
1522   long state = InterlockedAdd(&reflock->state, -REFLOCK__REF);
1523   long ref_count = state & REFLOCK__REF_MASK;
1524   long destroy = state & REFLOCK__DESTROY_MASK;
1525 
1526   unused_var(ref_count);
1527   unused_var(destroy);
1528 
1529   if (state == REFLOCK__DESTROY)
1530     reflock__signal_event(reflock);
1531   else
1532     assert(destroy == 0 || ref_count > 0);
1533 }
1534 
reflock_unref_and_destroy(reflock_t * reflock)1535 void reflock_unref_and_destroy(reflock_t* reflock) {
1536   long state =
1537       InterlockedAdd(&reflock->state, REFLOCK__DESTROY - REFLOCK__REF);
1538   long ref_count = state & REFLOCK__REF_MASK;
1539 
1540   assert((state & REFLOCK__DESTROY_MASK) ==
1541          REFLOCK__DESTROY); /* Underflow or already destroyed. */
1542 
1543   if (ref_count != 0)
1544     reflock__await_event(reflock);
1545 
1546   state = InterlockedExchange(&reflock->state, REFLOCK__POISON);
1547   assert(state == REFLOCK__DESTROY);
1548 }
1549 
1550 static const uint32_t SOCK__KNOWN_EPOLL_EVENTS =
1551     EPOLLIN | EPOLLPRI | EPOLLOUT | EPOLLERR | EPOLLHUP | EPOLLRDNORM |
1552     EPOLLRDBAND | EPOLLWRNORM | EPOLLWRBAND | EPOLLMSG | EPOLLRDHUP;
1553 
1554 typedef enum sock__poll_status {
1555   SOCK__POLL_IDLE = 0,
1556   SOCK__POLL_PENDING,
1557   SOCK__POLL_CANCELLED
1558 } sock__poll_status_t;
1559 
1560 typedef struct sock_state {
1561   OVERLAPPED overlapped;
1562   AFD_POLL_INFO poll_info;
1563   queue_node_t queue_node;
1564   tree_node_t tree_node;
1565   poll_group_t* poll_group;
1566   SOCKET base_socket;
1567   epoll_data_t user_data;
1568   uint32_t user_events;
1569   uint32_t pending_events;
1570   sock__poll_status_t poll_status;
1571   bool delete_pending;
1572 } sock_state_t;
1573 
sock__alloc(void)1574 static inline sock_state_t* sock__alloc(void) {
1575   sock_state_t* sock_state = malloc(sizeof *sock_state);
1576   if (sock_state == NULL)
1577     return_set_error(NULL, ERROR_NOT_ENOUGH_MEMORY);
1578   return sock_state;
1579 }
1580 
sock__free(sock_state_t * sock_state)1581 static inline void sock__free(sock_state_t* sock_state) {
1582   free(sock_state);
1583 }
1584 
sock__cancel_poll(sock_state_t * sock_state)1585 static int sock__cancel_poll(sock_state_t* sock_state) {
1586   HANDLE afd_helper_handle =
1587       poll_group_get_afd_helper_handle(sock_state->poll_group);
1588   assert(sock_state->poll_status == SOCK__POLL_PENDING);
1589 
1590   /* CancelIoEx() may fail with ERROR_NOT_FOUND if the overlapped operation has
1591    * already completed. This is not a problem and we proceed normally. */
1592   if (!HasOverlappedIoCompleted(&sock_state->overlapped) &&
1593       !CancelIoEx(afd_helper_handle, &sock_state->overlapped) &&
1594       GetLastError() != ERROR_NOT_FOUND)
1595     return_map_error(-1);
1596 
1597   sock_state->poll_status = SOCK__POLL_CANCELLED;
1598   sock_state->pending_events = 0;
1599   return 0;
1600 }
1601 
sock_new(port_state_t * port_state,SOCKET socket)1602 sock_state_t* sock_new(port_state_t* port_state, SOCKET socket) {
1603   SOCKET base_socket;
1604   poll_group_t* poll_group;
1605   sock_state_t* sock_state;
1606 
1607   if (socket == 0 || socket == INVALID_SOCKET)
1608     return_set_error(NULL, ERROR_INVALID_HANDLE);
1609 
1610   base_socket = ws_get_base_socket(socket);
1611   if (base_socket == INVALID_SOCKET)
1612     return NULL;
1613 
1614   poll_group = poll_group_acquire(port_state);
1615   if (poll_group == NULL)
1616     return NULL;
1617 
1618   sock_state = sock__alloc();
1619   if (sock_state == NULL)
1620     goto err1;
1621 
1622   memset(sock_state, 0, sizeof *sock_state);
1623 
1624   sock_state->base_socket = base_socket;
1625   sock_state->poll_group = poll_group;
1626 
1627   tree_node_init(&sock_state->tree_node);
1628   queue_node_init(&sock_state->queue_node);
1629 
1630   if (port_register_socket_handle(port_state, sock_state, socket) < 0)
1631     goto err2;
1632 
1633   return sock_state;
1634 
1635 err2:
1636   sock__free(sock_state);
1637 err1:
1638   poll_group_release(poll_group);
1639 
1640   return NULL;
1641 }
1642 
sock__delete(port_state_t * port_state,sock_state_t * sock_state,bool force)1643 static int sock__delete(port_state_t* port_state,
1644                         sock_state_t* sock_state,
1645                         bool force) {
1646   if (!sock_state->delete_pending) {
1647     if (sock_state->poll_status == SOCK__POLL_PENDING)
1648       sock__cancel_poll(sock_state);
1649 
1650     port_cancel_socket_update(port_state, sock_state);
1651     port_unregister_socket_handle(port_state, sock_state);
1652 
1653     sock_state->delete_pending = true;
1654   }
1655 
1656   /* If the poll request still needs to complete, the sock_state object can't
1657    * be free()d yet. `sock_feed_event()` or `port_close()` will take care
1658    * of this later. */
1659   if (force || sock_state->poll_status == SOCK__POLL_IDLE) {
1660     /* Free the sock_state now. */
1661     port_remove_deleted_socket(port_state, sock_state);
1662     poll_group_release(sock_state->poll_group);
1663     sock__free(sock_state);
1664   } else {
1665     /* Free the socket later. */
1666     port_add_deleted_socket(port_state, sock_state);
1667   }
1668 
1669   return 0;
1670 }
1671 
sock_delete(port_state_t * port_state,sock_state_t * sock_state)1672 void sock_delete(port_state_t* port_state, sock_state_t* sock_state) {
1673   sock__delete(port_state, sock_state, false);
1674 }
1675 
sock_force_delete(port_state_t * port_state,sock_state_t * sock_state)1676 void sock_force_delete(port_state_t* port_state, sock_state_t* sock_state) {
1677   sock__delete(port_state, sock_state, true);
1678 }
1679 
sock_set_event(port_state_t * port_state,sock_state_t * sock_state,const struct epoll_event * ev)1680 int sock_set_event(port_state_t* port_state,
1681                    sock_state_t* sock_state,
1682                    const struct epoll_event* ev) {
1683   /* EPOLLERR and EPOLLHUP are always reported, even when not requested by the
1684    * caller. However they are disabled after a event has been reported for a
1685    * socket for which the EPOLLONESHOT flag as set. */
1686   uint32_t events = ev->events | EPOLLERR | EPOLLHUP;
1687 
1688   sock_state->user_events = events;
1689   sock_state->user_data = ev->data;
1690 
1691   if ((events & SOCK__KNOWN_EPOLL_EVENTS & ~sock_state->pending_events) != 0)
1692     port_request_socket_update(port_state, sock_state);
1693 
1694   return 0;
1695 }
1696 
sock__epoll_events_to_afd_events(uint32_t epoll_events)1697 static inline DWORD sock__epoll_events_to_afd_events(uint32_t epoll_events) {
1698   /* Always monitor for AFD_POLL_LOCAL_CLOSE, which is triggered when the
1699    * socket is closed with closesocket() or CloseHandle(). */
1700   DWORD afd_events = AFD_POLL_LOCAL_CLOSE;
1701 
1702   if (epoll_events & (EPOLLIN | EPOLLRDNORM))
1703     afd_events |= AFD_POLL_RECEIVE | AFD_POLL_ACCEPT;
1704   if (epoll_events & (EPOLLPRI | EPOLLRDBAND))
1705     afd_events |= AFD_POLL_RECEIVE_EXPEDITED;
1706   if (epoll_events & (EPOLLOUT | EPOLLWRNORM | EPOLLWRBAND))
1707     afd_events |= AFD_POLL_SEND;
1708   if (epoll_events & (EPOLLIN | EPOLLRDNORM | EPOLLRDHUP))
1709     afd_events |= AFD_POLL_DISCONNECT;
1710   if (epoll_events & EPOLLHUP)
1711     afd_events |= AFD_POLL_ABORT;
1712   if (epoll_events & EPOLLERR)
1713     afd_events |= AFD_POLL_CONNECT_FAIL;
1714 
1715   return afd_events;
1716 }
1717 
sock__afd_events_to_epoll_events(DWORD afd_events)1718 static inline uint32_t sock__afd_events_to_epoll_events(DWORD afd_events) {
1719   uint32_t epoll_events = 0;
1720 
1721   if (afd_events & (AFD_POLL_RECEIVE | AFD_POLL_ACCEPT))
1722     epoll_events |= EPOLLIN | EPOLLRDNORM;
1723   if (afd_events & AFD_POLL_RECEIVE_EXPEDITED)
1724     epoll_events |= EPOLLPRI | EPOLLRDBAND;
1725   if (afd_events & AFD_POLL_SEND)
1726     epoll_events |= EPOLLOUT | EPOLLWRNORM | EPOLLWRBAND;
1727   if (afd_events & AFD_POLL_DISCONNECT)
1728     epoll_events |= EPOLLIN | EPOLLRDNORM | EPOLLRDHUP;
1729   if (afd_events & AFD_POLL_ABORT)
1730     epoll_events |= EPOLLHUP;
1731   if (afd_events & AFD_POLL_CONNECT_FAIL)
1732     /* Linux reports all these events after connect() has failed. */
1733     epoll_events |=
1734         EPOLLIN | EPOLLOUT | EPOLLERR | EPOLLRDNORM | EPOLLWRNORM | EPOLLRDHUP;
1735 
1736   return epoll_events;
1737 }
1738 
sock_update(port_state_t * port_state,sock_state_t * sock_state)1739 int sock_update(port_state_t* port_state, sock_state_t* sock_state) {
1740   assert(!sock_state->delete_pending);
1741 
1742   if ((sock_state->poll_status == SOCK__POLL_PENDING) &&
1743       (sock_state->user_events & SOCK__KNOWN_EPOLL_EVENTS &
1744        ~sock_state->pending_events) == 0) {
1745     /* All the events the user is interested in are already being monitored by
1746      * the pending poll operation. It might spuriously complete because of an
1747      * event that we're no longer interested in; when that happens we'll submit
1748      * a new poll operation with the updated event mask. */
1749 
1750   } else if (sock_state->poll_status == SOCK__POLL_PENDING) {
1751     /* A poll operation is already pending, but it's not monitoring for all the
1752      * events that the user is interested in. Therefore, cancel the pending
1753      * poll operation; when we receive it's completion package, a new poll
1754      * operation will be submitted with the correct event mask. */
1755     if (sock__cancel_poll(sock_state) < 0)
1756       return -1;
1757 
1758   } else if (sock_state->poll_status == SOCK__POLL_CANCELLED) {
1759     /* The poll operation has already been cancelled, we're still waiting for
1760      * it to return. For now, there's nothing that needs to be done. */
1761 
1762   } else if (sock_state->poll_status == SOCK__POLL_IDLE) {
1763     /* No poll operation is pending; start one. */
1764     sock_state->poll_info.Exclusive = FALSE;
1765     sock_state->poll_info.NumberOfHandles = 1;
1766     sock_state->poll_info.Timeout.QuadPart = INT64_MAX;
1767     sock_state->poll_info.Handles[0].Handle = (HANDLE) sock_state->base_socket;
1768     sock_state->poll_info.Handles[0].Status = 0;
1769     sock_state->poll_info.Handles[0].Events =
1770         sock__epoll_events_to_afd_events(sock_state->user_events);
1771 
1772     memset(&sock_state->overlapped, 0, sizeof sock_state->overlapped);
1773 
1774     if (afd_poll(poll_group_get_afd_helper_handle(sock_state->poll_group),
1775                  &sock_state->poll_info,
1776                  &sock_state->overlapped) < 0) {
1777       switch (GetLastError()) {
1778         case ERROR_IO_PENDING:
1779           /* Overlapped poll operation in progress; this is expected. */
1780           break;
1781         case ERROR_INVALID_HANDLE:
1782           /* Socket closed; it'll be dropped from the epoll set. */
1783           return sock__delete(port_state, sock_state, false);
1784         default:
1785           /* Other errors are propagated to the caller. */
1786           return_map_error(-1);
1787       }
1788     }
1789 
1790     /* The poll request was successfully submitted. */
1791     sock_state->poll_status = SOCK__POLL_PENDING;
1792     sock_state->pending_events = sock_state->user_events;
1793 
1794   } else {
1795     /* Unreachable. */
1796     assert(false);
1797   }
1798 
1799   port_cancel_socket_update(port_state, sock_state);
1800   return 0;
1801 }
1802 
sock_feed_event(port_state_t * port_state,OVERLAPPED * overlapped,struct epoll_event * ev)1803 int sock_feed_event(port_state_t* port_state,
1804                     OVERLAPPED* overlapped,
1805                     struct epoll_event* ev) {
1806   sock_state_t* sock_state =
1807       container_of(overlapped, sock_state_t, overlapped);
1808   AFD_POLL_INFO* poll_info = &sock_state->poll_info;
1809   uint32_t epoll_events = 0;
1810 
1811   sock_state->poll_status = SOCK__POLL_IDLE;
1812   sock_state->pending_events = 0;
1813 
1814   if (sock_state->delete_pending) {
1815     /* Socket has been deleted earlier and can now be freed. */
1816     return sock__delete(port_state, sock_state, false);
1817 
1818   } else if ((NTSTATUS) overlapped->Internal == STATUS_CANCELLED) {
1819     /* The poll request was cancelled by CancelIoEx. */
1820 
1821   } else if (!NT_SUCCESS(overlapped->Internal)) {
1822     /* The overlapped request itself failed in an unexpected way. */
1823     epoll_events = EPOLLERR;
1824 
1825   } else if (poll_info->NumberOfHandles < 1) {
1826     /* This poll operation succeeded but didn't report any socket events. */
1827 
1828   } else if (poll_info->Handles[0].Events & AFD_POLL_LOCAL_CLOSE) {
1829     /* The poll operation reported that the socket was closed. */
1830     return sock__delete(port_state, sock_state, false);
1831 
1832   } else {
1833     /* Events related to our socket were reported. */
1834     epoll_events =
1835         sock__afd_events_to_epoll_events(poll_info->Handles[0].Events);
1836   }
1837 
1838   /* Requeue the socket so a new poll request will be submitted. */
1839   port_request_socket_update(port_state, sock_state);
1840 
1841   /* Filter out events that the user didn't ask for. */
1842   epoll_events &= sock_state->user_events;
1843 
1844   /* Return if there are no epoll events to report. */
1845   if (epoll_events == 0)
1846     return 0;
1847 
1848   /* If the the socket has the EPOLLONESHOT flag set, unmonitor all events,
1849    * even EPOLLERR and EPOLLHUP. But always keep looking for closed sockets. */
1850   if (sock_state->user_events & EPOLLONESHOT)
1851     sock_state->user_events = 0;
1852 
1853   ev->data = sock_state->user_data;
1854   ev->events = epoll_events;
1855   return 1;
1856 }
1857 
sock_state_to_queue_node(sock_state_t * sock_state)1858 queue_node_t* sock_state_to_queue_node(sock_state_t* sock_state) {
1859   return &sock_state->queue_node;
1860 }
1861 
sock_state_from_tree_node(tree_node_t * tree_node)1862 sock_state_t* sock_state_from_tree_node(tree_node_t* tree_node) {
1863   return container_of(tree_node, sock_state_t, tree_node);
1864 }
1865 
sock_state_to_tree_node(sock_state_t * sock_state)1866 tree_node_t* sock_state_to_tree_node(sock_state_t* sock_state) {
1867   return &sock_state->tree_node;
1868 }
1869 
sock_state_from_queue_node(queue_node_t * queue_node)1870 sock_state_t* sock_state_from_queue_node(queue_node_t* queue_node) {
1871   return container_of(queue_node, sock_state_t, queue_node);
1872 }
1873 
ts_tree_init(ts_tree_t * ts_tree)1874 void ts_tree_init(ts_tree_t* ts_tree) {
1875   tree_init(&ts_tree->tree);
1876   InitializeSRWLock(&ts_tree->lock);
1877 }
1878 
ts_tree_node_init(ts_tree_node_t * node)1879 void ts_tree_node_init(ts_tree_node_t* node) {
1880   tree_node_init(&node->tree_node);
1881   reflock_init(&node->reflock);
1882 }
1883 
ts_tree_add(ts_tree_t * ts_tree,ts_tree_node_t * node,uintptr_t key)1884 int ts_tree_add(ts_tree_t* ts_tree, ts_tree_node_t* node, uintptr_t key) {
1885   int r;
1886 
1887   AcquireSRWLockExclusive(&ts_tree->lock);
1888   r = tree_add(&ts_tree->tree, &node->tree_node, key);
1889   ReleaseSRWLockExclusive(&ts_tree->lock);
1890 
1891   return r;
1892 }
1893 
ts_tree__find_node(ts_tree_t * ts_tree,uintptr_t key)1894 static inline ts_tree_node_t* ts_tree__find_node(ts_tree_t* ts_tree,
1895                                                  uintptr_t key) {
1896   tree_node_t* tree_node = tree_find(&ts_tree->tree, key);
1897   if (tree_node == NULL)
1898     return NULL;
1899 
1900   return container_of(tree_node, ts_tree_node_t, tree_node);
1901 }
1902 
ts_tree_del_and_ref(ts_tree_t * ts_tree,uintptr_t key)1903 ts_tree_node_t* ts_tree_del_and_ref(ts_tree_t* ts_tree, uintptr_t key) {
1904   ts_tree_node_t* ts_tree_node;
1905 
1906   AcquireSRWLockExclusive(&ts_tree->lock);
1907 
1908   ts_tree_node = ts_tree__find_node(ts_tree, key);
1909   if (ts_tree_node != NULL) {
1910     tree_del(&ts_tree->tree, &ts_tree_node->tree_node);
1911     reflock_ref(&ts_tree_node->reflock);
1912   }
1913 
1914   ReleaseSRWLockExclusive(&ts_tree->lock);
1915 
1916   return ts_tree_node;
1917 }
1918 
ts_tree_find_and_ref(ts_tree_t * ts_tree,uintptr_t key)1919 ts_tree_node_t* ts_tree_find_and_ref(ts_tree_t* ts_tree, uintptr_t key) {
1920   ts_tree_node_t* ts_tree_node;
1921 
1922   AcquireSRWLockShared(&ts_tree->lock);
1923 
1924   ts_tree_node = ts_tree__find_node(ts_tree, key);
1925   if (ts_tree_node != NULL)
1926     reflock_ref(&ts_tree_node->reflock);
1927 
1928   ReleaseSRWLockShared(&ts_tree->lock);
1929 
1930   return ts_tree_node;
1931 }
1932 
ts_tree_node_unref(ts_tree_node_t * node)1933 void ts_tree_node_unref(ts_tree_node_t* node) {
1934   reflock_unref(&node->reflock);
1935 }
1936 
ts_tree_node_unref_and_destroy(ts_tree_node_t * node)1937 void ts_tree_node_unref_and_destroy(ts_tree_node_t* node) {
1938   reflock_unref_and_destroy(&node->reflock);
1939 }
1940 
tree_init(tree_t * tree)1941 void tree_init(tree_t* tree) {
1942   memset(tree, 0, sizeof *tree);
1943 }
1944 
tree_node_init(tree_node_t * node)1945 void tree_node_init(tree_node_t* node) {
1946   memset(node, 0, sizeof *node);
1947 }
1948 
1949 #define TREE__ROTATE(cis, trans)   \
1950   tree_node_t* p = node;           \
1951   tree_node_t* q = node->trans;    \
1952   tree_node_t* parent = p->parent; \
1953                                    \
1954   if (parent) {                    \
1955     if (parent->left == p)         \
1956       parent->left = q;            \
1957     else                           \
1958       parent->right = q;           \
1959   } else {                         \
1960     tree->root = q;                \
1961   }                                \
1962                                    \
1963   q->parent = parent;              \
1964   p->parent = q;                   \
1965   p->trans = q->cis;               \
1966   if (p->trans)                    \
1967     p->trans->parent = p;          \
1968   q->cis = p;
1969 
tree__rotate_left(tree_t * tree,tree_node_t * node)1970 static inline void tree__rotate_left(tree_t* tree, tree_node_t* node) {
1971   TREE__ROTATE(left, right)
1972 }
1973 
tree__rotate_right(tree_t * tree,tree_node_t * node)1974 static inline void tree__rotate_right(tree_t* tree, tree_node_t* node) {
1975   TREE__ROTATE(right, left)
1976 }
1977 
1978 #define TREE__INSERT_OR_DESCEND(side) \
1979   if (parent->side) {                 \
1980     parent = parent->side;            \
1981   } else {                            \
1982     parent->side = node;              \
1983     break;                            \
1984   }
1985 
1986 #define TREE__FIXUP_AFTER_INSERT(cis, trans) \
1987   tree_node_t* grandparent = parent->parent; \
1988   tree_node_t* uncle = grandparent->trans;   \
1989                                              \
1990   if (uncle && uncle->red) {                 \
1991     parent->red = uncle->red = false;        \
1992     grandparent->red = true;                 \
1993     node = grandparent;                      \
1994   } else {                                   \
1995     if (node == parent->trans) {             \
1996       tree__rotate_##cis(tree, parent);      \
1997       node = parent;                         \
1998       parent = node->parent;                 \
1999     }                                        \
2000     parent->red = false;                     \
2001     grandparent->red = true;                 \
2002     tree__rotate_##trans(tree, grandparent); \
2003   }
2004 
tree_add(tree_t * tree,tree_node_t * node,uintptr_t key)2005 int tree_add(tree_t* tree, tree_node_t* node, uintptr_t key) {
2006   tree_node_t* parent;
2007 
2008   parent = tree->root;
2009   if (parent) {
2010     for (;;) {
2011       if (key < parent->key) {
2012         TREE__INSERT_OR_DESCEND(left)
2013       } else if (key > parent->key) {
2014         TREE__INSERT_OR_DESCEND(right)
2015       } else {
2016         return -1;
2017       }
2018     }
2019   } else {
2020     tree->root = node;
2021   }
2022 
2023   node->key = key;
2024   node->left = node->right = NULL;
2025   node->parent = parent;
2026   node->red = true;
2027 
2028   for (; parent && parent->red; parent = node->parent) {
2029     if (parent == parent->parent->left) {
2030       TREE__FIXUP_AFTER_INSERT(left, right)
2031     } else {
2032       TREE__FIXUP_AFTER_INSERT(right, left)
2033     }
2034   }
2035   tree->root->red = false;
2036 
2037   return 0;
2038 }
2039 
2040 #define TREE__FIXUP_AFTER_REMOVE(cis, trans)       \
2041   tree_node_t* sibling = parent->trans;            \
2042                                                    \
2043   if (sibling->red) {                              \
2044     sibling->red = false;                          \
2045     parent->red = true;                            \
2046     tree__rotate_##cis(tree, parent);              \
2047     sibling = parent->trans;                       \
2048   }                                                \
2049   if ((sibling->left && sibling->left->red) ||     \
2050       (sibling->right && sibling->right->red)) {   \
2051     if (!sibling->trans || !sibling->trans->red) { \
2052       sibling->cis->red = false;                   \
2053       sibling->red = true;                         \
2054       tree__rotate_##trans(tree, sibling);         \
2055       sibling = parent->trans;                     \
2056     }                                              \
2057     sibling->red = parent->red;                    \
2058     parent->red = sibling->trans->red = false;     \
2059     tree__rotate_##cis(tree, parent);              \
2060     node = tree->root;                             \
2061     break;                                         \
2062   }                                                \
2063   sibling->red = true;
2064 
tree_del(tree_t * tree,tree_node_t * node)2065 void tree_del(tree_t* tree, tree_node_t* node) {
2066   tree_node_t* parent = node->parent;
2067   tree_node_t* left = node->left;
2068   tree_node_t* right = node->right;
2069   tree_node_t* next;
2070   bool red;
2071 
2072   if (!left) {
2073     next = right;
2074   } else if (!right) {
2075     next = left;
2076   } else {
2077     next = right;
2078     while (next->left)
2079       next = next->left;
2080   }
2081 
2082   if (parent) {
2083     if (parent->left == node)
2084       parent->left = next;
2085     else
2086       parent->right = next;
2087   } else {
2088     tree->root = next;
2089   }
2090 
2091   if (left && right) {
2092     red = next->red;
2093     next->red = node->red;
2094     next->left = left;
2095     left->parent = next;
2096     if (next != right) {
2097       parent = next->parent;
2098       next->parent = node->parent;
2099       node = next->right;
2100       parent->left = node;
2101       next->right = right;
2102       right->parent = next;
2103     } else {
2104       next->parent = parent;
2105       parent = next;
2106       node = next->right;
2107     }
2108   } else {
2109     red = node->red;
2110     node = next;
2111   }
2112 
2113   if (node)
2114     node->parent = parent;
2115   if (red)
2116     return;
2117   if (node && node->red) {
2118     node->red = false;
2119     return;
2120   }
2121 
2122   do {
2123     if (node == tree->root)
2124       break;
2125     if (node == parent->left) {
2126       TREE__FIXUP_AFTER_REMOVE(left, right)
2127     } else {
2128       TREE__FIXUP_AFTER_REMOVE(right, left)
2129     }
2130     node = parent;
2131     parent = parent->parent;
2132   } while (!node->red);
2133 
2134   if (node)
2135     node->red = false;
2136 }
2137 
tree_find(const tree_t * tree,uintptr_t key)2138 tree_node_t* tree_find(const tree_t* tree, uintptr_t key) {
2139   tree_node_t* node = tree->root;
2140   while (node) {
2141     if (key < node->key)
2142       node = node->left;
2143     else if (key > node->key)
2144       node = node->right;
2145     else
2146       return node;
2147   }
2148   return NULL;
2149 }
2150 
tree_root(const tree_t * tree)2151 tree_node_t* tree_root(const tree_t* tree) {
2152   return tree->root;
2153 }
2154 
2155 #ifndef SIO_BASE_HANDLE
2156 #define SIO_BASE_HANDLE 0x48000022
2157 #endif
2158 
ws_global_init(void)2159 int ws_global_init(void) {
2160   int r;
2161   WSADATA wsa_data;
2162 
2163   r = WSAStartup(MAKEWORD(2, 2), &wsa_data);
2164   if (r != 0)
2165     return_set_error(-1, (DWORD) r);
2166 
2167   return 0;
2168 }
2169 
ws_get_base_socket(SOCKET socket)2170 SOCKET ws_get_base_socket(SOCKET socket) {
2171   SOCKET base_socket;
2172   DWORD bytes;
2173 
2174   if (WSAIoctl(socket,
2175                SIO_BASE_HANDLE,
2176                NULL,
2177                0,
2178                &base_socket,
2179                sizeof base_socket,
2180                &bytes,
2181                NULL,
2182                NULL) == SOCKET_ERROR)
2183     return_map_error(INVALID_SOCKET);
2184 
2185   return base_socket;
2186 }
2187