1import _thread as thread
2import contextlib
3import errno
4import functools
5import gc
6from io import BytesIO
7import os
8import re
9import select
10import socket
11import struct
12import sys
13import threading
14import time
15import unittest
16import warnings
17
18from waitress import compat, wasyncore as asyncore
19
20TIMEOUT = 3
21HAS_UNIX_SOCKETS = hasattr(socket, "AF_UNIX")
22HOST = "localhost"
23HOSTv4 = "127.0.0.1"
24HOSTv6 = "::1"
25
26# Filename used for testing
27
28if os.name == "java":  # pragma: no cover
29    # Jython disallows @ in module names
30    TESTFN = "$test"
31else:
32    TESTFN = "@test"
33
34TESTFN = "{}_{}_tmp".format(TESTFN, os.getpid())
35
36
37class DummyLogger:  # pragma: no cover
38    def __init__(self):
39        self.messages = []
40
41    def log(self, severity, message):
42        self.messages.append((severity, message))
43
44
45class WarningsRecorder:  # pragma: no cover
46    """Convenience wrapper for the warnings list returned on
47    entry to the warnings.catch_warnings() context manager.
48    """
49
50    def __init__(self, warnings_list):
51        self._warnings = warnings_list
52        self._last = 0
53
54    @property
55    def warnings(self):
56        return self._warnings[self._last :]
57
58    def reset(self):
59        self._last = len(self._warnings)
60
61
62def _filterwarnings(filters, quiet=False):  # pragma: no cover
63    """Catch the warnings, then check if all the expected
64    warnings have been raised and re-raise unexpected warnings.
65    If 'quiet' is True, only re-raise the unexpected warnings.
66    """
67    # Clear the warning registry of the calling module
68    # in order to re-raise the warnings.
69    frame = sys._getframe(2)
70    registry = frame.f_globals.get("__warningregistry__")
71
72    if registry:
73        registry.clear()
74    with warnings.catch_warnings(record=True) as w:
75        # Set filter "always" to record all warnings.  Because
76        # test_warnings swap the module, we need to look up in
77        # the sys.modules dictionary.
78        sys.modules["warnings"].simplefilter("always")
79        yield WarningsRecorder(w)
80    # Filter the recorded warnings
81    reraise = list(w)
82    missing = []
83
84    for msg, cat in filters:
85        seen = False
86
87        for w in reraise[:]:
88            warning = w.message
89            # Filter out the matching messages
90
91            if re.match(msg, str(warning), re.I) and issubclass(warning.__class__, cat):
92                seen = True
93                reraise.remove(w)
94
95        if not seen and not quiet:
96            # This filter caught nothing
97            missing.append((msg, cat.__name__))
98
99    if reraise:
100        raise AssertionError("unhandled warning %s" % reraise[0])
101
102    if missing:
103        raise AssertionError("filter (%r, %s) did not catch any warning" % missing[0])
104
105
106@contextlib.contextmanager
107def check_warnings(*filters, **kwargs):  # pragma: no cover
108    """Context manager to silence warnings.
109
110    Accept 2-tuples as positional arguments:
111        ("message regexp", WarningCategory)
112
113    Optional argument:
114     - if 'quiet' is True, it does not fail if a filter catches nothing
115        (default True without argument,
116         default False if some filters are defined)
117
118    Without argument, it defaults to:
119        check_warnings(("", Warning), quiet=True)
120    """
121    quiet = kwargs.get("quiet")
122
123    if not filters:
124        filters = (("", Warning),)
125        # Preserve backward compatibility
126
127        if quiet is None:
128            quiet = True
129
130    return _filterwarnings(filters, quiet)
131
132
133def gc_collect():  # pragma: no cover
134    """Force as many objects as possible to be collected.
135
136    In non-CPython implementations of Python, this is needed because timely
137    deallocation is not guaranteed by the garbage collector.  (Even in CPython
138    this can be the case in case of reference cycles.)  This means that __del__
139    methods may be called later than expected and weakrefs may remain alive for
140    longer than expected.  This function tries its best to force all garbage
141    objects to disappear.
142    """
143    gc.collect()
144
145    if sys.platform.startswith("java"):
146        time.sleep(0.1)
147    gc.collect()
148    gc.collect()
149
150
151def threading_setup():  # pragma: no cover
152    return (thread._count(), None)
153
154
155def threading_cleanup(*original_values):  # pragma: no cover
156    global environment_altered
157
158    _MAX_COUNT = 100
159
160    for count in range(_MAX_COUNT):
161        values = (thread._count(), None)
162
163        if values == original_values:
164            break
165
166        if not count:
167            # Display a warning at the first iteration
168            environment_altered = True
169            sys.stderr.write(
170                "Warning -- threading_cleanup() failed to cleanup "
171                "%s threads" % (values[0] - original_values[0])
172            )
173            sys.stderr.flush()
174
175        values = None
176
177        time.sleep(0.01)
178        gc_collect()
179
180
181def reap_threads(func):  # pragma: no cover
182    """Use this function when threads are being used.  This will
183    ensure that the threads are cleaned up even when the test fails.
184    """
185
186    @functools.wraps(func)
187    def decorator(*args):
188        key = threading_setup()
189        try:
190            return func(*args)
191        finally:
192            threading_cleanup(*key)
193
194    return decorator
195
196
197def join_thread(thread, timeout=30.0):  # pragma: no cover
198    """Join a thread. Raise an AssertionError if the thread is still alive
199    after timeout seconds.
200    """
201    thread.join(timeout)
202
203    if thread.is_alive():
204        msg = "failed to join the thread in %.1f seconds" % timeout
205        raise AssertionError(msg)
206
207
208def bind_port(sock, host=HOST):  # pragma: no cover
209    """Bind the socket to a free port and return the port number.  Relies on
210    ephemeral ports in order to ensure we are using an unbound port.  This is
211    important as many tests may be running simultaneously, especially in a
212    buildbot environment.  This method raises an exception if the sock.family
213    is AF_INET and sock.type is SOCK_STREAM, *and* the socket has SO_REUSEADDR
214    or SO_REUSEPORT set on it.  Tests should *never* set these socket options
215    for TCP/IP sockets.  The only case for setting these options is testing
216    multicasting via multiple UDP sockets.
217
218    Additionally, if the SO_EXCLUSIVEADDRUSE socket option is available (i.e.
219    on Windows), it will be set on the socket.  This will prevent anyone else
220    from bind()'ing to our host/port for the duration of the test.
221    """
222
223    if sock.family == socket.AF_INET and sock.type == socket.SOCK_STREAM:
224        if hasattr(socket, "SO_REUSEADDR"):
225            if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) == 1:
226                raise RuntimeError(
227                    "tests should never set the SO_REUSEADDR "
228                    "socket option on TCP/IP sockets!"
229                )
230
231        if hasattr(socket, "SO_REUSEPORT"):
232            try:
233                if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 1:
234                    raise RuntimeError(
235                        "tests should never set the SO_REUSEPORT "
236                        "socket option on TCP/IP sockets!"
237                    )
238            except OSError:
239                # Python's socket module was compiled using modern headers
240                # thus defining SO_REUSEPORT but this process is running
241                # under an older kernel that does not support SO_REUSEPORT.
242                pass
243
244        if hasattr(socket, "SO_EXCLUSIVEADDRUSE"):
245            sock.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1)
246
247    sock.bind((host, 0))
248    port = sock.getsockname()[1]
249
250    return port
251
252
253@contextlib.contextmanager
254def closewrapper(sock):  # pragma: no cover
255    try:
256        yield sock
257    finally:
258        sock.close()
259
260
261class dummysocket:  # pragma: no cover
262    def __init__(self):
263        self.closed = False
264
265    def close(self):
266        self.closed = True
267
268    def fileno(self):
269        return 42
270
271    def setblocking(self, yesno):
272        self.isblocking = yesno
273
274    def getpeername(self):
275        return "peername"
276
277
278class dummychannel:  # pragma: no cover
279    def __init__(self):
280        self.socket = dummysocket()
281
282    def close(self):
283        self.socket.close()
284
285
286class exitingdummy:  # pragma: no cover
287    def __init__(self):
288        pass
289
290    def handle_read_event(self):
291        raise asyncore.ExitNow()
292
293    handle_write_event = handle_read_event
294    handle_close = handle_read_event
295    handle_expt_event = handle_read_event
296
297
298class crashingdummy:
299    def __init__(self):
300        self.error_handled = False
301
302    def handle_read_event(self):
303        raise Exception()
304
305    handle_write_event = handle_read_event
306    handle_close = handle_read_event
307    handle_expt_event = handle_read_event
308
309    def handle_error(self):
310        self.error_handled = True
311
312
313# used when testing senders; just collects what it gets until newline is sent
314def capture_server(evt, buf, serv):  # pragma no cover
315    try:
316        serv.listen(0)
317        conn, addr = serv.accept()
318    except socket.timeout:
319        pass
320    else:
321        n = 200
322        start = time.time()
323
324        while n > 0 and time.time() - start < 3.0:
325            r, w, e = select.select([conn], [], [], 0.1)
326
327            if r:
328                n -= 1
329                data = conn.recv(10)
330                # keep everything except for the newline terminator
331                buf.write(data.replace(b"\n", b""))
332
333                if b"\n" in data:
334                    break
335            time.sleep(0.01)
336
337        conn.close()
338    finally:
339        serv.close()
340        evt.set()
341
342
343def bind_unix_socket(sock, addr):  # pragma: no cover
344    """Bind a unix socket, raising SkipTest if PermissionError is raised."""
345    assert sock.family == socket.AF_UNIX
346    try:
347        sock.bind(addr)
348    except PermissionError:
349        sock.close()
350        raise unittest.SkipTest("cannot bind AF_UNIX sockets")
351
352
353def bind_af_aware(sock, addr):
354    """Helper function to bind a socket according to its family."""
355
356    if HAS_UNIX_SOCKETS and sock.family == socket.AF_UNIX:
357        # Make sure the path doesn't exist.
358        unlink(addr)
359        bind_unix_socket(sock, addr)
360    else:
361        sock.bind(addr)
362
363
364if sys.platform.startswith("win"):  # pragma: no cover
365
366    def _waitfor(func, pathname, waitall=False):
367        # Perform the operation
368        func(pathname)
369        # Now setup the wait loop
370
371        if waitall:
372            dirname = pathname
373        else:
374            dirname, name = os.path.split(pathname)
375            dirname = dirname or "."
376        # Check for `pathname` to be removed from the filesystem.
377        # The exponential backoff of the timeout amounts to a total
378        # of ~1 second after which the deletion is probably an error
379        # anyway.
380        # Testing on an i7@4.3GHz shows that usually only 1 iteration is
381        # required when contention occurs.
382        timeout = 0.001
383
384        while timeout < 1.0:
385            # Note we are only testing for the existence of the file(s) in
386            # the contents of the directory regardless of any security or
387            # access rights.  If we have made it this far, we have sufficient
388            # permissions to do that much using Python's equivalent of the
389            # Windows API FindFirstFile.
390            # Other Windows APIs can fail or give incorrect results when
391            # dealing with files that are pending deletion.
392            L = os.listdir(dirname)
393
394            if not (L if waitall else name in L):
395                return
396            # Increase the timeout and try again
397            time.sleep(timeout)
398            timeout *= 2
399        warnings.warn(
400            "tests may fail, delete still pending for " + pathname,
401            RuntimeWarning,
402            stacklevel=4,
403        )
404
405    def _unlink(filename):
406        _waitfor(os.unlink, filename)
407
408
409else:
410    _unlink = os.unlink
411
412
413def unlink(filename):
414    try:
415        _unlink(filename)
416    except OSError:
417        pass
418
419
420def _is_ipv6_enabled():  # pragma: no cover
421    """Check whether IPv6 is enabled on this host."""
422
423    if compat.HAS_IPV6:
424        sock = None
425        try:
426            sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
427            sock.bind(("::1", 0))
428
429            return True
430        except OSError:
431            pass
432        finally:
433            if sock:
434                sock.close()
435
436    return False
437
438
439IPV6_ENABLED = _is_ipv6_enabled()
440
441
442class HelperFunctionTests(unittest.TestCase):
443    def test_readwriteexc(self):
444        # Check exception handling behavior of read, write and _exception
445
446        # check that ExitNow exceptions in the object handler method
447        # bubbles all the way up through asyncore read/write/_exception calls
448        tr1 = exitingdummy()
449        self.assertRaises(asyncore.ExitNow, asyncore.read, tr1)
450        self.assertRaises(asyncore.ExitNow, asyncore.write, tr1)
451        self.assertRaises(asyncore.ExitNow, asyncore._exception, tr1)
452
453        # check that an exception other than ExitNow in the object handler
454        # method causes the handle_error method to get called
455        tr2 = crashingdummy()
456        asyncore.read(tr2)
457        self.assertEqual(tr2.error_handled, True)
458
459        tr2 = crashingdummy()
460        asyncore.write(tr2)
461        self.assertEqual(tr2.error_handled, True)
462
463        tr2 = crashingdummy()
464        asyncore._exception(tr2)
465        self.assertEqual(tr2.error_handled, True)
466
467    # asyncore.readwrite uses constants in the select module that
468    # are not present in Windows systems (see this thread:
469    # http://mail.python.org/pipermail/python-list/2001-October/109973.html)
470    # These constants should be present as long as poll is available
471
472    @unittest.skipUnless(hasattr(select, "poll"), "select.poll required")
473    def test_readwrite(self):
474        # Check that correct methods are called by readwrite()
475
476        attributes = ("read", "expt", "write", "closed", "error_handled")
477
478        expected = (
479            (select.POLLIN, "read"),
480            (select.POLLPRI, "expt"),
481            (select.POLLOUT, "write"),
482            (select.POLLERR, "closed"),
483            (select.POLLHUP, "closed"),
484            (select.POLLNVAL, "closed"),
485        )
486
487        class testobj:
488            def __init__(self):
489                self.read = False
490                self.write = False
491                self.closed = False
492                self.expt = False
493                self.error_handled = False
494
495            def handle_read_event(self):
496                self.read = True
497
498            def handle_write_event(self):
499                self.write = True
500
501            def handle_close(self):
502                self.closed = True
503
504            def handle_expt_event(self):
505                self.expt = True
506
507            # def handle_error(self):
508            #     self.error_handled = True
509
510        for flag, expectedattr in expected:
511            tobj = testobj()
512            self.assertEqual(getattr(tobj, expectedattr), False)
513            asyncore.readwrite(tobj, flag)
514
515            # Only the attribute modified by the routine we expect to be
516            # called should be True.
517
518            for attr in attributes:
519                self.assertEqual(getattr(tobj, attr), attr == expectedattr)
520
521            # check that ExitNow exceptions in the object handler method
522            # bubbles all the way up through asyncore readwrite call
523            tr1 = exitingdummy()
524            self.assertRaises(asyncore.ExitNow, asyncore.readwrite, tr1, flag)
525
526            # check that an exception other than ExitNow in the object handler
527            # method causes the handle_error method to get called
528            tr2 = crashingdummy()
529            self.assertEqual(tr2.error_handled, False)
530            asyncore.readwrite(tr2, flag)
531            self.assertEqual(tr2.error_handled, True)
532
533    def test_closeall(self):
534        self.closeall_check(False)
535
536    def test_closeall_default(self):
537        self.closeall_check(True)
538
539    def closeall_check(self, usedefault):
540        # Check that close_all() closes everything in a given map
541
542        l = []
543        testmap = {}
544
545        for i in range(10):
546            c = dummychannel()
547            l.append(c)
548            self.assertEqual(c.socket.closed, False)
549            testmap[i] = c
550
551        if usedefault:
552            socketmap = asyncore.socket_map
553            try:
554                asyncore.socket_map = testmap
555                asyncore.close_all()
556            finally:
557                testmap, asyncore.socket_map = asyncore.socket_map, socketmap
558        else:
559            asyncore.close_all(testmap)
560
561        self.assertEqual(len(testmap), 0)
562
563        for c in l:
564            self.assertEqual(c.socket.closed, True)
565
566    def test_compact_traceback(self):
567        try:
568            raise Exception("I don't like spam!")
569        except:
570            real_t, real_v, real_tb = sys.exc_info()
571            r = asyncore.compact_traceback()
572
573        (f, function, line), t, v, info = r
574        self.assertEqual(os.path.split(f)[-1], "test_wasyncore.py")
575        self.assertEqual(function, "test_compact_traceback")
576        self.assertEqual(t, real_t)
577        self.assertEqual(v, real_v)
578        self.assertEqual(info, "[%s|%s|%s]" % (f, function, line))
579
580
581class DispatcherTests(unittest.TestCase):
582    def setUp(self):
583        pass
584
585    def tearDown(self):
586        asyncore.close_all()
587
588    def test_basic(self):
589        d = asyncore.dispatcher()
590        self.assertEqual(d.readable(), True)
591        self.assertEqual(d.writable(), True)
592
593    def test_repr(self):
594        d = asyncore.dispatcher()
595        self.assertEqual(repr(d), "<waitress.wasyncore.dispatcher at %#x>" % id(d))
596
597    def test_log_info(self):
598        import logging
599
600        inst = asyncore.dispatcher(map={})
601        logger = DummyLogger()
602        inst.logger = logger
603        inst.log_info("message", "warning")
604        self.assertEqual(logger.messages, [(logging.WARN, "message")])
605
606    def test_log(self):
607        import logging
608
609        inst = asyncore.dispatcher()
610        logger = DummyLogger()
611        inst.logger = logger
612        inst.log("message")
613        self.assertEqual(logger.messages, [(logging.DEBUG, "message")])
614
615    def test_unhandled(self):
616        import logging
617
618        inst = asyncore.dispatcher()
619        logger = DummyLogger()
620        inst.logger = logger
621
622        inst.handle_expt()
623        inst.handle_read()
624        inst.handle_write()
625        inst.handle_connect()
626
627        expected = [
628            (logging.WARN, "unhandled incoming priority event"),
629            (logging.WARN, "unhandled read event"),
630            (logging.WARN, "unhandled write event"),
631            (logging.WARN, "unhandled connect event"),
632        ]
633        self.assertEqual(logger.messages, expected)
634
635    def test_strerror(self):
636        # refers to bug #8573
637        err = asyncore._strerror(errno.EPERM)
638
639        if hasattr(os, "strerror"):
640            self.assertEqual(err, os.strerror(errno.EPERM))
641        err = asyncore._strerror(-1)
642        self.assertTrue(err != "")
643
644
645class dispatcherwithsend_noread(asyncore.dispatcher_with_send):  # pragma: no cover
646    def readable(self):
647        return False
648
649    def handle_connect(self):
650        pass
651
652
653class DispatcherWithSendTests(unittest.TestCase):
654    def setUp(self):
655        pass
656
657    def tearDown(self):
658        asyncore.close_all()
659
660    @reap_threads
661    def test_send(self):
662        evt = threading.Event()
663        sock = socket.socket()
664        sock.settimeout(3)
665        port = bind_port(sock)
666
667        cap = BytesIO()
668        args = (evt, cap, sock)
669        t = threading.Thread(target=capture_server, args=args)
670        t.start()
671        try:
672            # wait a little longer for the server to initialize (it sometimes
673            # refuses connections on slow machines without this wait)
674            time.sleep(0.2)
675
676            data = b"Suppose there isn't a 16-ton weight?"
677            d = dispatcherwithsend_noread()
678            d.create_socket()
679            d.connect((HOST, port))
680
681            # give time for socket to connect
682            time.sleep(0.1)
683
684            d.send(data)
685            d.send(data)
686            d.send(b"\n")
687
688            n = 1000
689
690            while d.out_buffer and n > 0:  # pragma: no cover
691                asyncore.poll()
692                n -= 1
693
694            evt.wait()
695
696            self.assertEqual(cap.getvalue(), data * 2)
697        finally:
698            join_thread(t, timeout=TIMEOUT)
699
700
701@unittest.skipUnless(
702    hasattr(asyncore, "file_wrapper"), "asyncore.file_wrapper required"
703)
704class FileWrapperTest(unittest.TestCase):
705    def setUp(self):
706        self.d = b"It's not dead, it's sleeping!"
707        with open(TESTFN, "wb") as file:
708            file.write(self.d)
709
710    def tearDown(self):
711        unlink(TESTFN)
712
713    def test_recv(self):
714        fd = os.open(TESTFN, os.O_RDONLY)
715        w = asyncore.file_wrapper(fd)
716        os.close(fd)
717
718        self.assertNotEqual(w.fd, fd)
719        self.assertNotEqual(w.fileno(), fd)
720        self.assertEqual(w.recv(13), b"It's not dead")
721        self.assertEqual(w.read(6), b", it's")
722        w.close()
723        self.assertRaises(OSError, w.read, 1)
724
725    def test_send(self):
726        d1 = b"Come again?"
727        d2 = b"I want to buy some cheese."
728        fd = os.open(TESTFN, os.O_WRONLY | os.O_APPEND)
729        w = asyncore.file_wrapper(fd)
730        os.close(fd)
731
732        w.write(d1)
733        w.send(d2)
734        w.close()
735        with open(TESTFN, "rb") as file:
736            self.assertEqual(file.read(), self.d + d1 + d2)
737
738    @unittest.skipUnless(
739        hasattr(asyncore, "file_dispatcher"), "asyncore.file_dispatcher required"
740    )
741    def test_dispatcher(self):
742        fd = os.open(TESTFN, os.O_RDONLY)
743        data = []
744
745        class FileDispatcher(asyncore.file_dispatcher):
746            def handle_read(self):
747                data.append(self.recv(29))
748
749        FileDispatcher(fd)
750        os.close(fd)
751        asyncore.loop(timeout=0.01, use_poll=True, count=2)
752        self.assertEqual(b"".join(data), self.d)
753
754    def test_resource_warning(self):
755        # Issue #11453
756        got_warning = False
757
758        while got_warning is False:
759            # we try until we get the outcome we want because this
760            # test is not deterministic (gc_collect() may not
761            fd = os.open(TESTFN, os.O_RDONLY)
762            f = asyncore.file_wrapper(fd)
763
764            os.close(fd)
765
766            try:
767                with check_warnings(("", ResourceWarning)):
768                    f = None
769                    gc_collect()
770            except AssertionError:  # pragma: no cover
771                pass
772            else:
773                got_warning = True
774
775    def test_close_twice(self):
776        fd = os.open(TESTFN, os.O_RDONLY)
777        f = asyncore.file_wrapper(fd)
778        os.close(fd)
779
780        os.close(f.fd)  # file_wrapper dupped fd
781        with self.assertRaises(OSError):
782            f.close()
783
784        self.assertEqual(f.fd, -1)
785        # calling close twice should not fail
786        f.close()
787
788
789class BaseTestHandler(asyncore.dispatcher):  # pragma: no cover
790    def __init__(self, sock=None):
791        asyncore.dispatcher.__init__(self, sock)
792        self.flag = False
793
794    def handle_accept(self):
795        raise Exception("handle_accept not supposed to be called")
796
797    def handle_accepted(self):
798        raise Exception("handle_accepted not supposed to be called")
799
800    def handle_connect(self):
801        raise Exception("handle_connect not supposed to be called")
802
803    def handle_expt(self):
804        raise Exception("handle_expt not supposed to be called")
805
806    def handle_close(self):
807        raise Exception("handle_close not supposed to be called")
808
809    def handle_error(self):
810        raise
811
812
813class BaseServer(asyncore.dispatcher):
814    """A server which listens on an address and dispatches the
815    connection to a handler.
816    """
817
818    def __init__(self, family, addr, handler=BaseTestHandler):
819        asyncore.dispatcher.__init__(self)
820        self.create_socket(family)
821        self.set_reuse_addr()
822        bind_af_aware(self.socket, addr)
823        self.listen(5)
824        self.handler = handler
825
826    @property
827    def address(self):
828        return self.socket.getsockname()
829
830    def handle_accepted(self, sock, addr):
831        self.handler(sock)
832
833    def handle_error(self):  # pragma: no cover
834        raise
835
836
837class BaseClient(BaseTestHandler):
838    def __init__(self, family, address):
839        BaseTestHandler.__init__(self)
840        self.create_socket(family)
841        self.connect(address)
842
843    def handle_connect(self):
844        pass
845
846
847class BaseTestAPI:
848    def tearDown(self):
849        asyncore.close_all(ignore_all=True)
850
851    def loop_waiting_for_flag(self, instance, timeout=5):  # pragma: no cover
852        timeout = float(timeout) / 100
853        count = 100
854
855        while asyncore.socket_map and count > 0:
856            asyncore.loop(timeout=0.01, count=1, use_poll=self.use_poll)
857
858            if instance.flag:
859                return
860            count -= 1
861            time.sleep(timeout)
862        self.fail("flag not set")
863
864    def test_handle_connect(self):
865        # make sure handle_connect is called on connect()
866
867        class TestClient(BaseClient):
868            def handle_connect(self):
869                self.flag = True
870
871        server = BaseServer(self.family, self.addr)
872        client = TestClient(self.family, server.address)
873        self.loop_waiting_for_flag(client)
874
875    def test_handle_accept(self):
876        # make sure handle_accept() is called when a client connects
877
878        class TestListener(BaseTestHandler):
879            def __init__(self, family, addr):
880                BaseTestHandler.__init__(self)
881                self.create_socket(family)
882                bind_af_aware(self.socket, addr)
883                self.listen(5)
884                self.address = self.socket.getsockname()
885
886            def handle_accept(self):
887                self.flag = True
888
889        server = TestListener(self.family, self.addr)
890        client = BaseClient(self.family, server.address)
891        self.loop_waiting_for_flag(server)
892
893    def test_handle_accepted(self):
894        # make sure handle_accepted() is called when a client connects
895
896        class TestListener(BaseTestHandler):
897            def __init__(self, family, addr):
898                BaseTestHandler.__init__(self)
899                self.create_socket(family)
900                bind_af_aware(self.socket, addr)
901                self.listen(5)
902                self.address = self.socket.getsockname()
903
904            def handle_accept(self):
905                asyncore.dispatcher.handle_accept(self)
906
907            def handle_accepted(self, sock, addr):
908                sock.close()
909                self.flag = True
910
911        server = TestListener(self.family, self.addr)
912        client = BaseClient(self.family, server.address)
913        self.loop_waiting_for_flag(server)
914
915    def test_handle_read(self):
916        # make sure handle_read is called on data received
917
918        class TestClient(BaseClient):
919            def handle_read(self):
920                self.flag = True
921
922        class TestHandler(BaseTestHandler):
923            def __init__(self, conn):
924                BaseTestHandler.__init__(self, conn)
925                self.send(b"x" * 1024)
926
927        server = BaseServer(self.family, self.addr, TestHandler)
928        client = TestClient(self.family, server.address)
929        self.loop_waiting_for_flag(client)
930
931    def test_handle_write(self):
932        # make sure handle_write is called
933
934        class TestClient(BaseClient):
935            def handle_write(self):
936                self.flag = True
937
938        server = BaseServer(self.family, self.addr)
939        client = TestClient(self.family, server.address)
940        self.loop_waiting_for_flag(client)
941
942    def test_handle_close(self):
943        # make sure handle_close is called when the other end closes
944        # the connection
945
946        class TestClient(BaseClient):
947            def handle_read(self):
948                # in order to make handle_close be called we are supposed
949                # to make at least one recv() call
950                self.recv(1024)
951
952            def handle_close(self):
953                self.flag = True
954                self.close()
955
956        class TestHandler(BaseTestHandler):
957            def __init__(self, conn):
958                BaseTestHandler.__init__(self, conn)
959                self.close()
960
961        server = BaseServer(self.family, self.addr, TestHandler)
962        client = TestClient(self.family, server.address)
963        self.loop_waiting_for_flag(client)
964
965    def test_handle_close_after_conn_broken(self):
966        # Check that ECONNRESET/EPIPE is correctly handled (issues #5661 and
967        # #11265).
968
969        data = b"\0" * 128
970
971        class TestClient(BaseClient):
972            def handle_write(self):
973                self.send(data)
974
975            def handle_close(self):
976                self.flag = True
977                self.close()
978
979            def handle_expt(self):  # pragma: no cover
980                # needs to exist for MacOS testing
981                self.flag = True
982                self.close()
983
984        class TestHandler(BaseTestHandler):
985            def handle_read(self):
986                self.recv(len(data))
987                self.close()
988
989            def writable(self):
990                return False
991
992        server = BaseServer(self.family, self.addr, TestHandler)
993        client = TestClient(self.family, server.address)
994        self.loop_waiting_for_flag(client)
995
996    @unittest.skipIf(
997        sys.platform.startswith("sunos"), "OOB support is broken on Solaris"
998    )
999    def test_handle_expt(self):
1000        # Make sure handle_expt is called on OOB data received.
1001        # Note: this might fail on some platforms as OOB data is
1002        # tenuously supported and rarely used.
1003
1004        if HAS_UNIX_SOCKETS and self.family == socket.AF_UNIX:
1005            self.skipTest("Not applicable to AF_UNIX sockets.")
1006
1007        if sys.platform == "darwin" and self.use_poll:  # pragma: no cover
1008            self.skipTest("poll may fail on macOS; see issue #28087")
1009
1010        class TestClient(BaseClient):
1011            def handle_expt(self):
1012                self.socket.recv(1024, socket.MSG_OOB)
1013                self.flag = True
1014
1015        class TestHandler(BaseTestHandler):
1016            def __init__(self, conn):
1017                BaseTestHandler.__init__(self, conn)
1018                self.socket.send(chr(244).encode("latin-1"), socket.MSG_OOB)
1019
1020        server = BaseServer(self.family, self.addr, TestHandler)
1021        client = TestClient(self.family, server.address)
1022        self.loop_waiting_for_flag(client)
1023
1024    def test_handle_error(self):
1025        class TestClient(BaseClient):
1026            def handle_write(self):
1027                1.0 / 0
1028
1029            def handle_error(self):
1030                self.flag = True
1031                try:
1032                    raise
1033                except ZeroDivisionError:
1034                    pass
1035                else:  # pragma: no cover
1036                    raise Exception("exception not raised")
1037
1038        server = BaseServer(self.family, self.addr)
1039        client = TestClient(self.family, server.address)
1040        self.loop_waiting_for_flag(client)
1041
1042    def test_connection_attributes(self):
1043        server = BaseServer(self.family, self.addr)
1044        client = BaseClient(self.family, server.address)
1045
1046        # we start disconnected
1047        self.assertFalse(server.connected)
1048        self.assertTrue(server.accepting)
1049        # this can't be taken for granted across all platforms
1050        # self.assertFalse(client.connected)
1051        self.assertFalse(client.accepting)
1052
1053        # execute some loops so that client connects to server
1054        asyncore.loop(timeout=0.01, use_poll=self.use_poll, count=100)
1055        self.assertFalse(server.connected)
1056        self.assertTrue(server.accepting)
1057        self.assertTrue(client.connected)
1058        self.assertFalse(client.accepting)
1059
1060        # disconnect the client
1061        client.close()
1062        self.assertFalse(server.connected)
1063        self.assertTrue(server.accepting)
1064        self.assertFalse(client.connected)
1065        self.assertFalse(client.accepting)
1066
1067        # stop serving
1068        server.close()
1069        self.assertFalse(server.connected)
1070        self.assertFalse(server.accepting)
1071
1072    def test_create_socket(self):
1073        s = asyncore.dispatcher()
1074        s.create_socket(self.family)
1075        # self.assertEqual(s.socket.type, socket.SOCK_STREAM)
1076        self.assertEqual(s.socket.family, self.family)
1077        self.assertEqual(s.socket.gettimeout(), 0)
1078        # self.assertFalse(s.socket.get_inheritable())
1079
1080    def test_bind(self):
1081        if HAS_UNIX_SOCKETS and self.family == socket.AF_UNIX:
1082            self.skipTest("Not applicable to AF_UNIX sockets.")
1083        s1 = asyncore.dispatcher()
1084        s1.create_socket(self.family)
1085        s1.bind(self.addr)
1086        s1.listen(5)
1087        port = s1.socket.getsockname()[1]
1088
1089        s2 = asyncore.dispatcher()
1090        s2.create_socket(self.family)
1091        # EADDRINUSE indicates the socket was correctly bound
1092        self.assertRaises(socket.error, s2.bind, (self.addr[0], port))
1093
1094    def test_set_reuse_addr(self):  # pragma: no cover
1095        if HAS_UNIX_SOCKETS and self.family == socket.AF_UNIX:
1096            self.skipTest("Not applicable to AF_UNIX sockets.")
1097
1098        with closewrapper(socket.socket(self.family)) as sock:
1099            try:
1100                sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
1101            except OSError:
1102                unittest.skip("SO_REUSEADDR not supported on this platform")
1103            else:
1104                # if SO_REUSEADDR succeeded for sock we expect asyncore
1105                # to do the same
1106                s = asyncore.dispatcher(socket.socket(self.family))
1107                self.assertFalse(
1108                    s.socket.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR)
1109                )
1110                s.socket.close()
1111                s.create_socket(self.family)
1112                s.set_reuse_addr()
1113                self.assertTrue(
1114                    s.socket.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR)
1115                )
1116
1117    @reap_threads
1118    def test_quick_connect(self):  # pragma: no cover
1119        # see: http://bugs.python.org/issue10340
1120
1121        if self.family not in (socket.AF_INET, getattr(socket, "AF_INET6", object())):
1122            self.skipTest("test specific to AF_INET and AF_INET6")
1123
1124        server = BaseServer(self.family, self.addr)
1125        # run the thread 500 ms: the socket should be connected in 200 ms
1126        t = threading.Thread(target=lambda: asyncore.loop(timeout=0.1, count=5))
1127        t.start()
1128        try:
1129            sock = socket.socket(self.family, socket.SOCK_STREAM)
1130            with closewrapper(sock) as s:
1131                s.settimeout(0.2)
1132                s.setsockopt(
1133                    socket.SOL_SOCKET, socket.SO_LINGER, struct.pack("ii", 1, 0)
1134                )
1135
1136                try:
1137                    s.connect(server.address)
1138                except OSError:
1139                    pass
1140        finally:
1141            join_thread(t, timeout=TIMEOUT)
1142
1143
1144class TestAPI_UseIPv4Sockets(BaseTestAPI):
1145    family = socket.AF_INET
1146    addr = (HOST, 0)
1147
1148
1149@unittest.skipUnless(IPV6_ENABLED, "IPv6 support required")
1150class TestAPI_UseIPv6Sockets(BaseTestAPI):
1151    family = socket.AF_INET6
1152    addr = (HOSTv6, 0)
1153
1154
1155@unittest.skipUnless(HAS_UNIX_SOCKETS, "Unix sockets required")
1156class TestAPI_UseUnixSockets(BaseTestAPI):
1157    if HAS_UNIX_SOCKETS:
1158        family = socket.AF_UNIX
1159    addr = TESTFN
1160
1161    def tearDown(self):
1162        unlink(self.addr)
1163        BaseTestAPI.tearDown(self)
1164
1165
1166class TestAPI_UseIPv4Select(TestAPI_UseIPv4Sockets, unittest.TestCase):
1167    use_poll = False
1168
1169
1170@unittest.skipUnless(hasattr(select, "poll"), "select.poll required")
1171class TestAPI_UseIPv4Poll(TestAPI_UseIPv4Sockets, unittest.TestCase):
1172    use_poll = True
1173
1174
1175class TestAPI_UseIPv6Select(TestAPI_UseIPv6Sockets, unittest.TestCase):
1176    use_poll = False
1177
1178
1179@unittest.skipUnless(hasattr(select, "poll"), "select.poll required")
1180class TestAPI_UseIPv6Poll(TestAPI_UseIPv6Sockets, unittest.TestCase):
1181    use_poll = True
1182
1183
1184class TestAPI_UseUnixSocketsSelect(TestAPI_UseUnixSockets, unittest.TestCase):
1185    use_poll = False
1186
1187
1188@unittest.skipUnless(hasattr(select, "poll"), "select.poll required")
1189class TestAPI_UseUnixSocketsPoll(TestAPI_UseUnixSockets, unittest.TestCase):
1190    use_poll = True
1191
1192
1193class Test__strerror(unittest.TestCase):
1194    def _callFUT(self, err):
1195        from waitress.wasyncore import _strerror
1196
1197        return _strerror(err)
1198
1199    def test_gardenpath(self):
1200        self.assertEqual(self._callFUT(1), "Operation not permitted")
1201
1202    def test_unknown(self):
1203        self.assertEqual(self._callFUT("wut"), "Unknown error wut")
1204
1205
1206class Test_read(unittest.TestCase):
1207    def _callFUT(self, dispatcher):
1208        from waitress.wasyncore import read
1209
1210        return read(dispatcher)
1211
1212    def test_gardenpath(self):
1213        inst = DummyDispatcher()
1214        self._callFUT(inst)
1215        self.assertTrue(inst.read_event_handled)
1216        self.assertFalse(inst.error_handled)
1217
1218    def test_reraised(self):
1219        from waitress.wasyncore import ExitNow
1220
1221        inst = DummyDispatcher(ExitNow)
1222        self.assertRaises(ExitNow, self._callFUT, inst)
1223        self.assertTrue(inst.read_event_handled)
1224        self.assertFalse(inst.error_handled)
1225
1226    def test_non_reraised(self):
1227        inst = DummyDispatcher(OSError)
1228        self._callFUT(inst)
1229        self.assertTrue(inst.read_event_handled)
1230        self.assertTrue(inst.error_handled)
1231
1232
1233class Test_write(unittest.TestCase):
1234    def _callFUT(self, dispatcher):
1235        from waitress.wasyncore import write
1236
1237        return write(dispatcher)
1238
1239    def test_gardenpath(self):
1240        inst = DummyDispatcher()
1241        self._callFUT(inst)
1242        self.assertTrue(inst.write_event_handled)
1243        self.assertFalse(inst.error_handled)
1244
1245    def test_reraised(self):
1246        from waitress.wasyncore import ExitNow
1247
1248        inst = DummyDispatcher(ExitNow)
1249        self.assertRaises(ExitNow, self._callFUT, inst)
1250        self.assertTrue(inst.write_event_handled)
1251        self.assertFalse(inst.error_handled)
1252
1253    def test_non_reraised(self):
1254        inst = DummyDispatcher(OSError)
1255        self._callFUT(inst)
1256        self.assertTrue(inst.write_event_handled)
1257        self.assertTrue(inst.error_handled)
1258
1259
1260class Test__exception(unittest.TestCase):
1261    def _callFUT(self, dispatcher):
1262        from waitress.wasyncore import _exception
1263
1264        return _exception(dispatcher)
1265
1266    def test_gardenpath(self):
1267        inst = DummyDispatcher()
1268        self._callFUT(inst)
1269        self.assertTrue(inst.expt_event_handled)
1270        self.assertFalse(inst.error_handled)
1271
1272    def test_reraised(self):
1273        from waitress.wasyncore import ExitNow
1274
1275        inst = DummyDispatcher(ExitNow)
1276        self.assertRaises(ExitNow, self._callFUT, inst)
1277        self.assertTrue(inst.expt_event_handled)
1278        self.assertFalse(inst.error_handled)
1279
1280    def test_non_reraised(self):
1281        inst = DummyDispatcher(OSError)
1282        self._callFUT(inst)
1283        self.assertTrue(inst.expt_event_handled)
1284        self.assertTrue(inst.error_handled)
1285
1286
1287@unittest.skipUnless(hasattr(select, "poll"), "select.poll required")
1288class Test_readwrite(unittest.TestCase):
1289    def _callFUT(self, obj, flags):
1290        from waitress.wasyncore import readwrite
1291
1292        return readwrite(obj, flags)
1293
1294    def test_handle_read_event(self):
1295        flags = 0
1296        flags |= select.POLLIN
1297        inst = DummyDispatcher()
1298        self._callFUT(inst, flags)
1299        self.assertTrue(inst.read_event_handled)
1300
1301    def test_handle_write_event(self):
1302        flags = 0
1303        flags |= select.POLLOUT
1304        inst = DummyDispatcher()
1305        self._callFUT(inst, flags)
1306        self.assertTrue(inst.write_event_handled)
1307
1308    def test_handle_expt_event(self):
1309        flags = 0
1310        flags |= select.POLLPRI
1311        inst = DummyDispatcher()
1312        self._callFUT(inst, flags)
1313        self.assertTrue(inst.expt_event_handled)
1314
1315    def test_handle_close(self):
1316        flags = 0
1317        flags |= select.POLLHUP
1318        inst = DummyDispatcher()
1319        self._callFUT(inst, flags)
1320        self.assertTrue(inst.close_handled)
1321
1322    def test_socketerror_not_in_disconnected(self):
1323        flags = 0
1324        flags |= select.POLLIN
1325        inst = DummyDispatcher(socket.error(errno.EALREADY, "EALREADY"))
1326        self._callFUT(inst, flags)
1327        self.assertTrue(inst.read_event_handled)
1328        self.assertTrue(inst.error_handled)
1329
1330    def test_socketerror_in_disconnected(self):
1331        flags = 0
1332        flags |= select.POLLIN
1333        inst = DummyDispatcher(socket.error(errno.ECONNRESET, "ECONNRESET"))
1334        self._callFUT(inst, flags)
1335        self.assertTrue(inst.read_event_handled)
1336        self.assertTrue(inst.close_handled)
1337
1338    def test_exception_in_reraised(self):
1339        from waitress import wasyncore
1340
1341        flags = 0
1342        flags |= select.POLLIN
1343        inst = DummyDispatcher(wasyncore.ExitNow)
1344        self.assertRaises(wasyncore.ExitNow, self._callFUT, inst, flags)
1345        self.assertTrue(inst.read_event_handled)
1346
1347    def test_exception_not_in_reraised(self):
1348        flags = 0
1349        flags |= select.POLLIN
1350        inst = DummyDispatcher(ValueError)
1351        self._callFUT(inst, flags)
1352        self.assertTrue(inst.error_handled)
1353
1354
1355class Test_poll(unittest.TestCase):
1356    def _callFUT(self, timeout=0.0, map=None):
1357        from waitress.wasyncore import poll
1358
1359        return poll(timeout, map)
1360
1361    def test_nothing_writable_nothing_readable_but_map_not_empty(self):
1362        # i read the mock.patch docs.  nerp.
1363        dummy_time = DummyTime()
1364        map = {0: DummyDispatcher()}
1365        try:
1366            from waitress import wasyncore
1367
1368            old_time = wasyncore.time
1369            wasyncore.time = dummy_time
1370            result = self._callFUT(map=map)
1371        finally:
1372            wasyncore.time = old_time
1373        self.assertEqual(result, None)
1374        self.assertEqual(dummy_time.sleepvals, [0.0])
1375
1376    def test_select_raises_EINTR(self):
1377        # i read the mock.patch docs.  nerp.
1378        dummy_select = DummySelect(select.error(errno.EINTR))
1379        disp = DummyDispatcher()
1380        disp.readable = lambda: True
1381        map = {0: disp}
1382        try:
1383            from waitress import wasyncore
1384
1385            old_select = wasyncore.select
1386            wasyncore.select = dummy_select
1387            result = self._callFUT(map=map)
1388        finally:
1389            wasyncore.select = old_select
1390        self.assertEqual(result, None)
1391        self.assertEqual(dummy_select.selected, [([0], [], [0], 0.0)])
1392
1393    def test_select_raises_non_EINTR(self):
1394        # i read the mock.patch docs.  nerp.
1395        dummy_select = DummySelect(select.error(errno.EBADF))
1396        disp = DummyDispatcher()
1397        disp.readable = lambda: True
1398        map = {0: disp}
1399        try:
1400            from waitress import wasyncore
1401
1402            old_select = wasyncore.select
1403            wasyncore.select = dummy_select
1404            self.assertRaises(select.error, self._callFUT, map=map)
1405        finally:
1406            wasyncore.select = old_select
1407        self.assertEqual(dummy_select.selected, [([0], [], [0], 0.0)])
1408
1409
1410class Test_poll2(unittest.TestCase):
1411    def _callFUT(self, timeout=0.0, map=None):
1412        from waitress.wasyncore import poll2
1413
1414        return poll2(timeout, map)
1415
1416    def test_select_raises_EINTR(self):
1417        # i read the mock.patch docs.  nerp.
1418        pollster = DummyPollster(exc=select.error(errno.EINTR))
1419        dummy_select = DummySelect(pollster=pollster)
1420        disp = DummyDispatcher()
1421        map = {0: disp}
1422        try:
1423            from waitress import wasyncore
1424
1425            old_select = wasyncore.select
1426            wasyncore.select = dummy_select
1427            self._callFUT(map=map)
1428        finally:
1429            wasyncore.select = old_select
1430        self.assertEqual(pollster.polled, [0.0])
1431
1432    def test_select_raises_non_EINTR(self):
1433        # i read the mock.patch docs.  nerp.
1434        pollster = DummyPollster(exc=select.error(errno.EBADF))
1435        dummy_select = DummySelect(pollster=pollster)
1436        disp = DummyDispatcher()
1437        map = {0: disp}
1438        try:
1439            from waitress import wasyncore
1440
1441            old_select = wasyncore.select
1442            wasyncore.select = dummy_select
1443            self.assertRaises(select.error, self._callFUT, map=map)
1444        finally:
1445            wasyncore.select = old_select
1446        self.assertEqual(pollster.polled, [0.0])
1447
1448
1449class Test_dispatcher(unittest.TestCase):
1450    def _makeOne(self, sock=None, map=None):
1451        from waitress.wasyncore import dispatcher
1452
1453        return dispatcher(sock=sock, map=map)
1454
1455    def test_unexpected_getpeername_exc(self):
1456        sock = dummysocket()
1457
1458        def getpeername():
1459            raise OSError(errno.EBADF)
1460
1461        map = {}
1462        sock.getpeername = getpeername
1463        self.assertRaises(socket.error, self._makeOne, sock=sock, map=map)
1464        self.assertEqual(map, {})
1465
1466    def test___repr__accepting(self):
1467        sock = dummysocket()
1468        map = {}
1469        inst = self._makeOne(sock=sock, map=map)
1470        inst.accepting = True
1471        inst.addr = ("localhost", 8080)
1472        result = repr(inst)
1473        expected = "<waitress.wasyncore.dispatcher listening localhost:8080 at"
1474        self.assertEqual(result[: len(expected)], expected)
1475
1476    def test___repr__connected(self):
1477        sock = dummysocket()
1478        map = {}
1479        inst = self._makeOne(sock=sock, map=map)
1480        inst.accepting = False
1481        inst.connected = True
1482        inst.addr = ("localhost", 8080)
1483        result = repr(inst)
1484        expected = "<waitress.wasyncore.dispatcher connected localhost:8080 at"
1485        self.assertEqual(result[: len(expected)], expected)
1486
1487    def test_set_reuse_addr_with_socketerror(self):
1488        sock = dummysocket()
1489        map = {}
1490
1491        def setsockopt(*arg, **kw):
1492            sock.errored = True
1493            raise OSError
1494
1495        sock.setsockopt = setsockopt
1496        sock.getsockopt = lambda *arg: 0
1497        inst = self._makeOne(sock=sock, map=map)
1498        inst.set_reuse_addr()
1499        self.assertTrue(sock.errored)
1500
1501    def test_connect_raise_socket_error(self):
1502        sock = dummysocket()
1503        map = {}
1504        sock.connect_ex = lambda *arg: 1
1505        inst = self._makeOne(sock=sock, map=map)
1506        self.assertRaises(socket.error, inst.connect, 0)
1507
1508    def test_accept_raise_TypeError(self):
1509        sock = dummysocket()
1510        map = {}
1511
1512        def accept(*arg, **kw):
1513            raise TypeError
1514
1515        sock.accept = accept
1516        inst = self._makeOne(sock=sock, map=map)
1517        result = inst.accept()
1518        self.assertEqual(result, None)
1519
1520    def test_accept_raise_unexpected_socketerror(self):
1521        sock = dummysocket()
1522        map = {}
1523
1524        def accept(*arg, **kw):
1525            raise OSError(122)
1526
1527        sock.accept = accept
1528        inst = self._makeOne(sock=sock, map=map)
1529        self.assertRaises(socket.error, inst.accept)
1530
1531    def test_send_raise_EWOULDBLOCK(self):
1532        sock = dummysocket()
1533        map = {}
1534
1535        def send(*arg, **kw):
1536            raise OSError(errno.EWOULDBLOCK)
1537
1538        sock.send = send
1539        inst = self._makeOne(sock=sock, map=map)
1540        result = inst.send("a")
1541        self.assertEqual(result, 0)
1542
1543    def test_send_raise_unexpected_socketerror(self):
1544        sock = dummysocket()
1545        map = {}
1546
1547        def send(*arg, **kw):
1548            raise OSError(122)
1549
1550        sock.send = send
1551        inst = self._makeOne(sock=sock, map=map)
1552        self.assertRaises(socket.error, inst.send, "a")
1553
1554    def test_recv_raises_disconnect(self):
1555        sock = dummysocket()
1556        map = {}
1557
1558        def recv(*arg, **kw):
1559            raise OSError(errno.ECONNRESET)
1560
1561        def handle_close():
1562            inst.close_handled = True
1563
1564        sock.recv = recv
1565        inst = self._makeOne(sock=sock, map=map)
1566        inst.handle_close = handle_close
1567        result = inst.recv(1)
1568        self.assertEqual(result, b"")
1569        self.assertTrue(inst.close_handled)
1570
1571    def test_close_raises_unknown_socket_error(self):
1572        sock = dummysocket()
1573        map = {}
1574
1575        def close():
1576            raise OSError(122)
1577
1578        sock.close = close
1579        inst = self._makeOne(sock=sock, map=map)
1580        inst.del_channel = lambda: None
1581        self.assertRaises(socket.error, inst.close)
1582
1583    def test_handle_read_event_not_accepting_not_connected_connecting(self):
1584        sock = dummysocket()
1585        map = {}
1586        inst = self._makeOne(sock=sock, map=map)
1587
1588        def handle_connect_event():
1589            inst.connect_event_handled = True
1590
1591        def handle_read():
1592            inst.read_handled = True
1593
1594        inst.handle_connect_event = handle_connect_event
1595        inst.handle_read = handle_read
1596        inst.accepting = False
1597        inst.connected = False
1598        inst.connecting = True
1599        inst.handle_read_event()
1600        self.assertTrue(inst.connect_event_handled)
1601        self.assertTrue(inst.read_handled)
1602
1603    def test_handle_connect_event_getsockopt_returns_error(self):
1604        sock = dummysocket()
1605        sock.getsockopt = lambda *arg: 122
1606        map = {}
1607        inst = self._makeOne(sock=sock, map=map)
1608        self.assertRaises(socket.error, inst.handle_connect_event)
1609
1610    def test_handle_expt_event_getsockopt_returns_error(self):
1611        sock = dummysocket()
1612        sock.getsockopt = lambda *arg: 122
1613        map = {}
1614        inst = self._makeOne(sock=sock, map=map)
1615
1616        def handle_close():
1617            inst.close_handled = True
1618
1619        inst.handle_close = handle_close
1620        inst.handle_expt_event()
1621        self.assertTrue(inst.close_handled)
1622
1623    def test_handle_write_event_while_accepting(self):
1624        sock = dummysocket()
1625        map = {}
1626        inst = self._makeOne(sock=sock, map=map)
1627        inst.accepting = True
1628        result = inst.handle_write_event()
1629        self.assertEqual(result, None)
1630
1631    def test_handle_error_gardenpath(self):
1632        sock = dummysocket()
1633        map = {}
1634        inst = self._makeOne(sock=sock, map=map)
1635
1636        def handle_close():
1637            inst.close_handled = True
1638
1639        def compact_traceback(*arg, **kw):
1640            return None, None, None, None
1641
1642        def log_info(self, *arg):
1643            inst.logged_info = arg
1644
1645        inst.handle_close = handle_close
1646        inst.compact_traceback = compact_traceback
1647        inst.log_info = log_info
1648        inst.handle_error()
1649        self.assertTrue(inst.close_handled)
1650        self.assertEqual(inst.logged_info, ("error",))
1651
1652    def test_handle_close(self):
1653        sock = dummysocket()
1654        map = {}
1655        inst = self._makeOne(sock=sock, map=map)
1656
1657        def log_info(self, *arg):
1658            inst.logged_info = arg
1659
1660        def close():
1661            inst._closed = True
1662
1663        inst.log_info = log_info
1664        inst.close = close
1665        inst.handle_close()
1666        self.assertTrue(inst._closed)
1667
1668    def test_handle_accepted(self):
1669        sock = dummysocket()
1670        map = {}
1671        inst = self._makeOne(sock=sock, map=map)
1672        inst.handle_accepted(sock, "1")
1673        self.assertTrue(sock.closed)
1674
1675
1676class Test_dispatcher_with_send(unittest.TestCase):
1677    def _makeOne(self, sock=None, map=None):
1678        from waitress.wasyncore import dispatcher_with_send
1679
1680        return dispatcher_with_send(sock=sock, map=map)
1681
1682    def test_writable(self):
1683        sock = dummysocket()
1684        map = {}
1685        inst = self._makeOne(sock=sock, map=map)
1686        inst.out_buffer = b"123"
1687        inst.connected = True
1688        self.assertTrue(inst.writable())
1689
1690
1691class Test_close_all(unittest.TestCase):
1692    def _callFUT(self, map=None, ignore_all=False):
1693        from waitress.wasyncore import close_all
1694
1695        return close_all(map, ignore_all)
1696
1697    def test_socketerror_on_close_ebadf(self):
1698        disp = DummyDispatcher(exc=socket.error(errno.EBADF))
1699        map = {0: disp}
1700        self._callFUT(map)
1701        self.assertEqual(map, {})
1702
1703    def test_socketerror_on_close_non_ebadf(self):
1704        disp = DummyDispatcher(exc=socket.error(errno.EAGAIN))
1705        map = {0: disp}
1706        self.assertRaises(socket.error, self._callFUT, map)
1707
1708    def test_reraised_exc_on_close(self):
1709        disp = DummyDispatcher(exc=KeyboardInterrupt)
1710        map = {0: disp}
1711        self.assertRaises(KeyboardInterrupt, self._callFUT, map)
1712
1713    def test_unknown_exc_on_close(self):
1714        disp = DummyDispatcher(exc=RuntimeError)
1715        map = {0: disp}
1716        self.assertRaises(RuntimeError, self._callFUT, map)
1717
1718
1719class DummyDispatcher:
1720    read_event_handled = False
1721    write_event_handled = False
1722    expt_event_handled = False
1723    error_handled = False
1724    close_handled = False
1725    accepting = False
1726
1727    def __init__(self, exc=None):
1728        self.exc = exc
1729
1730    def handle_read_event(self):
1731        self.read_event_handled = True
1732
1733        if self.exc is not None:
1734            raise self.exc
1735
1736    def handle_write_event(self):
1737        self.write_event_handled = True
1738
1739        if self.exc is not None:
1740            raise self.exc
1741
1742    def handle_expt_event(self):
1743        self.expt_event_handled = True
1744
1745        if self.exc is not None:
1746            raise self.exc
1747
1748    def handle_error(self):
1749        self.error_handled = True
1750
1751    def handle_close(self):
1752        self.close_handled = True
1753
1754    def readable(self):
1755        return False
1756
1757    def writable(self):
1758        return False
1759
1760    def close(self):
1761        if self.exc is not None:
1762            raise self.exc
1763
1764
1765class DummyTime:
1766    def __init__(self):
1767        self.sleepvals = []
1768
1769    def sleep(self, val):
1770        self.sleepvals.append(val)
1771
1772
1773class DummySelect:
1774    error = select.error
1775
1776    def __init__(self, exc=None, pollster=None):
1777        self.selected = []
1778        self.pollster = pollster
1779        self.exc = exc
1780
1781    def select(self, *arg):
1782        self.selected.append(arg)
1783
1784        if self.exc is not None:
1785            raise self.exc
1786
1787    def poll(self):
1788        return self.pollster
1789
1790
1791class DummyPollster:
1792    def __init__(self, exc=None):
1793        self.polled = []
1794        self.exc = exc
1795
1796    def poll(self, timeout):
1797        self.polled.append(timeout)
1798
1799        if self.exc is not None:
1800            raise self.exc
1801        else:  # pragma: no cover
1802            return []
1803