1 /*
2  * wepoll - epoll for Windows
3  * https://github.com/piscisaureus/wepoll
4  *
5  * Copyright 2012-2020, Bert Belder <bertbelder@gmail.com>
6  * All rights reserved.
7  *
8  * Redistribution and use in source and binary forms, with or without
9  * modification, are permitted provided that the following conditions are
10  * met:
11  *
12  *   * Redistributions of source code must retain the above copyright
13  *     notice, this list of conditions and the following disclaimer.
14  *
15  *   * Redistributions in binary form must reproduce the above copyright
16  *     notice, this list of conditions and the following disclaimer in the
17  *     documentation and/or other materials provided with the distribution.
18  *
19  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22  * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23  * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30  */
31 
32 #ifndef WEPOLL_EXPORT
33 #define WEPOLL_EXPORT
34 #endif
35 
36 #include <stdint.h>
37 
38 enum EPOLL_EVENTS {
39   EPOLLIN      = (int) (1U <<  0),
40   EPOLLPRI     = (int) (1U <<  1),
41   EPOLLOUT     = (int) (1U <<  2),
42   EPOLLERR     = (int) (1U <<  3),
43   EPOLLHUP     = (int) (1U <<  4),
44   EPOLLRDNORM  = (int) (1U <<  6),
45   EPOLLRDBAND  = (int) (1U <<  7),
46   EPOLLWRNORM  = (int) (1U <<  8),
47   EPOLLWRBAND  = (int) (1U <<  9),
48   EPOLLMSG     = (int) (1U << 10), /* Never reported. */
49   EPOLLRDHUP   = (int) (1U << 13),
50   EPOLLONESHOT = (int) (1U << 31)
51 };
52 
53 #define EPOLLIN      (1U <<  0)
54 #define EPOLLPRI     (1U <<  1)
55 #define EPOLLOUT     (1U <<  2)
56 #define EPOLLERR     (1U <<  3)
57 #define EPOLLHUP     (1U <<  4)
58 #define EPOLLRDNORM  (1U <<  6)
59 #define EPOLLRDBAND  (1U <<  7)
60 #define EPOLLWRNORM  (1U <<  8)
61 #define EPOLLWRBAND  (1U <<  9)
62 #define EPOLLMSG     (1U << 10)
63 #define EPOLLRDHUP   (1U << 13)
64 #define EPOLLONESHOT (1U << 31)
65 
66 #define EPOLL_CTL_ADD 1
67 #define EPOLL_CTL_MOD 2
68 #define EPOLL_CTL_DEL 3
69 
70 typedef void* HANDLE;
71 typedef uintptr_t SOCKET;
72 
73 typedef union epoll_data {
74   void* ptr;
75   int fd;
76   uint32_t u32;
77   uint64_t u64;
78   SOCKET sock; /* Windows specific */
79   HANDLE hnd;  /* Windows specific */
80 } epoll_data_t;
81 
82 struct epoll_event {
83   uint32_t events;   /* Epoll events and flags */
84   epoll_data_t data; /* User data variable */
85 };
86 
87 #ifdef __cplusplus
88 extern "C" {
89 #endif
90 
91 WEPOLL_EXPORT HANDLE epoll_create(int size);
92 WEPOLL_EXPORT HANDLE epoll_create1(int flags);
93 
94 WEPOLL_EXPORT int epoll_close(HANDLE ephnd);
95 
96 WEPOLL_EXPORT int epoll_ctl(HANDLE ephnd,
97                             int op,
98                             SOCKET sock,
99                             struct epoll_event* event);
100 
101 WEPOLL_EXPORT int epoll_wait(HANDLE ephnd,
102                              struct epoll_event* events,
103                              int maxevents,
104                              int timeout);
105 
106 #ifdef __cplusplus
107 } /* extern "C" */
108 #endif
109 
110 #include <assert.h>
111 
112 #include <stdlib.h>
113 
114 #define WEPOLL_INTERNAL static
115 #define WEPOLL_INTERNAL_EXTERN static
116 
117 #if defined(__clang__)
118 #pragma clang diagnostic push
119 #pragma clang diagnostic ignored "-Wnonportable-system-include-path"
120 #pragma clang diagnostic ignored "-Wreserved-id-macro"
121 #elif defined(_MSC_VER)
122 #pragma warning(push, 1)
123 #endif
124 
125 #undef WIN32_LEAN_AND_MEAN
126 #define WIN32_LEAN_AND_MEAN
127 
128 #undef _WIN32_WINNT
129 #define _WIN32_WINNT 0x0600
130 
131 #include <winsock2.h>
132 #include <ws2tcpip.h>
133 #include <windows.h>
134 
135 #if defined(__clang__)
136 #pragma clang diagnostic pop
137 #elif defined(_MSC_VER)
138 #pragma warning(pop)
139 #endif
140 
141 WEPOLL_INTERNAL int nt_global_init(void);
142 
143 typedef LONG NTSTATUS;
144 typedef NTSTATUS* PNTSTATUS;
145 
146 #ifndef NT_SUCCESS
147 #define NT_SUCCESS(status) (((NTSTATUS)(status)) >= 0)
148 #endif
149 
150 #ifndef STATUS_SUCCESS
151 #define STATUS_SUCCESS ((NTSTATUS) 0x00000000L)
152 #endif
153 
154 #ifndef STATUS_PENDING
155 #define STATUS_PENDING ((NTSTATUS) 0x00000103L)
156 #endif
157 
158 #ifndef STATUS_CANCELLED
159 #define STATUS_CANCELLED ((NTSTATUS) 0xC0000120L)
160 #endif
161 
162 #ifndef STATUS_NOT_FOUND
163 #define STATUS_NOT_FOUND ((NTSTATUS) 0xC0000225L)
164 #endif
165 
166 typedef struct _IO_STATUS_BLOCK {
167   NTSTATUS Status;
168   ULONG_PTR Information;
169 } IO_STATUS_BLOCK, *PIO_STATUS_BLOCK;
170 
171 typedef VOID(NTAPI* PIO_APC_ROUTINE)(PVOID ApcContext,
172                                      PIO_STATUS_BLOCK IoStatusBlock,
173                                      ULONG Reserved);
174 
175 typedef struct _UNICODE_STRING {
176   USHORT Length;
177   USHORT MaximumLength;
178   PWSTR Buffer;
179 } UNICODE_STRING, *PUNICODE_STRING;
180 
181 #define RTL_CONSTANT_STRING(s) \
182   { sizeof(s) - sizeof((s)[0]), sizeof(s), s }
183 
184 typedef struct _OBJECT_ATTRIBUTES {
185   ULONG Length;
186   HANDLE RootDirectory;
187   PUNICODE_STRING ObjectName;
188   ULONG Attributes;
189   PVOID SecurityDescriptor;
190   PVOID SecurityQualityOfService;
191 } OBJECT_ATTRIBUTES, *POBJECT_ATTRIBUTES;
192 
193 #define RTL_CONSTANT_OBJECT_ATTRIBUTES(ObjectName, Attributes) \
194   { sizeof(OBJECT_ATTRIBUTES), NULL, ObjectName, Attributes, NULL, NULL }
195 
196 #ifndef FILE_OPEN
197 #define FILE_OPEN 0x00000001UL
198 #endif
199 
200 #define KEYEDEVENT_WAIT 0x00000001UL
201 #define KEYEDEVENT_WAKE 0x00000002UL
202 #define KEYEDEVENT_ALL_ACCESS \
203   (STANDARD_RIGHTS_REQUIRED | KEYEDEVENT_WAIT | KEYEDEVENT_WAKE)
204 
205 #define NT_NTDLL_IMPORT_LIST(X)           \
206   X(NTSTATUS,                             \
207     NTAPI,                                \
208     NtCancelIoFileEx,                     \
209     (HANDLE FileHandle,                   \
210      PIO_STATUS_BLOCK IoRequestToCancel,  \
211      PIO_STATUS_BLOCK IoStatusBlock))     \
212                                           \
213   X(NTSTATUS,                             \
214     NTAPI,                                \
215     NtCreateFile,                         \
216     (PHANDLE FileHandle,                  \
217      ACCESS_MASK DesiredAccess,           \
218      POBJECT_ATTRIBUTES ObjectAttributes, \
219      PIO_STATUS_BLOCK IoStatusBlock,      \
220      PLARGE_INTEGER AllocationSize,       \
221      ULONG FileAttributes,                \
222      ULONG ShareAccess,                   \
223      ULONG CreateDisposition,             \
224      ULONG CreateOptions,                 \
225      PVOID EaBuffer,                      \
226      ULONG EaLength))                     \
227                                           \
228   X(NTSTATUS,                             \
229     NTAPI,                                \
230     NtCreateKeyedEvent,                   \
231     (PHANDLE KeyedEventHandle,            \
232      ACCESS_MASK DesiredAccess,           \
233      POBJECT_ATTRIBUTES ObjectAttributes, \
234      ULONG Flags))                        \
235                                           \
236   X(NTSTATUS,                             \
237     NTAPI,                                \
238     NtDeviceIoControlFile,                \
239     (HANDLE FileHandle,                   \
240      HANDLE Event,                        \
241      PIO_APC_ROUTINE ApcRoutine,          \
242      PVOID ApcContext,                    \
243      PIO_STATUS_BLOCK IoStatusBlock,      \
244      ULONG IoControlCode,                 \
245      PVOID InputBuffer,                   \
246      ULONG InputBufferLength,             \
247      PVOID OutputBuffer,                  \
248      ULONG OutputBufferLength))           \
249                                           \
250   X(NTSTATUS,                             \
251     NTAPI,                                \
252     NtReleaseKeyedEvent,                  \
253     (HANDLE KeyedEventHandle,             \
254      PVOID KeyValue,                      \
255      BOOLEAN Alertable,                   \
256      PLARGE_INTEGER Timeout))             \
257                                           \
258   X(NTSTATUS,                             \
259     NTAPI,                                \
260     NtWaitForKeyedEvent,                  \
261     (HANDLE KeyedEventHandle,             \
262      PVOID KeyValue,                      \
263      BOOLEAN Alertable,                   \
264      PLARGE_INTEGER Timeout))             \
265                                           \
266   X(ULONG, WINAPI, RtlNtStatusToDosError, (NTSTATUS Status))
267 
268 #define X(return_type, attributes, name, parameters) \
269   WEPOLL_INTERNAL_EXTERN return_type(attributes* name) parameters;
270 NT_NTDLL_IMPORT_LIST(X)
271 #undef X
272 
273 #define AFD_POLL_RECEIVE           0x0001
274 #define AFD_POLL_RECEIVE_EXPEDITED 0x0002
275 #define AFD_POLL_SEND              0x0004
276 #define AFD_POLL_DISCONNECT        0x0008
277 #define AFD_POLL_ABORT             0x0010
278 #define AFD_POLL_LOCAL_CLOSE       0x0020
279 #define AFD_POLL_ACCEPT            0x0080
280 #define AFD_POLL_CONNECT_FAIL      0x0100
281 
282 typedef struct _AFD_POLL_HANDLE_INFO {
283   HANDLE Handle;
284   ULONG Events;
285   NTSTATUS Status;
286 } AFD_POLL_HANDLE_INFO, *PAFD_POLL_HANDLE_INFO;
287 
288 typedef struct _AFD_POLL_INFO {
289   LARGE_INTEGER Timeout;
290   ULONG NumberOfHandles;
291   ULONG Exclusive;
292   AFD_POLL_HANDLE_INFO Handles[1];
293 } AFD_POLL_INFO, *PAFD_POLL_INFO;
294 
295 WEPOLL_INTERNAL int afd_create_device_handle(HANDLE iocp_handle,
296                                              HANDLE* afd_device_handle_out);
297 
298 WEPOLL_INTERNAL int afd_poll(HANDLE afd_device_handle,
299                              AFD_POLL_INFO* poll_info,
300                              IO_STATUS_BLOCK* io_status_block);
301 WEPOLL_INTERNAL int afd_cancel_poll(HANDLE afd_device_handle,
302                                     IO_STATUS_BLOCK* io_status_block);
303 
304 #define return_map_error(value) \
305   do {                          \
306     err_map_win_error();        \
307     return (value);             \
308   } while (0)
309 
310 #define return_set_error(value, error) \
311   do {                                 \
312     err_set_win_error(error);          \
313     return (value);                    \
314   } while (0)
315 
316 WEPOLL_INTERNAL void err_map_win_error(void);
317 WEPOLL_INTERNAL void err_set_win_error(DWORD error);
318 WEPOLL_INTERNAL int err_check_handle(HANDLE handle);
319 
320 #define IOCTL_AFD_POLL 0x00012024
321 
322 static UNICODE_STRING afd__device_name =
323     RTL_CONSTANT_STRING(L"\\Device\\Afd\\Wepoll");
324 
325 static OBJECT_ATTRIBUTES afd__device_attributes =
326     RTL_CONSTANT_OBJECT_ATTRIBUTES(&afd__device_name, 0);
327 
afd_create_device_handle(HANDLE iocp_handle,HANDLE * afd_device_handle_out)328 int afd_create_device_handle(HANDLE iocp_handle,
329                              HANDLE* afd_device_handle_out) {
330   HANDLE afd_device_handle;
331   IO_STATUS_BLOCK iosb;
332   NTSTATUS status;
333 
334   /* By opening \Device\Afd without specifying any extended attributes, we'll
335    * get a handle that lets us talk to the AFD driver, but that doesn't have an
336    * associated endpoint (so it's not a socket). */
337   status = NtCreateFile(&afd_device_handle,
338                         SYNCHRONIZE,
339                         &afd__device_attributes,
340                         &iosb,
341                         NULL,
342                         0,
343                         FILE_SHARE_READ | FILE_SHARE_WRITE,
344                         FILE_OPEN,
345                         0,
346                         NULL,
347                         0);
348   if (status != STATUS_SUCCESS)
349     return_set_error(-1, RtlNtStatusToDosError(status));
350 
351   if (CreateIoCompletionPort(afd_device_handle, iocp_handle, 0, 0) == NULL)
352     goto error;
353 
354   if (!SetFileCompletionNotificationModes(afd_device_handle,
355                                           FILE_SKIP_SET_EVENT_ON_HANDLE))
356     goto error;
357 
358   *afd_device_handle_out = afd_device_handle;
359   return 0;
360 
361 error:
362   CloseHandle(afd_device_handle);
363   return_map_error(-1);
364 }
365 
afd_poll(HANDLE afd_device_handle,AFD_POLL_INFO * poll_info,IO_STATUS_BLOCK * io_status_block)366 int afd_poll(HANDLE afd_device_handle,
367              AFD_POLL_INFO* poll_info,
368              IO_STATUS_BLOCK* io_status_block) {
369   NTSTATUS status;
370 
371   /* Blocking operation is not supported. */
372   assert(io_status_block != NULL);
373 
374   io_status_block->Status = STATUS_PENDING;
375   status = NtDeviceIoControlFile(afd_device_handle,
376                                  NULL,
377                                  NULL,
378                                  io_status_block,
379                                  io_status_block,
380                                  IOCTL_AFD_POLL,
381                                  poll_info,
382                                  sizeof *poll_info,
383                                  poll_info,
384                                  sizeof *poll_info);
385 
386   if (status == STATUS_SUCCESS)
387     return 0;
388   else if (status == STATUS_PENDING)
389     return_set_error(-1, ERROR_IO_PENDING);
390   else
391     return_set_error(-1, RtlNtStatusToDosError(status));
392 }
393 
afd_cancel_poll(HANDLE afd_device_handle,IO_STATUS_BLOCK * io_status_block)394 int afd_cancel_poll(HANDLE afd_device_handle,
395                     IO_STATUS_BLOCK* io_status_block) {
396   NTSTATUS cancel_status;
397   IO_STATUS_BLOCK cancel_iosb;
398 
399   /* If the poll operation has already completed or has been cancelled earlier,
400    * there's nothing left for us to do. */
401   if (io_status_block->Status != STATUS_PENDING)
402     return 0;
403 
404   cancel_status =
405       NtCancelIoFileEx(afd_device_handle, io_status_block, &cancel_iosb);
406 
407   /* NtCancelIoFileEx() may return STATUS_NOT_FOUND if the operation completed
408    * just before calling NtCancelIoFileEx(). This is not an error. */
409   if (cancel_status == STATUS_SUCCESS || cancel_status == STATUS_NOT_FOUND)
410     return 0;
411   else
412     return_set_error(-1, RtlNtStatusToDosError(cancel_status));
413 }
414 
415 WEPOLL_INTERNAL int epoll_global_init(void);
416 
417 WEPOLL_INTERNAL int init(void);
418 
419 typedef struct port_state port_state_t;
420 typedef struct queue queue_t;
421 typedef struct sock_state sock_state_t;
422 typedef struct ts_tree_node ts_tree_node_t;
423 
424 WEPOLL_INTERNAL port_state_t* port_new(HANDLE* iocp_handle_out);
425 WEPOLL_INTERNAL int port_close(port_state_t* port_state);
426 WEPOLL_INTERNAL int port_delete(port_state_t* port_state);
427 
428 WEPOLL_INTERNAL int port_wait(port_state_t* port_state,
429                               struct epoll_event* events,
430                               int maxevents,
431                               int timeout);
432 
433 WEPOLL_INTERNAL int port_ctl(port_state_t* port_state,
434                              int op,
435                              SOCKET sock,
436                              struct epoll_event* ev);
437 
438 WEPOLL_INTERNAL int port_register_socket(port_state_t* port_state,
439                                          sock_state_t* sock_state,
440                                          SOCKET socket);
441 WEPOLL_INTERNAL void port_unregister_socket(port_state_t* port_state,
442                                             sock_state_t* sock_state);
443 WEPOLL_INTERNAL sock_state_t* port_find_socket(port_state_t* port_state,
444                                                SOCKET socket);
445 
446 WEPOLL_INTERNAL void port_request_socket_update(port_state_t* port_state,
447                                                 sock_state_t* sock_state);
448 WEPOLL_INTERNAL void port_cancel_socket_update(port_state_t* port_state,
449                                                sock_state_t* sock_state);
450 
451 WEPOLL_INTERNAL void port_add_deleted_socket(port_state_t* port_state,
452                                              sock_state_t* sock_state);
453 WEPOLL_INTERNAL void port_remove_deleted_socket(port_state_t* port_state,
454                                                 sock_state_t* sock_state);
455 
456 WEPOLL_INTERNAL HANDLE port_get_iocp_handle(port_state_t* port_state);
457 WEPOLL_INTERNAL queue_t* port_get_poll_group_queue(port_state_t* port_state);
458 
459 WEPOLL_INTERNAL port_state_t* port_state_from_handle_tree_node(
460     ts_tree_node_t* tree_node);
461 WEPOLL_INTERNAL ts_tree_node_t* port_state_to_handle_tree_node(
462     port_state_t* port_state);
463 
464 /* The reflock is a special kind of lock that normally prevents a chunk of
465  * memory from being freed, but does allow the chunk of memory to eventually be
466  * released in a coordinated fashion.
467  *
468  * Under normal operation, threads increase and decrease the reference count,
469  * which are wait-free operations.
470  *
471  * Exactly once during the reflock's lifecycle, a thread holding a reference to
472  * the lock may "destroy" the lock; this operation blocks until all other
473  * threads holding a reference to the lock have dereferenced it. After
474  * "destroy" returns, the calling thread may assume that no other threads have
475  * a reference to the lock.
476  *
477  * Attemmpting to lock or destroy a lock after reflock_unref_and_destroy() has
478  * been called is invalid and results in undefined behavior. Therefore the user
479  * should use another lock to guarantee that this can't happen.
480  */
481 
482 typedef struct reflock {
483   volatile long state; /* 32-bit Interlocked APIs operate on `long` values. */
484 } reflock_t;
485 
486 WEPOLL_INTERNAL int reflock_global_init(void);
487 
488 WEPOLL_INTERNAL void reflock_init(reflock_t* reflock);
489 WEPOLL_INTERNAL void reflock_ref(reflock_t* reflock);
490 WEPOLL_INTERNAL void reflock_unref(reflock_t* reflock);
491 WEPOLL_INTERNAL void reflock_unref_and_destroy(reflock_t* reflock);
492 
493 #include <stdbool.h>
494 
495 /* N.b.: the tree functions do not set errno or LastError when they fail. Each
496  * of the API functions has at most one failure mode. It is up to the caller to
497  * set an appropriate error code when necessary. */
498 
499 typedef struct tree tree_t;
500 typedef struct tree_node tree_node_t;
501 
502 typedef struct tree {
503   tree_node_t* root;
504 } tree_t;
505 
506 typedef struct tree_node {
507   tree_node_t* left;
508   tree_node_t* right;
509   tree_node_t* parent;
510   uintptr_t key;
511   bool red;
512 } tree_node_t;
513 
514 WEPOLL_INTERNAL void tree_init(tree_t* tree);
515 WEPOLL_INTERNAL void tree_node_init(tree_node_t* node);
516 
517 WEPOLL_INTERNAL int tree_add(tree_t* tree, tree_node_t* node, uintptr_t key);
518 WEPOLL_INTERNAL void tree_del(tree_t* tree, tree_node_t* node);
519 
520 WEPOLL_INTERNAL tree_node_t* tree_find(const tree_t* tree, uintptr_t key);
521 WEPOLL_INTERNAL tree_node_t* tree_root(const tree_t* tree);
522 
523 typedef struct ts_tree {
524   tree_t tree;
525   SRWLOCK lock;
526 } ts_tree_t;
527 
528 typedef struct ts_tree_node {
529   tree_node_t tree_node;
530   reflock_t reflock;
531 } ts_tree_node_t;
532 
533 WEPOLL_INTERNAL void ts_tree_init(ts_tree_t* rtl);
534 WEPOLL_INTERNAL void ts_tree_node_init(ts_tree_node_t* node);
535 
536 WEPOLL_INTERNAL int ts_tree_add(ts_tree_t* ts_tree,
537                                 ts_tree_node_t* node,
538                                 uintptr_t key);
539 
540 WEPOLL_INTERNAL ts_tree_node_t* ts_tree_del_and_ref(ts_tree_t* ts_tree,
541                                                     uintptr_t key);
542 WEPOLL_INTERNAL ts_tree_node_t* ts_tree_find_and_ref(ts_tree_t* ts_tree,
543                                                      uintptr_t key);
544 
545 WEPOLL_INTERNAL void ts_tree_node_unref(ts_tree_node_t* node);
546 WEPOLL_INTERNAL void ts_tree_node_unref_and_destroy(ts_tree_node_t* node);
547 
548 static ts_tree_t epoll__handle_tree;
549 
epoll_global_init(void)550 int epoll_global_init(void) {
551   ts_tree_init(&epoll__handle_tree);
552   return 0;
553 }
554 
epoll__create(void)555 static HANDLE epoll__create(void) {
556   port_state_t* port_state;
557   HANDLE ephnd;
558   ts_tree_node_t* tree_node;
559 
560   if (init() < 0)
561     return NULL;
562 
563   port_state = port_new(&ephnd);
564   if (port_state == NULL)
565     return NULL;
566 
567   tree_node = port_state_to_handle_tree_node(port_state);
568   if (ts_tree_add(&epoll__handle_tree, tree_node, (uintptr_t) ephnd) < 0) {
569     /* This should never happen. */
570     port_delete(port_state);
571     return_set_error(NULL, ERROR_ALREADY_EXISTS);
572   }
573 
574   return ephnd;
575 }
576 
epoll_create(int size)577 HANDLE epoll_create(int size) {
578   if (size <= 0)
579     return_set_error(NULL, ERROR_INVALID_PARAMETER);
580 
581   return epoll__create();
582 }
583 
epoll_create1(int flags)584 HANDLE epoll_create1(int flags) {
585   if (flags != 0)
586     return_set_error(NULL, ERROR_INVALID_PARAMETER);
587 
588   return epoll__create();
589 }
590 
epoll_close(HANDLE ephnd)591 int epoll_close(HANDLE ephnd) {
592   ts_tree_node_t* tree_node;
593   port_state_t* port_state;
594 
595   if (init() < 0)
596     return -1;
597 
598   tree_node = ts_tree_del_and_ref(&epoll__handle_tree, (uintptr_t) ephnd);
599   if (tree_node == NULL) {
600     err_set_win_error(ERROR_INVALID_PARAMETER);
601     goto err;
602   }
603 
604   port_state = port_state_from_handle_tree_node(tree_node);
605   port_close(port_state);
606 
607   ts_tree_node_unref_and_destroy(tree_node);
608 
609   return port_delete(port_state);
610 
611 err:
612   err_check_handle(ephnd);
613   return -1;
614 }
615 
epoll_ctl(HANDLE ephnd,int op,SOCKET sock,struct epoll_event * ev)616 int epoll_ctl(HANDLE ephnd, int op, SOCKET sock, struct epoll_event* ev) {
617   ts_tree_node_t* tree_node;
618   port_state_t* port_state;
619   int r;
620 
621   if (init() < 0)
622     return -1;
623 
624   tree_node = ts_tree_find_and_ref(&epoll__handle_tree, (uintptr_t) ephnd);
625   if (tree_node == NULL) {
626     err_set_win_error(ERROR_INVALID_PARAMETER);
627     goto err;
628   }
629 
630   port_state = port_state_from_handle_tree_node(tree_node);
631   r = port_ctl(port_state, op, sock, ev);
632 
633   ts_tree_node_unref(tree_node);
634 
635   if (r < 0)
636     goto err;
637 
638   return 0;
639 
640 err:
641   /* On Linux, in the case of epoll_ctl(), EBADF takes priority over other
642    * errors. Wepoll mimics this behavior. */
643   err_check_handle(ephnd);
644   err_check_handle((HANDLE) sock);
645   return -1;
646 }
647 
epoll_wait(HANDLE ephnd,struct epoll_event * events,int maxevents,int timeout)648 int epoll_wait(HANDLE ephnd,
649                struct epoll_event* events,
650                int maxevents,
651                int timeout) {
652   ts_tree_node_t* tree_node;
653   port_state_t* port_state;
654   int num_events;
655 
656   if (maxevents <= 0)
657     return_set_error(-1, ERROR_INVALID_PARAMETER);
658 
659   if (init() < 0)
660     return -1;
661 
662   tree_node = ts_tree_find_and_ref(&epoll__handle_tree, (uintptr_t) ephnd);
663   if (tree_node == NULL) {
664     err_set_win_error(ERROR_INVALID_PARAMETER);
665     goto err;
666   }
667 
668   port_state = port_state_from_handle_tree_node(tree_node);
669   num_events = port_wait(port_state, events, maxevents, timeout);
670 
671   ts_tree_node_unref(tree_node);
672 
673   if (num_events < 0)
674     goto err;
675 
676   return num_events;
677 
678 err:
679   err_check_handle(ephnd);
680   return -1;
681 }
682 
683 #include <errno.h>
684 
685 #define ERR__ERRNO_MAPPINGS(X)               \
686   X(ERROR_ACCESS_DENIED, EACCES)             \
687   X(ERROR_ALREADY_EXISTS, EEXIST)            \
688   X(ERROR_BAD_COMMAND, EACCES)               \
689   X(ERROR_BAD_EXE_FORMAT, ENOEXEC)           \
690   X(ERROR_BAD_LENGTH, EACCES)                \
691   X(ERROR_BAD_NETPATH, ENOENT)               \
692   X(ERROR_BAD_NET_NAME, ENOENT)              \
693   X(ERROR_BAD_NET_RESP, ENETDOWN)            \
694   X(ERROR_BAD_PATHNAME, ENOENT)              \
695   X(ERROR_BROKEN_PIPE, EPIPE)                \
696   X(ERROR_CANNOT_MAKE, EACCES)               \
697   X(ERROR_COMMITMENT_LIMIT, ENOMEM)          \
698   X(ERROR_CONNECTION_ABORTED, ECONNABORTED)  \
699   X(ERROR_CONNECTION_ACTIVE, EISCONN)        \
700   X(ERROR_CONNECTION_REFUSED, ECONNREFUSED)  \
701   X(ERROR_CRC, EACCES)                       \
702   X(ERROR_DIR_NOT_EMPTY, ENOTEMPTY)          \
703   X(ERROR_DISK_FULL, ENOSPC)                 \
704   X(ERROR_DUP_NAME, EADDRINUSE)              \
705   X(ERROR_FILENAME_EXCED_RANGE, ENOENT)      \
706   X(ERROR_FILE_NOT_FOUND, ENOENT)            \
707   X(ERROR_GEN_FAILURE, EACCES)               \
708   X(ERROR_GRACEFUL_DISCONNECT, EPIPE)        \
709   X(ERROR_HOST_DOWN, EHOSTUNREACH)           \
710   X(ERROR_HOST_UNREACHABLE, EHOSTUNREACH)    \
711   X(ERROR_INSUFFICIENT_BUFFER, EFAULT)       \
712   X(ERROR_INVALID_ADDRESS, EADDRNOTAVAIL)    \
713   X(ERROR_INVALID_FUNCTION, EINVAL)          \
714   X(ERROR_INVALID_HANDLE, EBADF)             \
715   X(ERROR_INVALID_NETNAME, EADDRNOTAVAIL)    \
716   X(ERROR_INVALID_PARAMETER, EINVAL)         \
717   X(ERROR_INVALID_USER_BUFFER, EMSGSIZE)     \
718   X(ERROR_IO_PENDING, EINPROGRESS)           \
719   X(ERROR_LOCK_VIOLATION, EACCES)            \
720   X(ERROR_MORE_DATA, EMSGSIZE)               \
721   X(ERROR_NETNAME_DELETED, ECONNABORTED)     \
722   X(ERROR_NETWORK_ACCESS_DENIED, EACCES)     \
723   X(ERROR_NETWORK_BUSY, ENETDOWN)            \
724   X(ERROR_NETWORK_UNREACHABLE, ENETUNREACH)  \
725   X(ERROR_NOACCESS, EFAULT)                  \
726   X(ERROR_NONPAGED_SYSTEM_RESOURCES, ENOMEM) \
727   X(ERROR_NOT_ENOUGH_MEMORY, ENOMEM)         \
728   X(ERROR_NOT_ENOUGH_QUOTA, ENOMEM)          \
729   X(ERROR_NOT_FOUND, ENOENT)                 \
730   X(ERROR_NOT_LOCKED, EACCES)                \
731   X(ERROR_NOT_READY, EACCES)                 \
732   X(ERROR_NOT_SAME_DEVICE, EXDEV)            \
733   X(ERROR_NOT_SUPPORTED, ENOTSUP)            \
734   X(ERROR_NO_MORE_FILES, ENOENT)             \
735   X(ERROR_NO_SYSTEM_RESOURCES, ENOMEM)       \
736   X(ERROR_OPERATION_ABORTED, EINTR)          \
737   X(ERROR_OUT_OF_PAPER, EACCES)              \
738   X(ERROR_PAGED_SYSTEM_RESOURCES, ENOMEM)    \
739   X(ERROR_PAGEFILE_QUOTA, ENOMEM)            \
740   X(ERROR_PATH_NOT_FOUND, ENOENT)            \
741   X(ERROR_PIPE_NOT_CONNECTED, EPIPE)         \
742   X(ERROR_PORT_UNREACHABLE, ECONNRESET)      \
743   X(ERROR_PROTOCOL_UNREACHABLE, ENETUNREACH) \
744   X(ERROR_REM_NOT_LIST, ECONNREFUSED)        \
745   X(ERROR_REQUEST_ABORTED, EINTR)            \
746   X(ERROR_REQ_NOT_ACCEP, EWOULDBLOCK)        \
747   X(ERROR_SECTOR_NOT_FOUND, EACCES)          \
748   X(ERROR_SEM_TIMEOUT, ETIMEDOUT)            \
749   X(ERROR_SHARING_VIOLATION, EACCES)         \
750   X(ERROR_TOO_MANY_NAMES, ENOMEM)            \
751   X(ERROR_TOO_MANY_OPEN_FILES, EMFILE)       \
752   X(ERROR_UNEXP_NET_ERR, ECONNABORTED)       \
753   X(ERROR_WAIT_NO_CHILDREN, ECHILD)          \
754   X(ERROR_WORKING_SET_QUOTA, ENOMEM)         \
755   X(ERROR_WRITE_PROTECT, EACCES)             \
756   X(ERROR_WRONG_DISK, EACCES)                \
757   X(WSAEACCES, EACCES)                       \
758   X(WSAEADDRINUSE, EADDRINUSE)               \
759   X(WSAEADDRNOTAVAIL, EADDRNOTAVAIL)         \
760   X(WSAEAFNOSUPPORT, EAFNOSUPPORT)           \
761   X(WSAECONNABORTED, ECONNABORTED)           \
762   X(WSAECONNREFUSED, ECONNREFUSED)           \
763   X(WSAECONNRESET, ECONNRESET)               \
764   X(WSAEDISCON, EPIPE)                       \
765   X(WSAEFAULT, EFAULT)                       \
766   X(WSAEHOSTDOWN, EHOSTUNREACH)              \
767   X(WSAEHOSTUNREACH, EHOSTUNREACH)           \
768   X(WSAEINPROGRESS, EBUSY)                   \
769   X(WSAEINTR, EINTR)                         \
770   X(WSAEINVAL, EINVAL)                       \
771   X(WSAEISCONN, EISCONN)                     \
772   X(WSAEMSGSIZE, EMSGSIZE)                   \
773   X(WSAENETDOWN, ENETDOWN)                   \
774   X(WSAENETRESET, EHOSTUNREACH)              \
775   X(WSAENETUNREACH, ENETUNREACH)             \
776   X(WSAENOBUFS, ENOMEM)                      \
777   X(WSAENOTCONN, ENOTCONN)                   \
778   X(WSAENOTSOCK, ENOTSOCK)                   \
779   X(WSAEOPNOTSUPP, EOPNOTSUPP)               \
780   X(WSAEPROCLIM, ENOMEM)                     \
781   X(WSAESHUTDOWN, EPIPE)                     \
782   X(WSAETIMEDOUT, ETIMEDOUT)                 \
783   X(WSAEWOULDBLOCK, EWOULDBLOCK)             \
784   X(WSANOTINITIALISED, ENETDOWN)             \
785   X(WSASYSNOTREADY, ENETDOWN)                \
786   X(WSAVERNOTSUPPORTED, ENOSYS)
787 
err__map_win_error_to_errno(DWORD error)788 static errno_t err__map_win_error_to_errno(DWORD error) {
789   switch (error) {
790 #define X(error_sym, errno_sym) \
791   case error_sym:               \
792     return errno_sym;
793     ERR__ERRNO_MAPPINGS(X)
794 #undef X
795   }
796   return EINVAL;
797 }
798 
err_map_win_error(void)799 void err_map_win_error(void) {
800   errno = err__map_win_error_to_errno(GetLastError());
801 }
802 
err_set_win_error(DWORD error)803 void err_set_win_error(DWORD error) {
804   SetLastError(error);
805   errno = err__map_win_error_to_errno(error);
806 }
807 
err_check_handle(HANDLE handle)808 int err_check_handle(HANDLE handle) {
809   DWORD flags;
810 
811   /* GetHandleInformation() succeeds when passed INVALID_HANDLE_VALUE, so check
812    * for this condition explicitly. */
813   if (handle == INVALID_HANDLE_VALUE)
814     return_set_error(-1, ERROR_INVALID_HANDLE);
815 
816   if (!GetHandleInformation(handle, &flags))
817     return_map_error(-1);
818 
819   return 0;
820 }
821 
822 #include <stddef.h>
823 
824 #define array_count(a) (sizeof(a) / (sizeof((a)[0])))
825 
826 #define container_of(ptr, type, member) \
827   ((type*) ((uintptr_t) (ptr) - offsetof(type, member)))
828 
829 #define unused_var(v) ((void) (v))
830 
831 /* Polyfill `inline` for older versions of msvc (up to Visual Studio 2013) */
832 #if defined(_MSC_VER) && _MSC_VER < 1900
833 #define inline __inline
834 #endif
835 
836 WEPOLL_INTERNAL int ws_global_init(void);
837 WEPOLL_INTERNAL SOCKET ws_get_base_socket(SOCKET socket);
838 
839 static bool init__done = false;
840 static INIT_ONCE init__once = INIT_ONCE_STATIC_INIT;
841 
init__once_callback(INIT_ONCE * once,void * parameter,void ** context)842 static BOOL CALLBACK init__once_callback(INIT_ONCE* once,
843                                          void* parameter,
844                                          void** context) {
845   unused_var(once);
846   unused_var(parameter);
847   unused_var(context);
848 
849   /* N.b. that initialization order matters here. */
850   if (ws_global_init() < 0 || nt_global_init() < 0 ||
851       reflock_global_init() < 0 || epoll_global_init() < 0)
852     return FALSE;
853 
854   init__done = true;
855   return TRUE;
856 }
857 
init(void)858 int init(void) {
859   if (!init__done &&
860       !InitOnceExecuteOnce(&init__once, init__once_callback, NULL, NULL))
861     /* `InitOnceExecuteOnce()` itself is infallible, and it doesn't set any
862      * error code when the once-callback returns FALSE. We return -1 here to
863      * indicate that global initialization failed; the failing init function is
864      * resposible for setting `errno` and calling `SetLastError()`. */
865     return -1;
866 
867   return 0;
868 }
869 
870 /* Set up a workaround for the following problem:
871  *   FARPROC addr = GetProcAddress(...);
872  *   MY_FUNC func = (MY_FUNC) addr;          <-- GCC 8 warning/error.
873  *   MY_FUNC func = (MY_FUNC) (void*) addr;  <-- MSVC  warning/error.
874  * To compile cleanly with either compiler, do casts with this "bridge" type:
875  *   MY_FUNC func = (MY_FUNC) (nt__fn_ptr_cast_t) addr; */
876 #ifdef __GNUC__
877 typedef void* nt__fn_ptr_cast_t;
878 #else
879 typedef FARPROC nt__fn_ptr_cast_t;
880 #endif
881 
882 #define X(return_type, attributes, name, parameters) \
883   WEPOLL_INTERNAL return_type(attributes* name) parameters = NULL;
NT_NTDLL_IMPORT_LIST(X)884 NT_NTDLL_IMPORT_LIST(X)
885 #undef X
886 
887 int nt_global_init(void) {
888   HMODULE ntdll;
889   FARPROC fn_ptr;
890 
891   ntdll = GetModuleHandleW(L"ntdll.dll");
892   if (ntdll == NULL)
893     return -1;
894 
895 #define X(return_type, attributes, name, parameters) \
896   fn_ptr = GetProcAddress(ntdll, #name);             \
897   if (fn_ptr == NULL)                                \
898     return -1;                                       \
899   name = (return_type(attributes*) parameters)(nt__fn_ptr_cast_t) fn_ptr;
900   NT_NTDLL_IMPORT_LIST(X)
901 #undef X
902 
903   return 0;
904 }
905 
906 #include <string.h>
907 
908 typedef struct poll_group poll_group_t;
909 
910 typedef struct queue_node queue_node_t;
911 
912 WEPOLL_INTERNAL poll_group_t* poll_group_acquire(port_state_t* port);
913 WEPOLL_INTERNAL void poll_group_release(poll_group_t* poll_group);
914 
915 WEPOLL_INTERNAL void poll_group_delete(poll_group_t* poll_group);
916 
917 WEPOLL_INTERNAL poll_group_t* poll_group_from_queue_node(
918     queue_node_t* queue_node);
919 WEPOLL_INTERNAL HANDLE
920     poll_group_get_afd_device_handle(poll_group_t* poll_group);
921 
922 typedef struct queue_node {
923   queue_node_t* prev;
924   queue_node_t* next;
925 } queue_node_t;
926 
927 typedef struct queue {
928   queue_node_t head;
929 } queue_t;
930 
931 WEPOLL_INTERNAL void queue_init(queue_t* queue);
932 WEPOLL_INTERNAL void queue_node_init(queue_node_t* node);
933 
934 WEPOLL_INTERNAL queue_node_t* queue_first(const queue_t* queue);
935 WEPOLL_INTERNAL queue_node_t* queue_last(const queue_t* queue);
936 
937 WEPOLL_INTERNAL void queue_prepend(queue_t* queue, queue_node_t* node);
938 WEPOLL_INTERNAL void queue_append(queue_t* queue, queue_node_t* node);
939 WEPOLL_INTERNAL void queue_move_to_start(queue_t* queue, queue_node_t* node);
940 WEPOLL_INTERNAL void queue_move_to_end(queue_t* queue, queue_node_t* node);
941 WEPOLL_INTERNAL void queue_remove(queue_node_t* node);
942 
943 WEPOLL_INTERNAL bool queue_is_empty(const queue_t* queue);
944 WEPOLL_INTERNAL bool queue_is_enqueued(const queue_node_t* node);
945 
946 #define POLL_GROUP__MAX_GROUP_SIZE 32
947 
948 typedef struct poll_group {
949   port_state_t* port_state;
950   queue_node_t queue_node;
951   HANDLE afd_device_handle;
952   size_t group_size;
953 } poll_group_t;
954 
poll_group__new(port_state_t * port_state)955 static poll_group_t* poll_group__new(port_state_t* port_state) {
956   HANDLE iocp_handle = port_get_iocp_handle(port_state);
957   queue_t* poll_group_queue = port_get_poll_group_queue(port_state);
958 
959   poll_group_t* poll_group = malloc(sizeof *poll_group);
960   if (poll_group == NULL)
961     return_set_error(NULL, ERROR_NOT_ENOUGH_MEMORY);
962 
963   memset(poll_group, 0, sizeof *poll_group);
964 
965   queue_node_init(&poll_group->queue_node);
966   poll_group->port_state = port_state;
967 
968   if (afd_create_device_handle(iocp_handle, &poll_group->afd_device_handle) <
969       0) {
970     free(poll_group);
971     return NULL;
972   }
973 
974   queue_append(poll_group_queue, &poll_group->queue_node);
975 
976   return poll_group;
977 }
978 
poll_group_delete(poll_group_t * poll_group)979 void poll_group_delete(poll_group_t* poll_group) {
980   assert(poll_group->group_size == 0);
981   CloseHandle(poll_group->afd_device_handle);
982   queue_remove(&poll_group->queue_node);
983   free(poll_group);
984 }
985 
poll_group_from_queue_node(queue_node_t * queue_node)986 poll_group_t* poll_group_from_queue_node(queue_node_t* queue_node) {
987   return container_of(queue_node, poll_group_t, queue_node);
988 }
989 
poll_group_get_afd_device_handle(poll_group_t * poll_group)990 HANDLE poll_group_get_afd_device_handle(poll_group_t* poll_group) {
991   return poll_group->afd_device_handle;
992 }
993 
poll_group_acquire(port_state_t * port_state)994 poll_group_t* poll_group_acquire(port_state_t* port_state) {
995   queue_t* poll_group_queue = port_get_poll_group_queue(port_state);
996   poll_group_t* poll_group =
997       !queue_is_empty(poll_group_queue)
998           ? container_of(
999                 queue_last(poll_group_queue), poll_group_t, queue_node)
1000           : NULL;
1001 
1002   if (poll_group == NULL ||
1003       poll_group->group_size >= POLL_GROUP__MAX_GROUP_SIZE)
1004     poll_group = poll_group__new(port_state);
1005   if (poll_group == NULL)
1006     return NULL;
1007 
1008   if (++poll_group->group_size == POLL_GROUP__MAX_GROUP_SIZE)
1009     queue_move_to_start(poll_group_queue, &poll_group->queue_node);
1010 
1011   return poll_group;
1012 }
1013 
poll_group_release(poll_group_t * poll_group)1014 void poll_group_release(poll_group_t* poll_group) {
1015   port_state_t* port_state = poll_group->port_state;
1016   queue_t* poll_group_queue = port_get_poll_group_queue(port_state);
1017 
1018   poll_group->group_size--;
1019   assert(poll_group->group_size < POLL_GROUP__MAX_GROUP_SIZE);
1020 
1021   queue_move_to_end(poll_group_queue, &poll_group->queue_node);
1022 
1023   /* Poll groups are currently only freed when the epoll port is closed. */
1024 }
1025 
1026 WEPOLL_INTERNAL sock_state_t* sock_new(port_state_t* port_state,
1027                                        SOCKET socket);
1028 WEPOLL_INTERNAL void sock_delete(port_state_t* port_state,
1029                                  sock_state_t* sock_state);
1030 WEPOLL_INTERNAL void sock_force_delete(port_state_t* port_state,
1031                                        sock_state_t* sock_state);
1032 
1033 WEPOLL_INTERNAL int sock_set_event(port_state_t* port_state,
1034                                    sock_state_t* sock_state,
1035                                    const struct epoll_event* ev);
1036 
1037 WEPOLL_INTERNAL int sock_update(port_state_t* port_state,
1038                                 sock_state_t* sock_state);
1039 WEPOLL_INTERNAL int sock_feed_event(port_state_t* port_state,
1040                                     IO_STATUS_BLOCK* io_status_block,
1041                                     struct epoll_event* ev);
1042 
1043 WEPOLL_INTERNAL sock_state_t* sock_state_from_queue_node(
1044     queue_node_t* queue_node);
1045 WEPOLL_INTERNAL queue_node_t* sock_state_to_queue_node(
1046     sock_state_t* sock_state);
1047 WEPOLL_INTERNAL sock_state_t* sock_state_from_tree_node(
1048     tree_node_t* tree_node);
1049 WEPOLL_INTERNAL tree_node_t* sock_state_to_tree_node(sock_state_t* sock_state);
1050 
1051 #define PORT__MAX_ON_STACK_COMPLETIONS 256
1052 
1053 typedef struct port_state {
1054   HANDLE iocp_handle;
1055   tree_t sock_tree;
1056   queue_t sock_update_queue;
1057   queue_t sock_deleted_queue;
1058   queue_t poll_group_queue;
1059   ts_tree_node_t handle_tree_node;
1060   CRITICAL_SECTION lock;
1061   size_t active_poll_count;
1062 } port_state_t;
1063 
port__alloc(void)1064 static inline port_state_t* port__alloc(void) {
1065   port_state_t* port_state = malloc(sizeof *port_state);
1066   if (port_state == NULL)
1067     return_set_error(NULL, ERROR_NOT_ENOUGH_MEMORY);
1068 
1069   return port_state;
1070 }
1071 
port__free(port_state_t * port)1072 static inline void port__free(port_state_t* port) {
1073   assert(port != NULL);
1074   free(port);
1075 }
1076 
port__create_iocp(void)1077 static inline HANDLE port__create_iocp(void) {
1078   HANDLE iocp_handle =
1079       CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, 0);
1080   if (iocp_handle == NULL)
1081     return_map_error(NULL);
1082 
1083   return iocp_handle;
1084 }
1085 
port_new(HANDLE * iocp_handle_out)1086 port_state_t* port_new(HANDLE* iocp_handle_out) {
1087   port_state_t* port_state;
1088   HANDLE iocp_handle;
1089 
1090   port_state = port__alloc();
1091   if (port_state == NULL)
1092     goto err1;
1093 
1094   iocp_handle = port__create_iocp();
1095   if (iocp_handle == NULL)
1096     goto err2;
1097 
1098   memset(port_state, 0, sizeof *port_state);
1099 
1100   port_state->iocp_handle = iocp_handle;
1101   tree_init(&port_state->sock_tree);
1102   queue_init(&port_state->sock_update_queue);
1103   queue_init(&port_state->sock_deleted_queue);
1104   queue_init(&port_state->poll_group_queue);
1105   ts_tree_node_init(&port_state->handle_tree_node);
1106   InitializeCriticalSection(&port_state->lock);
1107 
1108   *iocp_handle_out = iocp_handle;
1109   return port_state;
1110 
1111 err2:
1112   port__free(port_state);
1113 err1:
1114   return NULL;
1115 }
1116 
port__close_iocp(port_state_t * port_state)1117 static inline int port__close_iocp(port_state_t* port_state) {
1118   HANDLE iocp_handle = port_state->iocp_handle;
1119   port_state->iocp_handle = NULL;
1120 
1121   if (!CloseHandle(iocp_handle))
1122     return_map_error(-1);
1123 
1124   return 0;
1125 }
1126 
port_close(port_state_t * port_state)1127 int port_close(port_state_t* port_state) {
1128   int result;
1129 
1130   EnterCriticalSection(&port_state->lock);
1131   result = port__close_iocp(port_state);
1132   LeaveCriticalSection(&port_state->lock);
1133 
1134   return result;
1135 }
1136 
port_delete(port_state_t * port_state)1137 int port_delete(port_state_t* port_state) {
1138   tree_node_t* tree_node;
1139   queue_node_t* queue_node;
1140 
1141   /* At this point the IOCP port should have been closed. */
1142   assert(port_state->iocp_handle == NULL);
1143 
1144   while ((tree_node = tree_root(&port_state->sock_tree)) != NULL) {
1145     sock_state_t* sock_state = sock_state_from_tree_node(tree_node);
1146     sock_force_delete(port_state, sock_state);
1147   }
1148 
1149   while ((queue_node = queue_first(&port_state->sock_deleted_queue)) != NULL) {
1150     sock_state_t* sock_state = sock_state_from_queue_node(queue_node);
1151     sock_force_delete(port_state, sock_state);
1152   }
1153 
1154   while ((queue_node = queue_first(&port_state->poll_group_queue)) != NULL) {
1155     poll_group_t* poll_group = poll_group_from_queue_node(queue_node);
1156     poll_group_delete(poll_group);
1157   }
1158 
1159   assert(queue_is_empty(&port_state->sock_update_queue));
1160 
1161   DeleteCriticalSection(&port_state->lock);
1162 
1163   port__free(port_state);
1164 
1165   return 0;
1166 }
1167 
port__update_events(port_state_t * port_state)1168 static int port__update_events(port_state_t* port_state) {
1169   queue_t* sock_update_queue = &port_state->sock_update_queue;
1170 
1171   /* Walk the queue, submitting new poll requests for every socket that needs
1172    * it. */
1173   while (!queue_is_empty(sock_update_queue)) {
1174     queue_node_t* queue_node = queue_first(sock_update_queue);
1175     sock_state_t* sock_state = sock_state_from_queue_node(queue_node);
1176 
1177     if (sock_update(port_state, sock_state) < 0)
1178       return -1;
1179 
1180     /* sock_update() removes the socket from the update queue. */
1181   }
1182 
1183   return 0;
1184 }
1185 
port__update_events_if_polling(port_state_t * port_state)1186 static inline void port__update_events_if_polling(port_state_t* port_state) {
1187   if (port_state->active_poll_count > 0)
1188     port__update_events(port_state);
1189 }
1190 
1191 #ifdef NULL_OVERLAPPED_WAKEUPS_PATCH
1192 const int RECEIVED_WAKEUP_EVENT = 0xFEFEFEFE;
1193 #endif
1194 
1195 /*
1196  * port__feed_events has been modified based on:
1197  * - https://github.com/piscisaureus/wepoll/pull/20
1198  * - https://github.com/piscisaureus/wepoll/pull/20#issuecomment-677646507
1199  *
1200  * In the polling crate, in order to implement notify(), we call PostQueuedCompletionStatus passing a null
1201  * lpOverlapped parameter. This will result in our event, received via GetQueuedCompletionStatusEx, also
1202  * having a null lpOverlapped parameter. If this occurs, we should bail with indication that the caller should
1203  * not continue to wait for further events.
1204  */
port__feed_events(port_state_t * port_state,struct epoll_event * epoll_events,OVERLAPPED_ENTRY * iocp_events,DWORD iocp_event_count)1205 static inline int port__feed_events(port_state_t* port_state,
1206                                     struct epoll_event* epoll_events,
1207                                     OVERLAPPED_ENTRY* iocp_events,
1208                                     DWORD iocp_event_count) {
1209   int epoll_event_count = 0;
1210   DWORD i;
1211 #ifdef NULL_OVERLAPPED_WAKEUPS_PATCH
1212   bool received_wakeup_event = false;
1213 #endif /* NULL_OVERLAPPED_WAKEUPS_PATCH */
1214 
1215   for (i = 0; i < iocp_event_count; i++) {
1216     IO_STATUS_BLOCK* io_status_block =
1217         (IO_STATUS_BLOCK*) iocp_events[i].lpOverlapped;
1218     struct epoll_event* ev = &epoll_events[epoll_event_count];
1219 
1220 #ifdef NULL_OVERLAPPED_WAKEUPS_PATCH
1221     if (!io_status_block) {
1222       received_wakeup_event = true;
1223       continue;
1224     }
1225 #endif /* NULL_OVERLAPPED_WAKEUPS_PATCH */
1226 
1227     epoll_event_count += sock_feed_event(port_state, io_status_block, ev);
1228   }
1229 
1230 #ifdef NULL_OVERLAPPED_WAKEUPS_PATCH
1231   if (epoll_event_count == 0 && received_wakeup_event)
1232   {
1233     return RECEIVED_WAKEUP_EVENT;
1234   }
1235 #endif /* NULL_OVERLAPPED_WAKEUPS_PATCH */
1236 
1237   return epoll_event_count;
1238 }
1239 
port__poll(port_state_t * port_state,struct epoll_event * epoll_events,OVERLAPPED_ENTRY * iocp_events,DWORD maxevents,DWORD timeout)1240 static inline int port__poll(port_state_t* port_state,
1241                              struct epoll_event* epoll_events,
1242                              OVERLAPPED_ENTRY* iocp_events,
1243                              DWORD maxevents,
1244                              DWORD timeout) {
1245   DWORD completion_count;
1246 
1247   if (port__update_events(port_state) < 0)
1248     return -1;
1249 
1250   port_state->active_poll_count++;
1251 
1252   LeaveCriticalSection(&port_state->lock);
1253 
1254   BOOL r = GetQueuedCompletionStatusEx(port_state->iocp_handle,
1255                                        iocp_events,
1256                                        maxevents,
1257                                        &completion_count,
1258                                        timeout,
1259                                        FALSE);
1260 
1261   EnterCriticalSection(&port_state->lock);
1262 
1263   port_state->active_poll_count--;
1264 
1265   if (!r)
1266     return_map_error(-1);
1267 
1268   return port__feed_events(
1269       port_state, epoll_events, iocp_events, completion_count);
1270 }
1271 
port_wait(port_state_t * port_state,struct epoll_event * events,int maxevents,int timeout)1272 int port_wait(port_state_t* port_state,
1273               struct epoll_event* events,
1274               int maxevents,
1275               int timeout) {
1276   OVERLAPPED_ENTRY stack_iocp_events[PORT__MAX_ON_STACK_COMPLETIONS];
1277   OVERLAPPED_ENTRY* iocp_events;
1278   uint64_t due = 0;
1279   DWORD gqcs_timeout;
1280   int result;
1281 
1282   /* Check whether `maxevents` is in range. */
1283   if (maxevents <= 0)
1284     return_set_error(-1, ERROR_INVALID_PARAMETER);
1285 
1286   /* Decide whether the IOCP completion list can live on the stack, or allocate
1287    * memory for it on the heap. */
1288   if ((size_t) maxevents <= array_count(stack_iocp_events)) {
1289     iocp_events = stack_iocp_events;
1290   } else if ((iocp_events =
1291                   malloc((size_t) maxevents * sizeof *iocp_events)) == NULL) {
1292     iocp_events = stack_iocp_events;
1293     maxevents = array_count(stack_iocp_events);
1294   }
1295 
1296   /* Compute the timeout for GetQueuedCompletionStatus, and the wait end
1297    * time, if the user specified a timeout other than zero or infinite. */
1298   if (timeout > 0) {
1299     due = GetTickCount64() + (uint64_t) timeout;
1300     gqcs_timeout = (DWORD) timeout;
1301   } else if (timeout == 0) {
1302     gqcs_timeout = 0;
1303   } else {
1304     gqcs_timeout = INFINITE;
1305   }
1306 
1307   EnterCriticalSection(&port_state->lock);
1308 
1309   /* Dequeue completion packets until either at least one interesting event
1310    * has been discovered, or the timeout is reached. */
1311   for (;;) {
1312     uint64_t now;
1313 
1314     result = port__poll(
1315         port_state, events, iocp_events, (DWORD) maxevents, gqcs_timeout);
1316 
1317 #ifdef NULL_OVERLAPPED_WAKEUPS_PATCH
1318     /* A "wakeup" event triggered by a null lpOverlapped param */
1319     if (result == RECEIVED_WAKEUP_EVENT) {
1320       result = 0;
1321       break;
1322     }
1323 #endif /* NULL_OVERLAPPED_WAKEUPS_PATCH */
1324 
1325     if (result < 0 || result > 0)
1326       break; /* Result, error, or time-out. */
1327 
1328     if (timeout < 0)
1329       continue; /* When timeout is negative, never time out. */
1330 
1331     /* Update time. */
1332     now = GetTickCount64();
1333 
1334     /* Do not allow the due time to be in the past. */
1335     if (now >= due) {
1336       SetLastError(WAIT_TIMEOUT);
1337       break;
1338     }
1339 
1340     /* Recompute time-out argument for GetQueuedCompletionStatus. */
1341     gqcs_timeout = (DWORD)(due - now);
1342   }
1343 
1344   port__update_events_if_polling(port_state);
1345 
1346   LeaveCriticalSection(&port_state->lock);
1347 
1348   if (iocp_events != stack_iocp_events)
1349     free(iocp_events);
1350 
1351   if (result >= 0)
1352     return result;
1353   else if (GetLastError() == WAIT_TIMEOUT)
1354     return 0;
1355   else
1356     return -1;
1357 }
1358 
port__ctl_add(port_state_t * port_state,SOCKET sock,struct epoll_event * ev)1359 static inline int port__ctl_add(port_state_t* port_state,
1360                                 SOCKET sock,
1361                                 struct epoll_event* ev) {
1362   sock_state_t* sock_state = sock_new(port_state, sock);
1363   if (sock_state == NULL)
1364     return -1;
1365 
1366   if (sock_set_event(port_state, sock_state, ev) < 0) {
1367     sock_delete(port_state, sock_state);
1368     return -1;
1369   }
1370 
1371   port__update_events_if_polling(port_state);
1372 
1373   return 0;
1374 }
1375 
port__ctl_mod(port_state_t * port_state,SOCKET sock,struct epoll_event * ev)1376 static inline int port__ctl_mod(port_state_t* port_state,
1377                                 SOCKET sock,
1378                                 struct epoll_event* ev) {
1379   sock_state_t* sock_state = port_find_socket(port_state, sock);
1380   if (sock_state == NULL)
1381     return -1;
1382 
1383   if (sock_set_event(port_state, sock_state, ev) < 0)
1384     return -1;
1385 
1386   port__update_events_if_polling(port_state);
1387 
1388   return 0;
1389 }
1390 
port__ctl_del(port_state_t * port_state,SOCKET sock)1391 static inline int port__ctl_del(port_state_t* port_state, SOCKET sock) {
1392   sock_state_t* sock_state = port_find_socket(port_state, sock);
1393   if (sock_state == NULL)
1394     return -1;
1395 
1396   sock_delete(port_state, sock_state);
1397 
1398   return 0;
1399 }
1400 
port__ctl_op(port_state_t * port_state,int op,SOCKET sock,struct epoll_event * ev)1401 static inline int port__ctl_op(port_state_t* port_state,
1402                                int op,
1403                                SOCKET sock,
1404                                struct epoll_event* ev) {
1405   switch (op) {
1406     case EPOLL_CTL_ADD:
1407       return port__ctl_add(port_state, sock, ev);
1408     case EPOLL_CTL_MOD:
1409       return port__ctl_mod(port_state, sock, ev);
1410     case EPOLL_CTL_DEL:
1411       return port__ctl_del(port_state, sock);
1412     default:
1413       return_set_error(-1, ERROR_INVALID_PARAMETER);
1414   }
1415 }
1416 
port_ctl(port_state_t * port_state,int op,SOCKET sock,struct epoll_event * ev)1417 int port_ctl(port_state_t* port_state,
1418              int op,
1419              SOCKET sock,
1420              struct epoll_event* ev) {
1421   int result;
1422 
1423   EnterCriticalSection(&port_state->lock);
1424   result = port__ctl_op(port_state, op, sock, ev);
1425   LeaveCriticalSection(&port_state->lock);
1426 
1427   return result;
1428 }
1429 
port_register_socket(port_state_t * port_state,sock_state_t * sock_state,SOCKET socket)1430 int port_register_socket(port_state_t* port_state,
1431                          sock_state_t* sock_state,
1432                          SOCKET socket) {
1433   if (tree_add(&port_state->sock_tree,
1434                sock_state_to_tree_node(sock_state),
1435                socket) < 0)
1436     return_set_error(-1, ERROR_ALREADY_EXISTS);
1437   return 0;
1438 }
1439 
port_unregister_socket(port_state_t * port_state,sock_state_t * sock_state)1440 void port_unregister_socket(port_state_t* port_state,
1441                             sock_state_t* sock_state) {
1442   tree_del(&port_state->sock_tree, sock_state_to_tree_node(sock_state));
1443 }
1444 
port_find_socket(port_state_t * port_state,SOCKET socket)1445 sock_state_t* port_find_socket(port_state_t* port_state, SOCKET socket) {
1446   tree_node_t* tree_node = tree_find(&port_state->sock_tree, socket);
1447   if (tree_node == NULL)
1448     return_set_error(NULL, ERROR_NOT_FOUND);
1449   return sock_state_from_tree_node(tree_node);
1450 }
1451 
port_request_socket_update(port_state_t * port_state,sock_state_t * sock_state)1452 void port_request_socket_update(port_state_t* port_state,
1453                                 sock_state_t* sock_state) {
1454   if (queue_is_enqueued(sock_state_to_queue_node(sock_state)))
1455     return;
1456   queue_append(&port_state->sock_update_queue,
1457                sock_state_to_queue_node(sock_state));
1458 }
1459 
port_cancel_socket_update(port_state_t * port_state,sock_state_t * sock_state)1460 void port_cancel_socket_update(port_state_t* port_state,
1461                                sock_state_t* sock_state) {
1462   unused_var(port_state);
1463   if (!queue_is_enqueued(sock_state_to_queue_node(sock_state)))
1464     return;
1465   queue_remove(sock_state_to_queue_node(sock_state));
1466 }
1467 
port_add_deleted_socket(port_state_t * port_state,sock_state_t * sock_state)1468 void port_add_deleted_socket(port_state_t* port_state,
1469                              sock_state_t* sock_state) {
1470   if (queue_is_enqueued(sock_state_to_queue_node(sock_state)))
1471     return;
1472   queue_append(&port_state->sock_deleted_queue,
1473                sock_state_to_queue_node(sock_state));
1474 }
1475 
port_remove_deleted_socket(port_state_t * port_state,sock_state_t * sock_state)1476 void port_remove_deleted_socket(port_state_t* port_state,
1477                                 sock_state_t* sock_state) {
1478   unused_var(port_state);
1479   if (!queue_is_enqueued(sock_state_to_queue_node(sock_state)))
1480     return;
1481   queue_remove(sock_state_to_queue_node(sock_state));
1482 }
1483 
port_get_iocp_handle(port_state_t * port_state)1484 HANDLE port_get_iocp_handle(port_state_t* port_state) {
1485   assert(port_state->iocp_handle != NULL);
1486   return port_state->iocp_handle;
1487 }
1488 
port_get_poll_group_queue(port_state_t * port_state)1489 queue_t* port_get_poll_group_queue(port_state_t* port_state) {
1490   return &port_state->poll_group_queue;
1491 }
1492 
port_state_from_handle_tree_node(ts_tree_node_t * tree_node)1493 port_state_t* port_state_from_handle_tree_node(ts_tree_node_t* tree_node) {
1494   return container_of(tree_node, port_state_t, handle_tree_node);
1495 }
1496 
port_state_to_handle_tree_node(port_state_t * port_state)1497 ts_tree_node_t* port_state_to_handle_tree_node(port_state_t* port_state) {
1498   return &port_state->handle_tree_node;
1499 }
1500 
queue_init(queue_t * queue)1501 void queue_init(queue_t* queue) {
1502   queue_node_init(&queue->head);
1503 }
1504 
queue_node_init(queue_node_t * node)1505 void queue_node_init(queue_node_t* node) {
1506   node->prev = node;
1507   node->next = node;
1508 }
1509 
queue__detach_node(queue_node_t * node)1510 static inline void queue__detach_node(queue_node_t* node) {
1511   node->prev->next = node->next;
1512   node->next->prev = node->prev;
1513 }
1514 
queue_first(const queue_t * queue)1515 queue_node_t* queue_first(const queue_t* queue) {
1516   return !queue_is_empty(queue) ? queue->head.next : NULL;
1517 }
1518 
queue_last(const queue_t * queue)1519 queue_node_t* queue_last(const queue_t* queue) {
1520   return !queue_is_empty(queue) ? queue->head.prev : NULL;
1521 }
1522 
queue_prepend(queue_t * queue,queue_node_t * node)1523 void queue_prepend(queue_t* queue, queue_node_t* node) {
1524   node->next = queue->head.next;
1525   node->prev = &queue->head;
1526   node->next->prev = node;
1527   queue->head.next = node;
1528 }
1529 
queue_append(queue_t * queue,queue_node_t * node)1530 void queue_append(queue_t* queue, queue_node_t* node) {
1531   node->next = &queue->head;
1532   node->prev = queue->head.prev;
1533   node->prev->next = node;
1534   queue->head.prev = node;
1535 }
1536 
queue_move_to_start(queue_t * queue,queue_node_t * node)1537 void queue_move_to_start(queue_t* queue, queue_node_t* node) {
1538   queue__detach_node(node);
1539   queue_prepend(queue, node);
1540 }
1541 
queue_move_to_end(queue_t * queue,queue_node_t * node)1542 void queue_move_to_end(queue_t* queue, queue_node_t* node) {
1543   queue__detach_node(node);
1544   queue_append(queue, node);
1545 }
1546 
queue_remove(queue_node_t * node)1547 void queue_remove(queue_node_t* node) {
1548   queue__detach_node(node);
1549   queue_node_init(node);
1550 }
1551 
queue_is_empty(const queue_t * queue)1552 bool queue_is_empty(const queue_t* queue) {
1553   return !queue_is_enqueued(&queue->head);
1554 }
1555 
queue_is_enqueued(const queue_node_t * node)1556 bool queue_is_enqueued(const queue_node_t* node) {
1557   return node->prev != node;
1558 }
1559 
1560 #define REFLOCK__REF          ((long) 0x00000001UL)
1561 #define REFLOCK__REF_MASK     ((long) 0x0fffffffUL)
1562 #define REFLOCK__DESTROY      ((long) 0x10000000UL)
1563 #define REFLOCK__DESTROY_MASK ((long) 0xf0000000UL)
1564 #define REFLOCK__POISON       ((long) 0x300dead0UL)
1565 
1566 static HANDLE reflock__keyed_event = NULL;
1567 
reflock_global_init(void)1568 int reflock_global_init(void) {
1569   NTSTATUS status = NtCreateKeyedEvent(
1570       &reflock__keyed_event, KEYEDEVENT_ALL_ACCESS, NULL, 0);
1571   if (status != STATUS_SUCCESS)
1572     return_set_error(-1, RtlNtStatusToDosError(status));
1573   return 0;
1574 }
1575 
reflock_init(reflock_t * reflock)1576 void reflock_init(reflock_t* reflock) {
1577   reflock->state = 0;
1578 }
1579 
reflock__signal_event(void * address)1580 static void reflock__signal_event(void* address) {
1581   NTSTATUS status =
1582       NtReleaseKeyedEvent(reflock__keyed_event, address, FALSE, NULL);
1583   if (status != STATUS_SUCCESS)
1584     abort();
1585 }
1586 
reflock__await_event(void * address)1587 static void reflock__await_event(void* address) {
1588   NTSTATUS status =
1589       NtWaitForKeyedEvent(reflock__keyed_event, address, FALSE, NULL);
1590   if (status != STATUS_SUCCESS)
1591     abort();
1592 }
1593 
reflock_ref(reflock_t * reflock)1594 void reflock_ref(reflock_t* reflock) {
1595   long state = InterlockedAdd(&reflock->state, REFLOCK__REF);
1596 
1597   /* Verify that the counter didn't overflow and the lock isn't destroyed. */
1598   assert((state & REFLOCK__DESTROY_MASK) == 0);
1599   unused_var(state);
1600 }
1601 
reflock_unref(reflock_t * reflock)1602 void reflock_unref(reflock_t* reflock) {
1603   long state = InterlockedAdd(&reflock->state, -REFLOCK__REF);
1604 
1605   /* Verify that the lock was referenced and not already destroyed. */
1606   assert((state & REFLOCK__DESTROY_MASK & ~REFLOCK__DESTROY) == 0);
1607 
1608   if (state == REFLOCK__DESTROY)
1609     reflock__signal_event(reflock);
1610 }
1611 
reflock_unref_and_destroy(reflock_t * reflock)1612 void reflock_unref_and_destroy(reflock_t* reflock) {
1613   long state =
1614       InterlockedAdd(&reflock->state, REFLOCK__DESTROY - REFLOCK__REF);
1615   long ref_count = state & REFLOCK__REF_MASK;
1616 
1617   /* Verify that the lock was referenced and not already destroyed. */
1618   assert((state & REFLOCK__DESTROY_MASK) == REFLOCK__DESTROY);
1619 
1620   if (ref_count != 0)
1621     reflock__await_event(reflock);
1622 
1623   state = InterlockedExchange(&reflock->state, REFLOCK__POISON);
1624   assert(state == REFLOCK__DESTROY);
1625 }
1626 
1627 #define SOCK__KNOWN_EPOLL_EVENTS                                       \
1628   (EPOLLIN | EPOLLPRI | EPOLLOUT | EPOLLERR | EPOLLHUP | EPOLLRDNORM | \
1629    EPOLLRDBAND | EPOLLWRNORM | EPOLLWRBAND | EPOLLMSG | EPOLLRDHUP)
1630 
1631 typedef enum sock__poll_status {
1632   SOCK__POLL_IDLE = 0,
1633   SOCK__POLL_PENDING,
1634   SOCK__POLL_CANCELLED
1635 } sock__poll_status_t;
1636 
1637 typedef struct sock_state {
1638   IO_STATUS_BLOCK io_status_block;
1639   AFD_POLL_INFO poll_info;
1640   queue_node_t queue_node;
1641   tree_node_t tree_node;
1642   poll_group_t* poll_group;
1643   SOCKET base_socket;
1644   epoll_data_t user_data;
1645   uint32_t user_events;
1646   uint32_t pending_events;
1647   sock__poll_status_t poll_status;
1648   bool delete_pending;
1649 } sock_state_t;
1650 
sock__alloc(void)1651 static inline sock_state_t* sock__alloc(void) {
1652   sock_state_t* sock_state = malloc(sizeof *sock_state);
1653   if (sock_state == NULL)
1654     return_set_error(NULL, ERROR_NOT_ENOUGH_MEMORY);
1655   return sock_state;
1656 }
1657 
sock__free(sock_state_t * sock_state)1658 static inline void sock__free(sock_state_t* sock_state) {
1659   assert(sock_state != NULL);
1660   free(sock_state);
1661 }
1662 
sock__cancel_poll(sock_state_t * sock_state)1663 static inline int sock__cancel_poll(sock_state_t* sock_state) {
1664   assert(sock_state->poll_status == SOCK__POLL_PENDING);
1665 
1666   if (afd_cancel_poll(poll_group_get_afd_device_handle(sock_state->poll_group),
1667                       &sock_state->io_status_block) < 0)
1668     return -1;
1669 
1670   sock_state->poll_status = SOCK__POLL_CANCELLED;
1671   sock_state->pending_events = 0;
1672   return 0;
1673 }
1674 
sock_new(port_state_t * port_state,SOCKET socket)1675 sock_state_t* sock_new(port_state_t* port_state, SOCKET socket) {
1676   SOCKET base_socket;
1677   poll_group_t* poll_group;
1678   sock_state_t* sock_state;
1679 
1680   if (socket == 0 || socket == INVALID_SOCKET)
1681     return_set_error(NULL, ERROR_INVALID_HANDLE);
1682 
1683   base_socket = ws_get_base_socket(socket);
1684   if (base_socket == INVALID_SOCKET)
1685     return NULL;
1686 
1687   poll_group = poll_group_acquire(port_state);
1688   if (poll_group == NULL)
1689     return NULL;
1690 
1691   sock_state = sock__alloc();
1692   if (sock_state == NULL)
1693     goto err1;
1694 
1695   memset(sock_state, 0, sizeof *sock_state);
1696 
1697   sock_state->base_socket = base_socket;
1698   sock_state->poll_group = poll_group;
1699 
1700   tree_node_init(&sock_state->tree_node);
1701   queue_node_init(&sock_state->queue_node);
1702 
1703   if (port_register_socket(port_state, sock_state, socket) < 0)
1704     goto err2;
1705 
1706   return sock_state;
1707 
1708 err2:
1709   sock__free(sock_state);
1710 err1:
1711   poll_group_release(poll_group);
1712 
1713   return NULL;
1714 }
1715 
sock__delete(port_state_t * port_state,sock_state_t * sock_state,bool force)1716 static int sock__delete(port_state_t* port_state,
1717                         sock_state_t* sock_state,
1718                         bool force) {
1719   if (!sock_state->delete_pending) {
1720     if (sock_state->poll_status == SOCK__POLL_PENDING)
1721       sock__cancel_poll(sock_state);
1722 
1723     port_cancel_socket_update(port_state, sock_state);
1724     port_unregister_socket(port_state, sock_state);
1725 
1726     sock_state->delete_pending = true;
1727   }
1728 
1729   /* If the poll request still needs to complete, the sock_state object can't
1730    * be free()d yet. `sock_feed_event()` or `port_close()` will take care
1731    * of this later. */
1732   if (force || sock_state->poll_status == SOCK__POLL_IDLE) {
1733     /* Free the sock_state now. */
1734     port_remove_deleted_socket(port_state, sock_state);
1735     poll_group_release(sock_state->poll_group);
1736     sock__free(sock_state);
1737   } else {
1738     /* Free the socket later. */
1739     port_add_deleted_socket(port_state, sock_state);
1740   }
1741 
1742   return 0;
1743 }
1744 
sock_delete(port_state_t * port_state,sock_state_t * sock_state)1745 void sock_delete(port_state_t* port_state, sock_state_t* sock_state) {
1746   sock__delete(port_state, sock_state, false);
1747 }
1748 
sock_force_delete(port_state_t * port_state,sock_state_t * sock_state)1749 void sock_force_delete(port_state_t* port_state, sock_state_t* sock_state) {
1750   sock__delete(port_state, sock_state, true);
1751 }
1752 
sock_set_event(port_state_t * port_state,sock_state_t * sock_state,const struct epoll_event * ev)1753 int sock_set_event(port_state_t* port_state,
1754                    sock_state_t* sock_state,
1755                    const struct epoll_event* ev) {
1756   /* EPOLLERR and EPOLLHUP are always reported, even when not requested by the
1757    * caller. However they are disabled after a event has been reported for a
1758    * socket for which the EPOLLONESHOT flag was set. */
1759   uint32_t events = ev->events | EPOLLERR | EPOLLHUP;
1760 
1761   sock_state->user_events = events;
1762   sock_state->user_data = ev->data;
1763 
1764   if ((events & SOCK__KNOWN_EPOLL_EVENTS & ~sock_state->pending_events) != 0)
1765     port_request_socket_update(port_state, sock_state);
1766 
1767   return 0;
1768 }
1769 
sock__epoll_events_to_afd_events(uint32_t epoll_events)1770 static inline DWORD sock__epoll_events_to_afd_events(uint32_t epoll_events) {
1771   /* Always monitor for AFD_POLL_LOCAL_CLOSE, which is triggered when the
1772    * socket is closed with closesocket() or CloseHandle(). */
1773   DWORD afd_events = AFD_POLL_LOCAL_CLOSE;
1774 
1775   if (epoll_events & (EPOLLIN | EPOLLRDNORM))
1776     afd_events |= AFD_POLL_RECEIVE | AFD_POLL_ACCEPT;
1777   if (epoll_events & (EPOLLPRI | EPOLLRDBAND))
1778     afd_events |= AFD_POLL_RECEIVE_EXPEDITED;
1779   if (epoll_events & (EPOLLOUT | EPOLLWRNORM | EPOLLWRBAND))
1780     afd_events |= AFD_POLL_SEND;
1781   if (epoll_events & (EPOLLIN | EPOLLRDNORM | EPOLLRDHUP))
1782     afd_events |= AFD_POLL_DISCONNECT;
1783   if (epoll_events & EPOLLHUP)
1784     afd_events |= AFD_POLL_ABORT;
1785   if (epoll_events & EPOLLERR)
1786     afd_events |= AFD_POLL_CONNECT_FAIL;
1787 
1788   return afd_events;
1789 }
1790 
sock__afd_events_to_epoll_events(DWORD afd_events)1791 static inline uint32_t sock__afd_events_to_epoll_events(DWORD afd_events) {
1792   uint32_t epoll_events = 0;
1793 
1794   if (afd_events & (AFD_POLL_RECEIVE | AFD_POLL_ACCEPT))
1795     epoll_events |= EPOLLIN | EPOLLRDNORM;
1796   if (afd_events & AFD_POLL_RECEIVE_EXPEDITED)
1797     epoll_events |= EPOLLPRI | EPOLLRDBAND;
1798   if (afd_events & AFD_POLL_SEND)
1799     epoll_events |= EPOLLOUT | EPOLLWRNORM | EPOLLWRBAND;
1800   if (afd_events & AFD_POLL_DISCONNECT)
1801     epoll_events |= EPOLLIN | EPOLLRDNORM | EPOLLRDHUP;
1802   if (afd_events & AFD_POLL_ABORT)
1803     epoll_events |= EPOLLHUP;
1804   if (afd_events & AFD_POLL_CONNECT_FAIL)
1805     /* Linux reports all these events after connect() has failed. */
1806     epoll_events |=
1807         EPOLLIN | EPOLLOUT | EPOLLERR | EPOLLRDNORM | EPOLLWRNORM | EPOLLRDHUP;
1808 
1809   return epoll_events;
1810 }
1811 
sock_update(port_state_t * port_state,sock_state_t * sock_state)1812 int sock_update(port_state_t* port_state, sock_state_t* sock_state) {
1813   assert(!sock_state->delete_pending);
1814 
1815   if ((sock_state->poll_status == SOCK__POLL_PENDING) &&
1816       (sock_state->user_events & SOCK__KNOWN_EPOLL_EVENTS &
1817        ~sock_state->pending_events) == 0) {
1818     /* All the events the user is interested in are already being monitored by
1819      * the pending poll operation. It might spuriously complete because of an
1820      * event that we're no longer interested in; when that happens we'll submit
1821      * a new poll operation with the updated event mask. */
1822 
1823   } else if (sock_state->poll_status == SOCK__POLL_PENDING) {
1824     /* A poll operation is already pending, but it's not monitoring for all the
1825      * events that the user is interested in. Therefore, cancel the pending
1826      * poll operation; when we receive it's completion package, a new poll
1827      * operation will be submitted with the correct event mask. */
1828     if (sock__cancel_poll(sock_state) < 0)
1829       return -1;
1830 
1831   } else if (sock_state->poll_status == SOCK__POLL_CANCELLED) {
1832     /* The poll operation has already been cancelled, we're still waiting for
1833      * it to return. For now, there's nothing that needs to be done. */
1834 
1835   } else if (sock_state->poll_status == SOCK__POLL_IDLE) {
1836     /* No poll operation is pending; start one. */
1837     sock_state->poll_info.Exclusive = FALSE;
1838     sock_state->poll_info.NumberOfHandles = 1;
1839     sock_state->poll_info.Timeout.QuadPart = INT64_MAX;
1840     sock_state->poll_info.Handles[0].Handle = (HANDLE) sock_state->base_socket;
1841     sock_state->poll_info.Handles[0].Status = 0;
1842     sock_state->poll_info.Handles[0].Events =
1843         sock__epoll_events_to_afd_events(sock_state->user_events);
1844 
1845     if (afd_poll(poll_group_get_afd_device_handle(sock_state->poll_group),
1846                  &sock_state->poll_info,
1847                  &sock_state->io_status_block) < 0) {
1848       switch (GetLastError()) {
1849         case ERROR_IO_PENDING:
1850           /* Overlapped poll operation in progress; this is expected. */
1851           break;
1852         case ERROR_INVALID_HANDLE:
1853           /* Socket closed; it'll be dropped from the epoll set. */
1854           return sock__delete(port_state, sock_state, false);
1855         default:
1856           /* Other errors are propagated to the caller. */
1857           return_map_error(-1);
1858       }
1859     }
1860 
1861     /* The poll request was successfully submitted. */
1862     sock_state->poll_status = SOCK__POLL_PENDING;
1863     sock_state->pending_events = sock_state->user_events;
1864 
1865   } else {
1866     /* Unreachable. */
1867     assert(false);
1868   }
1869 
1870   port_cancel_socket_update(port_state, sock_state);
1871   return 0;
1872 }
1873 
sock_feed_event(port_state_t * port_state,IO_STATUS_BLOCK * io_status_block,struct epoll_event * ev)1874 int sock_feed_event(port_state_t* port_state,
1875                     IO_STATUS_BLOCK* io_status_block,
1876                     struct epoll_event* ev) {
1877   sock_state_t* sock_state =
1878       container_of(io_status_block, sock_state_t, io_status_block);
1879   AFD_POLL_INFO* poll_info = &sock_state->poll_info;
1880   uint32_t epoll_events = 0;
1881 
1882   sock_state->poll_status = SOCK__POLL_IDLE;
1883   sock_state->pending_events = 0;
1884 
1885   if (sock_state->delete_pending) {
1886     /* Socket has been deleted earlier and can now be freed. */
1887     return sock__delete(port_state, sock_state, false);
1888 
1889   } else if (io_status_block->Status == STATUS_CANCELLED) {
1890     /* The poll request was cancelled by CancelIoEx. */
1891 
1892   } else if (!NT_SUCCESS(io_status_block->Status)) {
1893     /* The overlapped request itself failed in an unexpected way. */
1894     epoll_events = EPOLLERR;
1895 
1896   } else if (poll_info->NumberOfHandles < 1) {
1897     /* This poll operation succeeded but didn't report any socket events. */
1898 
1899   } else if (poll_info->Handles[0].Events & AFD_POLL_LOCAL_CLOSE) {
1900     /* The poll operation reported that the socket was closed. */
1901     return sock__delete(port_state, sock_state, false);
1902 
1903   } else {
1904     /* Events related to our socket were reported. */
1905     epoll_events =
1906         sock__afd_events_to_epoll_events(poll_info->Handles[0].Events);
1907   }
1908 
1909   /* Requeue the socket so a new poll request will be submitted. */
1910   port_request_socket_update(port_state, sock_state);
1911 
1912   /* Filter out events that the user didn't ask for. */
1913   epoll_events &= sock_state->user_events;
1914 
1915   /* Return if there are no epoll events to report. */
1916   if (epoll_events == 0)
1917     return 0;
1918 
1919   /* If the the socket has the EPOLLONESHOT flag set, unmonitor all events,
1920    * even EPOLLERR and EPOLLHUP. But always keep looking for closed sockets. */
1921   if (sock_state->user_events & EPOLLONESHOT)
1922     sock_state->user_events = 0;
1923 
1924   ev->data = sock_state->user_data;
1925   ev->events = epoll_events;
1926   return 1;
1927 }
1928 
sock_state_from_queue_node(queue_node_t * queue_node)1929 sock_state_t* sock_state_from_queue_node(queue_node_t* queue_node) {
1930   return container_of(queue_node, sock_state_t, queue_node);
1931 }
1932 
sock_state_to_queue_node(sock_state_t * sock_state)1933 queue_node_t* sock_state_to_queue_node(sock_state_t* sock_state) {
1934   return &sock_state->queue_node;
1935 }
1936 
sock_state_from_tree_node(tree_node_t * tree_node)1937 sock_state_t* sock_state_from_tree_node(tree_node_t* tree_node) {
1938   return container_of(tree_node, sock_state_t, tree_node);
1939 }
1940 
sock_state_to_tree_node(sock_state_t * sock_state)1941 tree_node_t* sock_state_to_tree_node(sock_state_t* sock_state) {
1942   return &sock_state->tree_node;
1943 }
1944 
ts_tree_init(ts_tree_t * ts_tree)1945 void ts_tree_init(ts_tree_t* ts_tree) {
1946   tree_init(&ts_tree->tree);
1947   InitializeSRWLock(&ts_tree->lock);
1948 }
1949 
ts_tree_node_init(ts_tree_node_t * node)1950 void ts_tree_node_init(ts_tree_node_t* node) {
1951   tree_node_init(&node->tree_node);
1952   reflock_init(&node->reflock);
1953 }
1954 
ts_tree_add(ts_tree_t * ts_tree,ts_tree_node_t * node,uintptr_t key)1955 int ts_tree_add(ts_tree_t* ts_tree, ts_tree_node_t* node, uintptr_t key) {
1956   int r;
1957 
1958   AcquireSRWLockExclusive(&ts_tree->lock);
1959   r = tree_add(&ts_tree->tree, &node->tree_node, key);
1960   ReleaseSRWLockExclusive(&ts_tree->lock);
1961 
1962   return r;
1963 }
1964 
ts_tree__find_node(ts_tree_t * ts_tree,uintptr_t key)1965 static inline ts_tree_node_t* ts_tree__find_node(ts_tree_t* ts_tree,
1966                                                  uintptr_t key) {
1967   tree_node_t* tree_node = tree_find(&ts_tree->tree, key);
1968   if (tree_node == NULL)
1969     return NULL;
1970 
1971   return container_of(tree_node, ts_tree_node_t, tree_node);
1972 }
1973 
ts_tree_del_and_ref(ts_tree_t * ts_tree,uintptr_t key)1974 ts_tree_node_t* ts_tree_del_and_ref(ts_tree_t* ts_tree, uintptr_t key) {
1975   ts_tree_node_t* ts_tree_node;
1976 
1977   AcquireSRWLockExclusive(&ts_tree->lock);
1978 
1979   ts_tree_node = ts_tree__find_node(ts_tree, key);
1980   if (ts_tree_node != NULL) {
1981     tree_del(&ts_tree->tree, &ts_tree_node->tree_node);
1982     reflock_ref(&ts_tree_node->reflock);
1983   }
1984 
1985   ReleaseSRWLockExclusive(&ts_tree->lock);
1986 
1987   return ts_tree_node;
1988 }
1989 
ts_tree_find_and_ref(ts_tree_t * ts_tree,uintptr_t key)1990 ts_tree_node_t* ts_tree_find_and_ref(ts_tree_t* ts_tree, uintptr_t key) {
1991   ts_tree_node_t* ts_tree_node;
1992 
1993   AcquireSRWLockShared(&ts_tree->lock);
1994 
1995   ts_tree_node = ts_tree__find_node(ts_tree, key);
1996   if (ts_tree_node != NULL)
1997     reflock_ref(&ts_tree_node->reflock);
1998 
1999   ReleaseSRWLockShared(&ts_tree->lock);
2000 
2001   return ts_tree_node;
2002 }
2003 
ts_tree_node_unref(ts_tree_node_t * node)2004 void ts_tree_node_unref(ts_tree_node_t* node) {
2005   reflock_unref(&node->reflock);
2006 }
2007 
ts_tree_node_unref_and_destroy(ts_tree_node_t * node)2008 void ts_tree_node_unref_and_destroy(ts_tree_node_t* node) {
2009   reflock_unref_and_destroy(&node->reflock);
2010 }
2011 
tree_init(tree_t * tree)2012 void tree_init(tree_t* tree) {
2013   memset(tree, 0, sizeof *tree);
2014 }
2015 
tree_node_init(tree_node_t * node)2016 void tree_node_init(tree_node_t* node) {
2017   memset(node, 0, sizeof *node);
2018 }
2019 
2020 #define TREE__ROTATE(cis, trans)   \
2021   tree_node_t* p = node;           \
2022   tree_node_t* q = node->trans;    \
2023   tree_node_t* parent = p->parent; \
2024                                    \
2025   if (parent) {                    \
2026     if (parent->left == p)         \
2027       parent->left = q;            \
2028     else                           \
2029       parent->right = q;           \
2030   } else {                         \
2031     tree->root = q;                \
2032   }                                \
2033                                    \
2034   q->parent = parent;              \
2035   p->parent = q;                   \
2036   p->trans = q->cis;               \
2037   if (p->trans)                    \
2038     p->trans->parent = p;          \
2039   q->cis = p;
2040 
tree__rotate_left(tree_t * tree,tree_node_t * node)2041 static inline void tree__rotate_left(tree_t* tree, tree_node_t* node) {
2042   TREE__ROTATE(left, right)
2043 }
2044 
tree__rotate_right(tree_t * tree,tree_node_t * node)2045 static inline void tree__rotate_right(tree_t* tree, tree_node_t* node) {
2046   TREE__ROTATE(right, left)
2047 }
2048 
2049 #define TREE__INSERT_OR_DESCEND(side) \
2050   if (parent->side) {                 \
2051     parent = parent->side;            \
2052   } else {                            \
2053     parent->side = node;              \
2054     break;                            \
2055   }
2056 
2057 #define TREE__REBALANCE_AFTER_INSERT(cis, trans) \
2058   tree_node_t* grandparent = parent->parent;     \
2059   tree_node_t* uncle = grandparent->trans;       \
2060                                                  \
2061   if (uncle && uncle->red) {                     \
2062     parent->red = uncle->red = false;            \
2063     grandparent->red = true;                     \
2064     node = grandparent;                          \
2065   } else {                                       \
2066     if (node == parent->trans) {                 \
2067       tree__rotate_##cis(tree, parent);          \
2068       node = parent;                             \
2069       parent = node->parent;                     \
2070     }                                            \
2071     parent->red = false;                         \
2072     grandparent->red = true;                     \
2073     tree__rotate_##trans(tree, grandparent);     \
2074   }
2075 
tree_add(tree_t * tree,tree_node_t * node,uintptr_t key)2076 int tree_add(tree_t* tree, tree_node_t* node, uintptr_t key) {
2077   tree_node_t* parent;
2078 
2079   parent = tree->root;
2080   if (parent) {
2081     for (;;) {
2082       if (key < parent->key) {
2083         TREE__INSERT_OR_DESCEND(left)
2084       } else if (key > parent->key) {
2085         TREE__INSERT_OR_DESCEND(right)
2086       } else {
2087         return -1;
2088       }
2089     }
2090   } else {
2091     tree->root = node;
2092   }
2093 
2094   node->key = key;
2095   node->left = node->right = NULL;
2096   node->parent = parent;
2097   node->red = true;
2098 
2099   for (; parent && parent->red; parent = node->parent) {
2100     if (parent == parent->parent->left) {
2101       TREE__REBALANCE_AFTER_INSERT(left, right)
2102     } else {
2103       TREE__REBALANCE_AFTER_INSERT(right, left)
2104     }
2105   }
2106   tree->root->red = false;
2107 
2108   return 0;
2109 }
2110 
2111 #define TREE__REBALANCE_AFTER_REMOVE(cis, trans)   \
2112   tree_node_t* sibling = parent->trans;            \
2113                                                    \
2114   if (sibling->red) {                              \
2115     sibling->red = false;                          \
2116     parent->red = true;                            \
2117     tree__rotate_##cis(tree, parent);              \
2118     sibling = parent->trans;                       \
2119   }                                                \
2120   if ((sibling->left && sibling->left->red) ||     \
2121       (sibling->right && sibling->right->red)) {   \
2122     if (!sibling->trans || !sibling->trans->red) { \
2123       sibling->cis->red = false;                   \
2124       sibling->red = true;                         \
2125       tree__rotate_##trans(tree, sibling);         \
2126       sibling = parent->trans;                     \
2127     }                                              \
2128     sibling->red = parent->red;                    \
2129     parent->red = sibling->trans->red = false;     \
2130     tree__rotate_##cis(tree, parent);              \
2131     node = tree->root;                             \
2132     break;                                         \
2133   }                                                \
2134   sibling->red = true;
2135 
tree_del(tree_t * tree,tree_node_t * node)2136 void tree_del(tree_t* tree, tree_node_t* node) {
2137   tree_node_t* parent = node->parent;
2138   tree_node_t* left = node->left;
2139   tree_node_t* right = node->right;
2140   tree_node_t* next;
2141   bool red;
2142 
2143   if (!left) {
2144     next = right;
2145   } else if (!right) {
2146     next = left;
2147   } else {
2148     next = right;
2149     while (next->left)
2150       next = next->left;
2151   }
2152 
2153   if (parent) {
2154     if (parent->left == node)
2155       parent->left = next;
2156     else
2157       parent->right = next;
2158   } else {
2159     tree->root = next;
2160   }
2161 
2162   if (left && right) {
2163     red = next->red;
2164     next->red = node->red;
2165     next->left = left;
2166     left->parent = next;
2167     if (next != right) {
2168       parent = next->parent;
2169       next->parent = node->parent;
2170       node = next->right;
2171       parent->left = node;
2172       next->right = right;
2173       right->parent = next;
2174     } else {
2175       next->parent = parent;
2176       parent = next;
2177       node = next->right;
2178     }
2179   } else {
2180     red = node->red;
2181     node = next;
2182   }
2183 
2184   if (node)
2185     node->parent = parent;
2186   if (red)
2187     return;
2188   if (node && node->red) {
2189     node->red = false;
2190     return;
2191   }
2192 
2193   do {
2194     if (node == tree->root)
2195       break;
2196     if (node == parent->left) {
2197       TREE__REBALANCE_AFTER_REMOVE(left, right)
2198     } else {
2199       TREE__REBALANCE_AFTER_REMOVE(right, left)
2200     }
2201     node = parent;
2202     parent = parent->parent;
2203   } while (!node->red);
2204 
2205   if (node)
2206     node->red = false;
2207 }
2208 
tree_find(const tree_t * tree,uintptr_t key)2209 tree_node_t* tree_find(const tree_t* tree, uintptr_t key) {
2210   tree_node_t* node = tree->root;
2211   while (node) {
2212     if (key < node->key)
2213       node = node->left;
2214     else if (key > node->key)
2215       node = node->right;
2216     else
2217       return node;
2218   }
2219   return NULL;
2220 }
2221 
tree_root(const tree_t * tree)2222 tree_node_t* tree_root(const tree_t* tree) {
2223   return tree->root;
2224 }
2225 
2226 #ifndef SIO_BSP_HANDLE_POLL
2227 #define SIO_BSP_HANDLE_POLL 0x4800001D
2228 #endif
2229 
2230 #ifndef SIO_BASE_HANDLE
2231 #define SIO_BASE_HANDLE 0x48000022
2232 #endif
2233 
ws_global_init(void)2234 int ws_global_init(void) {
2235   int r;
2236   WSADATA wsa_data;
2237 
2238   r = WSAStartup(MAKEWORD(2, 2), &wsa_data);
2239   if (r != 0)
2240     return_set_error(-1, (DWORD) r);
2241 
2242   return 0;
2243 }
2244 
ws__ioctl_get_bsp_socket(SOCKET socket,DWORD ioctl)2245 static inline SOCKET ws__ioctl_get_bsp_socket(SOCKET socket, DWORD ioctl) {
2246   SOCKET bsp_socket;
2247   DWORD bytes;
2248 
2249   if (WSAIoctl(socket,
2250                ioctl,
2251                NULL,
2252                0,
2253                &bsp_socket,
2254                sizeof bsp_socket,
2255                &bytes,
2256                NULL,
2257                NULL) != SOCKET_ERROR)
2258     return bsp_socket;
2259   else
2260     return INVALID_SOCKET;
2261 }
2262 
ws_get_base_socket(SOCKET socket)2263 SOCKET ws_get_base_socket(SOCKET socket) {
2264   SOCKET base_socket;
2265   DWORD error;
2266 
2267   for (;;) {
2268     base_socket = ws__ioctl_get_bsp_socket(socket, SIO_BASE_HANDLE);
2269     if (base_socket != INVALID_SOCKET)
2270       return base_socket;
2271 
2272     error = GetLastError();
2273     if (error == WSAENOTSOCK)
2274       return_set_error(INVALID_SOCKET, error);
2275 
2276     /* Even though Microsoft documentation clearly states that LSPs should
2277      * never intercept the `SIO_BASE_HANDLE` ioctl [1], Komodia based LSPs do
2278      * so anyway, breaking it, with the apparent intention of preventing LSP
2279      * bypass [2]. Fortunately they don't handle `SIO_BSP_HANDLE_POLL`, which
2280      * will at least let us obtain the socket associated with the next winsock
2281      * protocol chain entry. If this succeeds, loop around and call
2282      * `SIO_BASE_HANDLE` again with the returned BSP socket, to make sure that
2283      * we unwrap all layers and retrieve the actual base socket.
2284      *  [1] https://docs.microsoft.com/en-us/windows/win32/winsock/winsock-ioctls
2285      *  [2] https://www.komodia.com/newwiki/index.php?title=Komodia%27s_Redirector_bug_fixes#Version_2.2.2.6
2286      */
2287     base_socket = ws__ioctl_get_bsp_socket(socket, SIO_BSP_HANDLE_POLL);
2288     if (base_socket != INVALID_SOCKET && base_socket != socket)
2289       socket = base_socket;
2290     else
2291       return_set_error(INVALID_SOCKET, error);
2292   }
2293 }
2294