1import errno
2import os
3import random
4import selectors
5import signal
6import socket
7import sys
8from test import support
9from test.support import os_helper
10from test.support import socket_helper
11from time import sleep
12import unittest
13import unittest.mock
14import tempfile
15from time import monotonic as time
16try:
17    import resource
18except ImportError:
19    resource = None
20
21
22if hasattr(socket, 'socketpair'):
23    socketpair = socket.socketpair
24else:
25    def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0):
26        with socket.socket(family, type, proto) as l:
27            l.bind((socket_helper.HOST, 0))
28            l.listen()
29            c = socket.socket(family, type, proto)
30            try:
31                c.connect(l.getsockname())
32                caddr = c.getsockname()
33                while True:
34                    a, addr = l.accept()
35                    # check that we've got the correct client
36                    if addr == caddr:
37                        return c, a
38                    a.close()
39            except OSError:
40                c.close()
41                raise
42
43
44def find_ready_matching(ready, flag):
45    match = []
46    for key, events in ready:
47        if events & flag:
48            match.append(key.fileobj)
49    return match
50
51
52class BaseSelectorTestCase:
53
54    def make_socketpair(self):
55        rd, wr = socketpair()
56        self.addCleanup(rd.close)
57        self.addCleanup(wr.close)
58        return rd, wr
59
60    def test_register(self):
61        s = self.SELECTOR()
62        self.addCleanup(s.close)
63
64        rd, wr = self.make_socketpair()
65
66        key = s.register(rd, selectors.EVENT_READ, "data")
67        self.assertIsInstance(key, selectors.SelectorKey)
68        self.assertEqual(key.fileobj, rd)
69        self.assertEqual(key.fd, rd.fileno())
70        self.assertEqual(key.events, selectors.EVENT_READ)
71        self.assertEqual(key.data, "data")
72
73        # register an unknown event
74        self.assertRaises(ValueError, s.register, 0, 999999)
75
76        # register an invalid FD
77        self.assertRaises(ValueError, s.register, -10, selectors.EVENT_READ)
78
79        # register twice
80        self.assertRaises(KeyError, s.register, rd, selectors.EVENT_READ)
81
82        # register the same FD, but with a different object
83        self.assertRaises(KeyError, s.register, rd.fileno(),
84                          selectors.EVENT_READ)
85
86    def test_unregister(self):
87        s = self.SELECTOR()
88        self.addCleanup(s.close)
89
90        rd, wr = self.make_socketpair()
91
92        s.register(rd, selectors.EVENT_READ)
93        s.unregister(rd)
94
95        # unregister an unknown file obj
96        self.assertRaises(KeyError, s.unregister, 999999)
97
98        # unregister twice
99        self.assertRaises(KeyError, s.unregister, rd)
100
101    def test_unregister_after_fd_close(self):
102        s = self.SELECTOR()
103        self.addCleanup(s.close)
104        rd, wr = self.make_socketpair()
105        r, w = rd.fileno(), wr.fileno()
106        s.register(r, selectors.EVENT_READ)
107        s.register(w, selectors.EVENT_WRITE)
108        rd.close()
109        wr.close()
110        s.unregister(r)
111        s.unregister(w)
112
113    @unittest.skipUnless(os.name == 'posix', "requires posix")
114    def test_unregister_after_fd_close_and_reuse(self):
115        s = self.SELECTOR()
116        self.addCleanup(s.close)
117        rd, wr = self.make_socketpair()
118        r, w = rd.fileno(), wr.fileno()
119        s.register(r, selectors.EVENT_READ)
120        s.register(w, selectors.EVENT_WRITE)
121        rd2, wr2 = self.make_socketpair()
122        rd.close()
123        wr.close()
124        os.dup2(rd2.fileno(), r)
125        os.dup2(wr2.fileno(), w)
126        self.addCleanup(os.close, r)
127        self.addCleanup(os.close, w)
128        s.unregister(r)
129        s.unregister(w)
130
131    def test_unregister_after_socket_close(self):
132        s = self.SELECTOR()
133        self.addCleanup(s.close)
134        rd, wr = self.make_socketpair()
135        s.register(rd, selectors.EVENT_READ)
136        s.register(wr, selectors.EVENT_WRITE)
137        rd.close()
138        wr.close()
139        s.unregister(rd)
140        s.unregister(wr)
141
142    def test_modify(self):
143        s = self.SELECTOR()
144        self.addCleanup(s.close)
145
146        rd, wr = self.make_socketpair()
147
148        key = s.register(rd, selectors.EVENT_READ)
149
150        # modify events
151        key2 = s.modify(rd, selectors.EVENT_WRITE)
152        self.assertNotEqual(key.events, key2.events)
153        self.assertEqual(key2, s.get_key(rd))
154
155        s.unregister(rd)
156
157        # modify data
158        d1 = object()
159        d2 = object()
160
161        key = s.register(rd, selectors.EVENT_READ, d1)
162        key2 = s.modify(rd, selectors.EVENT_READ, d2)
163        self.assertEqual(key.events, key2.events)
164        self.assertNotEqual(key.data, key2.data)
165        self.assertEqual(key2, s.get_key(rd))
166        self.assertEqual(key2.data, d2)
167
168        # modify unknown file obj
169        self.assertRaises(KeyError, s.modify, 999999, selectors.EVENT_READ)
170
171        # modify use a shortcut
172        d3 = object()
173        s.register = unittest.mock.Mock()
174        s.unregister = unittest.mock.Mock()
175
176        s.modify(rd, selectors.EVENT_READ, d3)
177        self.assertFalse(s.register.called)
178        self.assertFalse(s.unregister.called)
179
180    def test_modify_unregister(self):
181        # Make sure the fd is unregister()ed in case of error on
182        # modify(): http://bugs.python.org/issue30014
183        if self.SELECTOR.__name__ == 'EpollSelector':
184            patch = unittest.mock.patch(
185                'selectors.EpollSelector._selector_cls')
186        elif self.SELECTOR.__name__ == 'PollSelector':
187            patch = unittest.mock.patch(
188                'selectors.PollSelector._selector_cls')
189        elif self.SELECTOR.__name__ == 'DevpollSelector':
190            patch = unittest.mock.patch(
191                'selectors.DevpollSelector._selector_cls')
192        else:
193            raise self.skipTest("")
194
195        with patch as m:
196            m.return_value.modify = unittest.mock.Mock(
197                side_effect=ZeroDivisionError)
198            s = self.SELECTOR()
199            self.addCleanup(s.close)
200            rd, wr = self.make_socketpair()
201            s.register(rd, selectors.EVENT_READ)
202            self.assertEqual(len(s._map), 1)
203            with self.assertRaises(ZeroDivisionError):
204                s.modify(rd, selectors.EVENT_WRITE)
205            self.assertEqual(len(s._map), 0)
206
207    def test_close(self):
208        s = self.SELECTOR()
209        self.addCleanup(s.close)
210
211        mapping = s.get_map()
212        rd, wr = self.make_socketpair()
213
214        s.register(rd, selectors.EVENT_READ)
215        s.register(wr, selectors.EVENT_WRITE)
216
217        s.close()
218        self.assertRaises(RuntimeError, s.get_key, rd)
219        self.assertRaises(RuntimeError, s.get_key, wr)
220        self.assertRaises(KeyError, mapping.__getitem__, rd)
221        self.assertRaises(KeyError, mapping.__getitem__, wr)
222
223    def test_get_key(self):
224        s = self.SELECTOR()
225        self.addCleanup(s.close)
226
227        rd, wr = self.make_socketpair()
228
229        key = s.register(rd, selectors.EVENT_READ, "data")
230        self.assertEqual(key, s.get_key(rd))
231
232        # unknown file obj
233        self.assertRaises(KeyError, s.get_key, 999999)
234
235    def test_get_map(self):
236        s = self.SELECTOR()
237        self.addCleanup(s.close)
238
239        rd, wr = self.make_socketpair()
240
241        keys = s.get_map()
242        self.assertFalse(keys)
243        self.assertEqual(len(keys), 0)
244        self.assertEqual(list(keys), [])
245        key = s.register(rd, selectors.EVENT_READ, "data")
246        self.assertIn(rd, keys)
247        self.assertEqual(key, keys[rd])
248        self.assertEqual(len(keys), 1)
249        self.assertEqual(list(keys), [rd.fileno()])
250        self.assertEqual(list(keys.values()), [key])
251
252        # unknown file obj
253        with self.assertRaises(KeyError):
254            keys[999999]
255
256        # Read-only mapping
257        with self.assertRaises(TypeError):
258            del keys[rd]
259
260    def test_select(self):
261        s = self.SELECTOR()
262        self.addCleanup(s.close)
263
264        rd, wr = self.make_socketpair()
265
266        s.register(rd, selectors.EVENT_READ)
267        wr_key = s.register(wr, selectors.EVENT_WRITE)
268
269        result = s.select()
270        for key, events in result:
271            self.assertTrue(isinstance(key, selectors.SelectorKey))
272            self.assertTrue(events)
273            self.assertFalse(events & ~(selectors.EVENT_READ |
274                                        selectors.EVENT_WRITE))
275
276        self.assertEqual([(wr_key, selectors.EVENT_WRITE)], result)
277
278    def test_context_manager(self):
279        s = self.SELECTOR()
280        self.addCleanup(s.close)
281
282        rd, wr = self.make_socketpair()
283
284        with s as sel:
285            sel.register(rd, selectors.EVENT_READ)
286            sel.register(wr, selectors.EVENT_WRITE)
287
288        self.assertRaises(RuntimeError, s.get_key, rd)
289        self.assertRaises(RuntimeError, s.get_key, wr)
290
291    def test_fileno(self):
292        s = self.SELECTOR()
293        self.addCleanup(s.close)
294
295        if hasattr(s, 'fileno'):
296            fd = s.fileno()
297            self.assertTrue(isinstance(fd, int))
298            self.assertGreaterEqual(fd, 0)
299
300    def test_selector(self):
301        s = self.SELECTOR()
302        self.addCleanup(s.close)
303
304        NUM_SOCKETS = 12
305        MSG = b" This is a test."
306        MSG_LEN = len(MSG)
307        readers = []
308        writers = []
309        r2w = {}
310        w2r = {}
311
312        for i in range(NUM_SOCKETS):
313            rd, wr = self.make_socketpair()
314            s.register(rd, selectors.EVENT_READ)
315            s.register(wr, selectors.EVENT_WRITE)
316            readers.append(rd)
317            writers.append(wr)
318            r2w[rd] = wr
319            w2r[wr] = rd
320
321        bufs = []
322
323        while writers:
324            ready = s.select()
325            ready_writers = find_ready_matching(ready, selectors.EVENT_WRITE)
326            if not ready_writers:
327                self.fail("no sockets ready for writing")
328            wr = random.choice(ready_writers)
329            wr.send(MSG)
330
331            for i in range(10):
332                ready = s.select()
333                ready_readers = find_ready_matching(ready,
334                                                    selectors.EVENT_READ)
335                if ready_readers:
336                    break
337                # there might be a delay between the write to the write end and
338                # the read end is reported ready
339                sleep(0.1)
340            else:
341                self.fail("no sockets ready for reading")
342            self.assertEqual([w2r[wr]], ready_readers)
343            rd = ready_readers[0]
344            buf = rd.recv(MSG_LEN)
345            self.assertEqual(len(buf), MSG_LEN)
346            bufs.append(buf)
347            s.unregister(r2w[rd])
348            s.unregister(rd)
349            writers.remove(r2w[rd])
350
351        self.assertEqual(bufs, [MSG] * NUM_SOCKETS)
352
353    @unittest.skipIf(sys.platform == 'win32',
354                     'select.select() cannot be used with empty fd sets')
355    def test_empty_select(self):
356        # Issue #23009: Make sure EpollSelector.select() works when no FD is
357        # registered.
358        s = self.SELECTOR()
359        self.addCleanup(s.close)
360        self.assertEqual(s.select(timeout=0), [])
361
362    def test_timeout(self):
363        s = self.SELECTOR()
364        self.addCleanup(s.close)
365
366        rd, wr = self.make_socketpair()
367
368        s.register(wr, selectors.EVENT_WRITE)
369        t = time()
370        self.assertEqual(1, len(s.select(0)))
371        self.assertEqual(1, len(s.select(-1)))
372        self.assertLess(time() - t, 0.5)
373
374        s.unregister(wr)
375        s.register(rd, selectors.EVENT_READ)
376        t = time()
377        self.assertFalse(s.select(0))
378        self.assertFalse(s.select(-1))
379        self.assertLess(time() - t, 0.5)
380
381        t0 = time()
382        self.assertFalse(s.select(1))
383        t1 = time()
384        dt = t1 - t0
385        # Tolerate 2.0 seconds for very slow buildbots
386        self.assertTrue(0.8 <= dt <= 2.0, dt)
387
388    @unittest.skipUnless(hasattr(signal, "alarm"),
389                         "signal.alarm() required for this test")
390    def test_select_interrupt_exc(self):
391        s = self.SELECTOR()
392        self.addCleanup(s.close)
393
394        rd, wr = self.make_socketpair()
395
396        class InterruptSelect(Exception):
397            pass
398
399        def handler(*args):
400            raise InterruptSelect
401
402        orig_alrm_handler = signal.signal(signal.SIGALRM, handler)
403        self.addCleanup(signal.signal, signal.SIGALRM, orig_alrm_handler)
404
405        try:
406            signal.alarm(1)
407
408            s.register(rd, selectors.EVENT_READ)
409            t = time()
410            # select() is interrupted by a signal which raises an exception
411            with self.assertRaises(InterruptSelect):
412                s.select(30)
413            # select() was interrupted before the timeout of 30 seconds
414            self.assertLess(time() - t, 5.0)
415        finally:
416            signal.alarm(0)
417
418    @unittest.skipUnless(hasattr(signal, "alarm"),
419                         "signal.alarm() required for this test")
420    def test_select_interrupt_noraise(self):
421        s = self.SELECTOR()
422        self.addCleanup(s.close)
423
424        rd, wr = self.make_socketpair()
425
426        orig_alrm_handler = signal.signal(signal.SIGALRM, lambda *args: None)
427        self.addCleanup(signal.signal, signal.SIGALRM, orig_alrm_handler)
428
429        try:
430            signal.alarm(1)
431
432            s.register(rd, selectors.EVENT_READ)
433            t = time()
434            # select() is interrupted by a signal, but the signal handler doesn't
435            # raise an exception, so select() should by retries with a recomputed
436            # timeout
437            self.assertFalse(s.select(1.5))
438            self.assertGreaterEqual(time() - t, 1.0)
439        finally:
440            signal.alarm(0)
441
442
443class ScalableSelectorMixIn:
444
445    # see issue #18963 for why it's skipped on older OS X versions
446    @support.requires_mac_ver(10, 5)
447    @unittest.skipUnless(resource, "Test needs resource module")
448    def test_above_fd_setsize(self):
449        # A scalable implementation should have no problem with more than
450        # FD_SETSIZE file descriptors. Since we don't know the value, we just
451        # try to set the soft RLIMIT_NOFILE to the hard RLIMIT_NOFILE ceiling.
452        soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
453        try:
454            resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
455            self.addCleanup(resource.setrlimit, resource.RLIMIT_NOFILE,
456                            (soft, hard))
457            NUM_FDS = min(hard, 2**16)
458        except (OSError, ValueError):
459            NUM_FDS = soft
460
461        # guard for already allocated FDs (stdin, stdout...)
462        NUM_FDS -= 32
463
464        s = self.SELECTOR()
465        self.addCleanup(s.close)
466
467        for i in range(NUM_FDS // 2):
468            try:
469                rd, wr = self.make_socketpair()
470            except OSError:
471                # too many FDs, skip - note that we should only catch EMFILE
472                # here, but apparently *BSD and Solaris can fail upon connect()
473                # or bind() with EADDRNOTAVAIL, so let's be safe
474                self.skipTest("FD limit reached")
475
476            try:
477                s.register(rd, selectors.EVENT_READ)
478                s.register(wr, selectors.EVENT_WRITE)
479            except OSError as e:
480                if e.errno == errno.ENOSPC:
481                    # this can be raised by epoll if we go over
482                    # fs.epoll.max_user_watches sysctl
483                    self.skipTest("FD limit reached")
484                raise
485
486        try:
487            fds = s.select()
488        except OSError as e:
489            if e.errno == errno.EINVAL and sys.platform == 'darwin':
490                # unexplainable errors on macOS don't need to fail the test
491                self.skipTest("Invalid argument error calling poll()")
492            raise
493        self.assertEqual(NUM_FDS // 2, len(fds))
494
495
496class DefaultSelectorTestCase(BaseSelectorTestCase, unittest.TestCase):
497
498    SELECTOR = selectors.DefaultSelector
499
500
501class SelectSelectorTestCase(BaseSelectorTestCase, unittest.TestCase):
502
503    SELECTOR = selectors.SelectSelector
504
505
506@unittest.skipUnless(hasattr(selectors, 'PollSelector'),
507                     "Test needs selectors.PollSelector")
508class PollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn,
509                           unittest.TestCase):
510
511    SELECTOR = getattr(selectors, 'PollSelector', None)
512
513
514@unittest.skipUnless(hasattr(selectors, 'EpollSelector'),
515                     "Test needs selectors.EpollSelector")
516class EpollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn,
517                            unittest.TestCase):
518
519    SELECTOR = getattr(selectors, 'EpollSelector', None)
520
521    def test_register_file(self):
522        # epoll(7) returns EPERM when given a file to watch
523        s = self.SELECTOR()
524        with tempfile.NamedTemporaryFile() as f:
525            with self.assertRaises(IOError):
526                s.register(f, selectors.EVENT_READ)
527            # the SelectorKey has been removed
528            with self.assertRaises(KeyError):
529                s.get_key(f)
530
531
532@unittest.skipUnless(hasattr(selectors, 'KqueueSelector'),
533                     "Test needs selectors.KqueueSelector)")
534class KqueueSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn,
535                             unittest.TestCase):
536
537    SELECTOR = getattr(selectors, 'KqueueSelector', None)
538
539    def test_register_bad_fd(self):
540        # a file descriptor that's been closed should raise an OSError
541        # with EBADF
542        s = self.SELECTOR()
543        bad_f = os_helper.make_bad_fd()
544        with self.assertRaises(OSError) as cm:
545            s.register(bad_f, selectors.EVENT_READ)
546        self.assertEqual(cm.exception.errno, errno.EBADF)
547        # the SelectorKey has been removed
548        with self.assertRaises(KeyError):
549            s.get_key(bad_f)
550
551    def test_empty_select_timeout(self):
552        # Issues #23009, #29255: Make sure timeout is applied when no fds
553        # are registered.
554        s = self.SELECTOR()
555        self.addCleanup(s.close)
556
557        t0 = time()
558        self.assertEqual(s.select(1), [])
559        t1 = time()
560        dt = t1 - t0
561        # Tolerate 2.0 seconds for very slow buildbots
562        self.assertTrue(0.8 <= dt <= 2.0, dt)
563
564
565@unittest.skipUnless(hasattr(selectors, 'DevpollSelector'),
566                     "Test needs selectors.DevpollSelector")
567class DevpollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn,
568                              unittest.TestCase):
569
570    SELECTOR = getattr(selectors, 'DevpollSelector', None)
571
572
573def tearDownModule():
574    support.reap_children()
575
576
577if __name__ == "__main__":
578    unittest.main()
579