1import asyncore
2import unittest
3import select
4import os
5import socket
6import sys
7import time
8import warnings
9import errno
10import struct
11
12from test import test_support
13from test.test_support import TESTFN, run_unittest, unlink, HOST
14from StringIO import StringIO
15
16try:
17    import threading
18except ImportError:
19    threading = None
20
21
22class dummysocket:
23    def __init__(self):
24        self.closed = False
25
26    def close(self):
27        self.closed = True
28
29    def fileno(self):
30        return 42
31
32class dummychannel:
33    def __init__(self):
34        self.socket = dummysocket()
35
36    def close(self):
37        self.socket.close()
38
39class exitingdummy:
40    def __init__(self):
41        pass
42
43    def handle_read_event(self):
44        raise asyncore.ExitNow()
45
46    handle_write_event = handle_read_event
47    handle_close = handle_read_event
48    handle_expt_event = handle_read_event
49
50class crashingdummy:
51    def __init__(self):
52        self.error_handled = False
53
54    def handle_read_event(self):
55        raise Exception()
56
57    handle_write_event = handle_read_event
58    handle_close = handle_read_event
59    handle_expt_event = handle_read_event
60
61    def handle_error(self):
62        self.error_handled = True
63
64# used when testing senders; just collects what it gets until newline is sent
65def capture_server(evt, buf, serv):
66    try:
67        serv.listen(5)
68        conn, addr = serv.accept()
69    except socket.timeout:
70        pass
71    else:
72        n = 200
73        while n > 0:
74            r, w, e = select.select([conn], [], [])
75            if r:
76                data = conn.recv(10)
77                # keep everything except for the newline terminator
78                buf.write(data.replace('\n', ''))
79                if '\n' in data:
80                    break
81            n -= 1
82            time.sleep(0.01)
83
84        conn.close()
85    finally:
86        serv.close()
87        evt.set()
88
89
90class HelperFunctionTests(unittest.TestCase):
91    def test_readwriteexc(self):
92        # Check exception handling behavior of read, write and _exception
93
94        # check that ExitNow exceptions in the object handler method
95        # bubbles all the way up through asyncore read/write/_exception calls
96        tr1 = exitingdummy()
97        self.assertRaises(asyncore.ExitNow, asyncore.read, tr1)
98        self.assertRaises(asyncore.ExitNow, asyncore.write, tr1)
99        self.assertRaises(asyncore.ExitNow, asyncore._exception, tr1)
100
101        # check that an exception other than ExitNow in the object handler
102        # method causes the handle_error method to get called
103        tr2 = crashingdummy()
104        asyncore.read(tr2)
105        self.assertEqual(tr2.error_handled, True)
106
107        tr2 = crashingdummy()
108        asyncore.write(tr2)
109        self.assertEqual(tr2.error_handled, True)
110
111        tr2 = crashingdummy()
112        asyncore._exception(tr2)
113        self.assertEqual(tr2.error_handled, True)
114
115    # asyncore.readwrite uses constants in the select module that
116    # are not present in Windows systems (see this thread:
117    # http://mail.python.org/pipermail/python-list/2001-October/109973.html)
118    # These constants should be present as long as poll is available
119
120    @unittest.skipUnless(hasattr(select, 'poll'), 'select.poll required')
121    def test_readwrite(self):
122        # Check that correct methods are called by readwrite()
123
124        attributes = ('read', 'expt', 'write', 'closed', 'error_handled')
125
126        expected = (
127            (select.POLLIN, 'read'),
128            (select.POLLPRI, 'expt'),
129            (select.POLLOUT, 'write'),
130            (select.POLLERR, 'closed'),
131            (select.POLLHUP, 'closed'),
132            (select.POLLNVAL, 'closed'),
133            )
134
135        class testobj:
136            def __init__(self):
137                self.read = False
138                self.write = False
139                self.closed = False
140                self.expt = False
141                self.error_handled = False
142
143            def handle_read_event(self):
144                self.read = True
145
146            def handle_write_event(self):
147                self.write = True
148
149            def handle_close(self):
150                self.closed = True
151
152            def handle_expt_event(self):
153                self.expt = True
154
155            def handle_error(self):
156                self.error_handled = True
157
158        for flag, expectedattr in expected:
159            tobj = testobj()
160            self.assertEqual(getattr(tobj, expectedattr), False)
161            asyncore.readwrite(tobj, flag)
162
163            # Only the attribute modified by the routine we expect to be
164            # called should be True.
165            for attr in attributes:
166                self.assertEqual(getattr(tobj, attr), attr==expectedattr)
167
168            # check that ExitNow exceptions in the object handler method
169            # bubbles all the way up through asyncore readwrite call
170            tr1 = exitingdummy()
171            self.assertRaises(asyncore.ExitNow, asyncore.readwrite, tr1, flag)
172
173            # check that an exception other than ExitNow in the object handler
174            # method causes the handle_error method to get called
175            tr2 = crashingdummy()
176            self.assertEqual(tr2.error_handled, False)
177            asyncore.readwrite(tr2, flag)
178            self.assertEqual(tr2.error_handled, True)
179
180    def test_closeall(self):
181        self.closeall_check(False)
182
183    def test_closeall_default(self):
184        self.closeall_check(True)
185
186    def closeall_check(self, usedefault):
187        # Check that close_all() closes everything in a given map
188
189        l = []
190        testmap = {}
191        for i in range(10):
192            c = dummychannel()
193            l.append(c)
194            self.assertEqual(c.socket.closed, False)
195            testmap[i] = c
196
197        if usedefault:
198            socketmap = asyncore.socket_map
199            try:
200                asyncore.socket_map = testmap
201                asyncore.close_all()
202            finally:
203                testmap, asyncore.socket_map = asyncore.socket_map, socketmap
204        else:
205            asyncore.close_all(testmap)
206
207        self.assertEqual(len(testmap), 0)
208
209        for c in l:
210            self.assertEqual(c.socket.closed, True)
211
212    def test_compact_traceback(self):
213        try:
214            raise Exception("I don't like spam!")
215        except:
216            real_t, real_v, real_tb = sys.exc_info()
217            r = asyncore.compact_traceback()
218        else:
219            self.fail("Expected exception")
220
221        (f, function, line), t, v, info = r
222        self.assertEqual(os.path.split(f)[-1], 'test_asyncore.py')
223        self.assertEqual(function, 'test_compact_traceback')
224        self.assertEqual(t, real_t)
225        self.assertEqual(v, real_v)
226        self.assertEqual(info, '[%s|%s|%s]' % (f, function, line))
227
228
229class DispatcherTests(unittest.TestCase):
230    def setUp(self):
231        pass
232
233    def tearDown(self):
234        asyncore.close_all()
235
236    def test_basic(self):
237        d = asyncore.dispatcher()
238        self.assertEqual(d.readable(), True)
239        self.assertEqual(d.writable(), True)
240
241    def test_repr(self):
242        d = asyncore.dispatcher()
243        self.assertEqual(repr(d), '<asyncore.dispatcher at %#x>' % id(d))
244
245    def test_log(self):
246        d = asyncore.dispatcher()
247
248        # capture output of dispatcher.log() (to stderr)
249        fp = StringIO()
250        stderr = sys.stderr
251        l1 = "Lovely spam! Wonderful spam!"
252        l2 = "I don't like spam!"
253        try:
254            sys.stderr = fp
255            d.log(l1)
256            d.log(l2)
257        finally:
258            sys.stderr = stderr
259
260        lines = fp.getvalue().splitlines()
261        self.assertEqual(lines, ['log: %s' % l1, 'log: %s' % l2])
262
263    def test_log_info(self):
264        d = asyncore.dispatcher()
265
266        # capture output of dispatcher.log_info() (to stdout via print)
267        fp = StringIO()
268        stdout = sys.stdout
269        l1 = "Have you got anything without spam?"
270        l2 = "Why can't she have egg bacon spam and sausage?"
271        l3 = "THAT'S got spam in it!"
272        try:
273            sys.stdout = fp
274            d.log_info(l1, 'EGGS')
275            d.log_info(l2)
276            d.log_info(l3, 'SPAM')
277        finally:
278            sys.stdout = stdout
279
280        lines = fp.getvalue().splitlines()
281        expected = ['EGGS: %s' % l1, 'info: %s' % l2, 'SPAM: %s' % l3]
282
283        self.assertEqual(lines, expected)
284
285    def test_unhandled(self):
286        d = asyncore.dispatcher()
287        d.ignore_log_types = ()
288
289        # capture output of dispatcher.log_info() (to stdout via print)
290        fp = StringIO()
291        stdout = sys.stdout
292        try:
293            sys.stdout = fp
294            d.handle_expt()
295            d.handle_read()
296            d.handle_write()
297            d.handle_connect()
298            d.handle_accept()
299        finally:
300            sys.stdout = stdout
301
302        lines = fp.getvalue().splitlines()
303        expected = ['warning: unhandled incoming priority event',
304                    'warning: unhandled read event',
305                    'warning: unhandled write event',
306                    'warning: unhandled connect event',
307                    'warning: unhandled accept event']
308        self.assertEqual(lines, expected)
309
310    def test_issue_8594(self):
311        # XXX - this test is supposed to be removed in next major Python
312        # version
313        d = asyncore.dispatcher(socket.socket())
314        # make sure the error message no longer refers to the socket
315        # object but the dispatcher instance instead
316        self.assertRaisesRegexp(AttributeError, 'dispatcher instance',
317                                getattr, d, 'foo')
318        # cheap inheritance with the underlying socket is supposed
319        # to still work but a DeprecationWarning is expected
320        with warnings.catch_warnings(record=True) as w:
321            warnings.simplefilter("always")
322            family = d.family
323            self.assertEqual(family, socket.AF_INET)
324            self.assertEqual(len(w), 1)
325            self.assertTrue(issubclass(w[0].category, DeprecationWarning))
326
327    def test_strerror(self):
328        # refers to bug #8573
329        err = asyncore._strerror(errno.EPERM)
330        if hasattr(os, 'strerror'):
331            self.assertEqual(err, os.strerror(errno.EPERM))
332        err = asyncore._strerror(-1)
333        self.assertTrue(err != "")
334
335
336class dispatcherwithsend_noread(asyncore.dispatcher_with_send):
337    def readable(self):
338        return False
339
340    def handle_connect(self):
341        pass
342
343class DispatcherWithSendTests(unittest.TestCase):
344    usepoll = False
345
346    def setUp(self):
347        pass
348
349    def tearDown(self):
350        asyncore.close_all()
351
352    @unittest.skipUnless(threading, 'Threading required for this test.')
353    @test_support.reap_threads
354    def test_send(self):
355        evt = threading.Event()
356        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
357        sock.settimeout(3)
358        port = test_support.bind_port(sock)
359
360        cap = StringIO()
361        args = (evt, cap, sock)
362        t = threading.Thread(target=capture_server, args=args)
363        t.start()
364        try:
365            # wait a little longer for the server to initialize (it sometimes
366            # refuses connections on slow machines without this wait)
367            time.sleep(0.2)
368
369            data = "Suppose there isn't a 16-ton weight?"
370            d = dispatcherwithsend_noread()
371            d.create_socket(socket.AF_INET, socket.SOCK_STREAM)
372            d.connect((HOST, port))
373
374            # give time for socket to connect
375            time.sleep(0.1)
376
377            d.send(data)
378            d.send(data)
379            d.send('\n')
380
381            n = 1000
382            while d.out_buffer and n > 0:
383                asyncore.poll()
384                n -= 1
385
386            evt.wait()
387
388            self.assertEqual(cap.getvalue(), data*2)
389        finally:
390            t.join()
391
392
393class DispatcherWithSendTests_UsePoll(DispatcherWithSendTests):
394    usepoll = True
395
396@unittest.skipUnless(hasattr(asyncore, 'file_wrapper'),
397                     'asyncore.file_wrapper required')
398class FileWrapperTest(unittest.TestCase):
399    def setUp(self):
400        self.d = "It's not dead, it's sleeping!"
401        with file(TESTFN, 'w') as h:
402            h.write(self.d)
403
404    def tearDown(self):
405        unlink(TESTFN)
406
407    def test_recv(self):
408        fd = os.open(TESTFN, os.O_RDONLY)
409        w = asyncore.file_wrapper(fd)
410        os.close(fd)
411
412        self.assertNotEqual(w.fd, fd)
413        self.assertNotEqual(w.fileno(), fd)
414        self.assertEqual(w.recv(13), "It's not dead")
415        self.assertEqual(w.read(6), ", it's")
416        w.close()
417        self.assertRaises(OSError, w.read, 1)
418
419
420    def test_send(self):
421        d1 = "Come again?"
422        d2 = "I want to buy some cheese."
423        fd = os.open(TESTFN, os.O_WRONLY | os.O_APPEND)
424        w = asyncore.file_wrapper(fd)
425        os.close(fd)
426
427        w.write(d1)
428        w.send(d2)
429        w.close()
430        self.assertEqual(file(TESTFN).read(), self.d + d1 + d2)
431
432    @unittest.skipUnless(hasattr(asyncore, 'file_dispatcher'),
433                         'asyncore.file_dispatcher required')
434    def test_dispatcher(self):
435        fd = os.open(TESTFN, os.O_RDONLY)
436        data = []
437        class FileDispatcher(asyncore.file_dispatcher):
438            def handle_read(self):
439                data.append(self.recv(29))
440        s = FileDispatcher(fd)
441        os.close(fd)
442        asyncore.loop(timeout=0.01, use_poll=True, count=2)
443        self.assertEqual(b"".join(data), self.d)
444
445    def test_close_twice(self):
446        fd = os.open(TESTFN, os.O_RDONLY)
447        f = asyncore.file_wrapper(fd)
448        os.close(fd)
449
450        os.close(f.fd)  # file_wrapper dupped fd
451        with self.assertRaises(OSError):
452            f.close()
453
454        self.assertEqual(f.fd, -1)
455        # calling close twice should not fail
456        f.close()
457
458
459class BaseTestHandler(asyncore.dispatcher):
460
461    def __init__(self, sock=None):
462        asyncore.dispatcher.__init__(self, sock)
463        self.flag = False
464
465    def handle_accept(self):
466        raise Exception("handle_accept not supposed to be called")
467
468    def handle_connect(self):
469        raise Exception("handle_connect not supposed to be called")
470
471    def handle_expt(self):
472        raise Exception("handle_expt not supposed to be called")
473
474    def handle_close(self):
475        raise Exception("handle_close not supposed to be called")
476
477    def handle_error(self):
478        raise
479
480
481class TCPServer(asyncore.dispatcher):
482    """A server which listens on an address and dispatches the
483    connection to a handler.
484    """
485
486    def __init__(self, handler=BaseTestHandler, host=HOST, port=0):
487        asyncore.dispatcher.__init__(self)
488        self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
489        self.set_reuse_addr()
490        self.bind((host, port))
491        self.listen(5)
492        self.handler = handler
493
494    @property
495    def address(self):
496        return self.socket.getsockname()[:2]
497
498    def handle_accept(self):
499        pair = self.accept()
500        if pair is not None:
501            self.handler(pair[0])
502
503    def handle_error(self):
504        raise
505
506
507class BaseClient(BaseTestHandler):
508
509    def __init__(self, address):
510        BaseTestHandler.__init__(self)
511        self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
512        self.connect(address)
513
514    def handle_connect(self):
515        pass
516
517
518class BaseTestAPI(unittest.TestCase):
519
520    def tearDown(self):
521        asyncore.close_all()
522
523    def loop_waiting_for_flag(self, instance, timeout=5):
524        timeout = float(timeout) / 100
525        count = 100
526        while asyncore.socket_map and count > 0:
527            asyncore.loop(timeout=0.01, count=1, use_poll=self.use_poll)
528            if instance.flag:
529                return
530            count -= 1
531            time.sleep(timeout)
532        self.fail("flag not set")
533
534    def test_handle_connect(self):
535        # make sure handle_connect is called on connect()
536
537        class TestClient(BaseClient):
538            def handle_connect(self):
539                self.flag = True
540
541        server = TCPServer()
542        client = TestClient(server.address)
543        self.loop_waiting_for_flag(client)
544
545    def test_handle_accept(self):
546        # make sure handle_accept() is called when a client connects
547
548        class TestListener(BaseTestHandler):
549
550            def __init__(self):
551                BaseTestHandler.__init__(self)
552                self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
553                self.bind((HOST, 0))
554                self.listen(5)
555                self.address = self.socket.getsockname()[:2]
556
557            def handle_accept(self):
558                self.flag = True
559
560        server = TestListener()
561        client = BaseClient(server.address)
562        self.loop_waiting_for_flag(server)
563
564    def test_handle_read(self):
565        # make sure handle_read is called on data received
566
567        class TestClient(BaseClient):
568            def handle_read(self):
569                self.flag = True
570
571        class TestHandler(BaseTestHandler):
572            def __init__(self, conn):
573                BaseTestHandler.__init__(self, conn)
574                self.send('x' * 1024)
575
576        server = TCPServer(TestHandler)
577        client = TestClient(server.address)
578        self.loop_waiting_for_flag(client)
579
580    def test_handle_write(self):
581        # make sure handle_write is called
582
583        class TestClient(BaseClient):
584            def handle_write(self):
585                self.flag = True
586
587        server = TCPServer()
588        client = TestClient(server.address)
589        self.loop_waiting_for_flag(client)
590
591    def test_handle_close(self):
592        # make sure handle_close is called when the other end closes
593        # the connection
594
595        class TestClient(BaseClient):
596
597            def handle_read(self):
598                # in order to make handle_close be called we are supposed
599                # to make at least one recv() call
600                self.recv(1024)
601
602            def handle_close(self):
603                self.flag = True
604                self.close()
605
606        class TestHandler(BaseTestHandler):
607            def __init__(self, conn):
608                BaseTestHandler.__init__(self, conn)
609                self.close()
610
611        server = TCPServer(TestHandler)
612        client = TestClient(server.address)
613        self.loop_waiting_for_flag(client)
614
615    @unittest.skipIf(sys.platform.startswith("sunos"),
616                     "OOB support is broken on Solaris")
617    def test_handle_expt(self):
618        # Make sure handle_expt is called on OOB data received.
619        # Note: this might fail on some platforms as OOB data is
620        # tenuously supported and rarely used.
621
622        if sys.platform == "darwin" and self.use_poll:
623            self.skipTest("poll may fail on macOS; see issue #28087")
624
625        class TestClient(BaseClient):
626            def handle_expt(self):
627                self.flag = True
628
629        class TestHandler(BaseTestHandler):
630            def __init__(self, conn):
631                BaseTestHandler.__init__(self, conn)
632                self.socket.send(chr(244), socket.MSG_OOB)
633
634        server = TCPServer(TestHandler)
635        client = TestClient(server.address)
636        self.loop_waiting_for_flag(client)
637
638    def test_handle_error(self):
639
640        class TestClient(BaseClient):
641            def handle_write(self):
642                1.0 / 0
643            def handle_error(self):
644                self.flag = True
645                try:
646                    raise
647                except ZeroDivisionError:
648                    pass
649                else:
650                    raise Exception("exception not raised")
651
652        server = TCPServer()
653        client = TestClient(server.address)
654        self.loop_waiting_for_flag(client)
655
656    def test_connection_attributes(self):
657        server = TCPServer()
658        client = BaseClient(server.address)
659
660        # we start disconnected
661        self.assertFalse(server.connected)
662        self.assertTrue(server.accepting)
663        # this can't be taken for granted across all platforms
664        #self.assertFalse(client.connected)
665        self.assertFalse(client.accepting)
666
667        # execute some loops so that client connects to server
668        asyncore.loop(timeout=0.01, use_poll=self.use_poll, count=100)
669        self.assertFalse(server.connected)
670        self.assertTrue(server.accepting)
671        self.assertTrue(client.connected)
672        self.assertFalse(client.accepting)
673
674        # disconnect the client
675        client.close()
676        self.assertFalse(server.connected)
677        self.assertTrue(server.accepting)
678        self.assertFalse(client.connected)
679        self.assertFalse(client.accepting)
680
681        # stop serving
682        server.close()
683        self.assertFalse(server.connected)
684        self.assertFalse(server.accepting)
685
686    def test_create_socket(self):
687        s = asyncore.dispatcher()
688        s.create_socket(socket.AF_INET, socket.SOCK_STREAM)
689        self.assertEqual(s.socket.family, socket.AF_INET)
690        self.assertEqual(s.socket.type, socket.SOCK_STREAM)
691
692    def test_bind(self):
693        s1 = asyncore.dispatcher()
694        s1.create_socket(socket.AF_INET, socket.SOCK_STREAM)
695        s1.bind((HOST, 0))
696        s1.listen(5)
697        port = s1.socket.getsockname()[1]
698
699        s2 = asyncore.dispatcher()
700        s2.create_socket(socket.AF_INET, socket.SOCK_STREAM)
701        # EADDRINUSE indicates the socket was correctly bound
702        self.assertRaises(socket.error, s2.bind, (HOST, port))
703
704    def test_set_reuse_addr(self):
705        sock = socket.socket()
706        try:
707            sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
708        except socket.error:
709            unittest.skip("SO_REUSEADDR not supported on this platform")
710        else:
711            # if SO_REUSEADDR succeeded for sock we expect asyncore
712            # to do the same
713            s = asyncore.dispatcher(socket.socket())
714            self.assertFalse(s.socket.getsockopt(socket.SOL_SOCKET,
715                                                 socket.SO_REUSEADDR))
716            s.create_socket(socket.AF_INET, socket.SOCK_STREAM)
717            s.set_reuse_addr()
718            self.assertTrue(s.socket.getsockopt(socket.SOL_SOCKET,
719                                                 socket.SO_REUSEADDR))
720        finally:
721            sock.close()
722
723    @unittest.skipUnless(threading, 'Threading required for this test.')
724    @test_support.reap_threads
725    def test_quick_connect(self):
726        # see: http://bugs.python.org/issue10340
727        server = TCPServer()
728        t = threading.Thread(target=lambda: asyncore.loop(timeout=0.1, count=500))
729        t.start()
730        try:
731            for x in xrange(20):
732                s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
733                s.settimeout(.2)
734                s.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER,
735                             struct.pack('ii', 1, 0))
736                try:
737                    s.connect(server.address)
738                except socket.error:
739                    pass
740                finally:
741                    s.close()
742        finally:
743            t.join()
744
745
746class TestAPI_UseSelect(BaseTestAPI):
747    use_poll = False
748
749@unittest.skipUnless(hasattr(select, 'poll'), 'select.poll required')
750class TestAPI_UsePoll(BaseTestAPI):
751    use_poll = True
752
753
754def test_main():
755    tests = [HelperFunctionTests, DispatcherTests, DispatcherWithSendTests,
756             DispatcherWithSendTests_UsePoll, TestAPI_UseSelect,
757             TestAPI_UsePoll, FileWrapperTest]
758    run_unittest(*tests)
759
760if __name__ == "__main__":
761    test_main()
762