1"""Tests for events.py."""
2
3import collections.abc
4import concurrent.futures
5import functools
6import io
7import os
8import platform
9import re
10import signal
11import socket
12try:
13    import ssl
14except ImportError:
15    ssl = None
16import subprocess
17import sys
18import threading
19import time
20import types
21import errno
22import unittest
23from unittest import mock
24import weakref
25
26if sys.platform not in ('win32', 'vxworks'):
27    import tty
28
29import asyncio
30from asyncio import coroutines
31from asyncio import events
32from asyncio import proactor_events
33from asyncio import selector_events
34from test.test_asyncio import utils as test_utils
35from test import support
36from test.support import socket_helper
37from test.support import threading_helper
38from test.support import ALWAYS_EQ, LARGEST, SMALLEST
39
40
41def tearDownModule():
42    asyncio.set_event_loop_policy(None)
43
44
45def broken_unix_getsockname():
46    """Return True if the platform is Mac OS 10.4 or older."""
47    if sys.platform.startswith("aix"):
48        return True
49    elif sys.platform != 'darwin':
50        return False
51    version = platform.mac_ver()[0]
52    version = tuple(map(int, version.split('.')))
53    return version < (10, 5)
54
55
56def _test_get_event_loop_new_process__sub_proc():
57    async def doit():
58        return 'hello'
59
60    loop = asyncio.new_event_loop()
61    asyncio.set_event_loop(loop)
62    return loop.run_until_complete(doit())
63
64
65class CoroLike:
66    def send(self, v):
67        pass
68
69    def throw(self, *exc):
70        pass
71
72    def close(self):
73        pass
74
75    def __await__(self):
76        pass
77
78
79class MyBaseProto(asyncio.Protocol):
80    connected = None
81    done = None
82
83    def __init__(self, loop=None):
84        self.transport = None
85        self.state = 'INITIAL'
86        self.nbytes = 0
87        if loop is not None:
88            self.connected = loop.create_future()
89            self.done = loop.create_future()
90
91    def connection_made(self, transport):
92        self.transport = transport
93        assert self.state == 'INITIAL', self.state
94        self.state = 'CONNECTED'
95        if self.connected:
96            self.connected.set_result(None)
97
98    def data_received(self, data):
99        assert self.state == 'CONNECTED', self.state
100        self.nbytes += len(data)
101
102    def eof_received(self):
103        assert self.state == 'CONNECTED', self.state
104        self.state = 'EOF'
105
106    def connection_lost(self, exc):
107        assert self.state in ('CONNECTED', 'EOF'), self.state
108        self.state = 'CLOSED'
109        if self.done:
110            self.done.set_result(None)
111
112
113class MyProto(MyBaseProto):
114    def connection_made(self, transport):
115        super().connection_made(transport)
116        transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n')
117
118
119class MyDatagramProto(asyncio.DatagramProtocol):
120    done = None
121
122    def __init__(self, loop=None):
123        self.state = 'INITIAL'
124        self.nbytes = 0
125        if loop is not None:
126            self.done = loop.create_future()
127
128    def connection_made(self, transport):
129        self.transport = transport
130        assert self.state == 'INITIAL', self.state
131        self.state = 'INITIALIZED'
132
133    def datagram_received(self, data, addr):
134        assert self.state == 'INITIALIZED', self.state
135        self.nbytes += len(data)
136
137    def error_received(self, exc):
138        assert self.state == 'INITIALIZED', self.state
139
140    def connection_lost(self, exc):
141        assert self.state == 'INITIALIZED', self.state
142        self.state = 'CLOSED'
143        if self.done:
144            self.done.set_result(None)
145
146
147class MyReadPipeProto(asyncio.Protocol):
148    done = None
149
150    def __init__(self, loop=None):
151        self.state = ['INITIAL']
152        self.nbytes = 0
153        self.transport = None
154        if loop is not None:
155            self.done = loop.create_future()
156
157    def connection_made(self, transport):
158        self.transport = transport
159        assert self.state == ['INITIAL'], self.state
160        self.state.append('CONNECTED')
161
162    def data_received(self, data):
163        assert self.state == ['INITIAL', 'CONNECTED'], self.state
164        self.nbytes += len(data)
165
166    def eof_received(self):
167        assert self.state == ['INITIAL', 'CONNECTED'], self.state
168        self.state.append('EOF')
169
170    def connection_lost(self, exc):
171        if 'EOF' not in self.state:
172            self.state.append('EOF')  # It is okay if EOF is missed.
173        assert self.state == ['INITIAL', 'CONNECTED', 'EOF'], self.state
174        self.state.append('CLOSED')
175        if self.done:
176            self.done.set_result(None)
177
178
179class MyWritePipeProto(asyncio.BaseProtocol):
180    done = None
181
182    def __init__(self, loop=None):
183        self.state = 'INITIAL'
184        self.transport = None
185        if loop is not None:
186            self.done = loop.create_future()
187
188    def connection_made(self, transport):
189        self.transport = transport
190        assert self.state == 'INITIAL', self.state
191        self.state = 'CONNECTED'
192
193    def connection_lost(self, exc):
194        assert self.state == 'CONNECTED', self.state
195        self.state = 'CLOSED'
196        if self.done:
197            self.done.set_result(None)
198
199
200class MySubprocessProtocol(asyncio.SubprocessProtocol):
201
202    def __init__(self, loop):
203        self.state = 'INITIAL'
204        self.transport = None
205        self.connected = loop.create_future()
206        self.completed = loop.create_future()
207        self.disconnects = {fd: loop.create_future() for fd in range(3)}
208        self.data = {1: b'', 2: b''}
209        self.returncode = None
210        self.got_data = {1: asyncio.Event(),
211                         2: asyncio.Event()}
212
213    def connection_made(self, transport):
214        self.transport = transport
215        assert self.state == 'INITIAL', self.state
216        self.state = 'CONNECTED'
217        self.connected.set_result(None)
218
219    def connection_lost(self, exc):
220        assert self.state == 'CONNECTED', self.state
221        self.state = 'CLOSED'
222        self.completed.set_result(None)
223
224    def pipe_data_received(self, fd, data):
225        assert self.state == 'CONNECTED', self.state
226        self.data[fd] += data
227        self.got_data[fd].set()
228
229    def pipe_connection_lost(self, fd, exc):
230        assert self.state == 'CONNECTED', self.state
231        if exc:
232            self.disconnects[fd].set_exception(exc)
233        else:
234            self.disconnects[fd].set_result(exc)
235
236    def process_exited(self):
237        assert self.state == 'CONNECTED', self.state
238        self.returncode = self.transport.get_returncode()
239
240
241class EventLoopTestsMixin:
242
243    def setUp(self):
244        super().setUp()
245        self.loop = self.create_event_loop()
246        self.set_event_loop(self.loop)
247
248    def tearDown(self):
249        # just in case if we have transport close callbacks
250        if not self.loop.is_closed():
251            test_utils.run_briefly(self.loop)
252
253        self.doCleanups()
254        support.gc_collect()
255        super().tearDown()
256
257    def test_run_until_complete_nesting(self):
258        async def coro1():
259            await asyncio.sleep(0)
260
261        async def coro2():
262            self.assertTrue(self.loop.is_running())
263            self.loop.run_until_complete(coro1())
264
265        with self.assertWarnsRegex(
266            RuntimeWarning,
267            r"coroutine \S+ was never awaited"
268        ):
269            self.assertRaises(
270                RuntimeError, self.loop.run_until_complete, coro2())
271
272    # Note: because of the default Windows timing granularity of
273    # 15.6 msec, we use fairly long sleep times here (~100 msec).
274
275    def test_run_until_complete(self):
276        t0 = self.loop.time()
277        self.loop.run_until_complete(asyncio.sleep(0.1))
278        t1 = self.loop.time()
279        self.assertTrue(0.08 <= t1-t0 <= 0.8, t1-t0)
280
281    def test_run_until_complete_stopped(self):
282
283        async def cb():
284            self.loop.stop()
285            await asyncio.sleep(0.1)
286        task = cb()
287        self.assertRaises(RuntimeError,
288                          self.loop.run_until_complete, task)
289
290    def test_call_later(self):
291        results = []
292
293        def callback(arg):
294            results.append(arg)
295            self.loop.stop()
296
297        self.loop.call_later(0.1, callback, 'hello world')
298        self.loop.run_forever()
299        self.assertEqual(results, ['hello world'])
300
301    def test_call_soon(self):
302        results = []
303
304        def callback(arg1, arg2):
305            results.append((arg1, arg2))
306            self.loop.stop()
307
308        self.loop.call_soon(callback, 'hello', 'world')
309        self.loop.run_forever()
310        self.assertEqual(results, [('hello', 'world')])
311
312    def test_call_soon_threadsafe(self):
313        results = []
314        lock = threading.Lock()
315
316        def callback(arg):
317            results.append(arg)
318            if len(results) >= 2:
319                self.loop.stop()
320
321        def run_in_thread():
322            self.loop.call_soon_threadsafe(callback, 'hello')
323            lock.release()
324
325        lock.acquire()
326        t = threading.Thread(target=run_in_thread)
327        t.start()
328
329        with lock:
330            self.loop.call_soon(callback, 'world')
331            self.loop.run_forever()
332        t.join()
333        self.assertEqual(results, ['hello', 'world'])
334
335    def test_call_soon_threadsafe_same_thread(self):
336        results = []
337
338        def callback(arg):
339            results.append(arg)
340            if len(results) >= 2:
341                self.loop.stop()
342
343        self.loop.call_soon_threadsafe(callback, 'hello')
344        self.loop.call_soon(callback, 'world')
345        self.loop.run_forever()
346        self.assertEqual(results, ['hello', 'world'])
347
348    def test_run_in_executor(self):
349        def run(arg):
350            return (arg, threading.get_ident())
351        f2 = self.loop.run_in_executor(None, run, 'yo')
352        res, thread_id = self.loop.run_until_complete(f2)
353        self.assertEqual(res, 'yo')
354        self.assertNotEqual(thread_id, threading.get_ident())
355
356    def test_run_in_executor_cancel(self):
357        called = False
358
359        def patched_call_soon(*args):
360            nonlocal called
361            called = True
362
363        def run():
364            time.sleep(0.05)
365
366        f2 = self.loop.run_in_executor(None, run)
367        f2.cancel()
368        self.loop.run_until_complete(
369                self.loop.shutdown_default_executor())
370        self.loop.close()
371        self.loop.call_soon = patched_call_soon
372        self.loop.call_soon_threadsafe = patched_call_soon
373        time.sleep(0.4)
374        self.assertFalse(called)
375
376    def test_reader_callback(self):
377        r, w = socket.socketpair()
378        r.setblocking(False)
379        bytes_read = bytearray()
380
381        def reader():
382            try:
383                data = r.recv(1024)
384            except BlockingIOError:
385                # Spurious readiness notifications are possible
386                # at least on Linux -- see man select.
387                return
388            if data:
389                bytes_read.extend(data)
390            else:
391                self.assertTrue(self.loop.remove_reader(r.fileno()))
392                r.close()
393
394        self.loop.add_reader(r.fileno(), reader)
395        self.loop.call_soon(w.send, b'abc')
396        test_utils.run_until(self.loop, lambda: len(bytes_read) >= 3)
397        self.loop.call_soon(w.send, b'def')
398        test_utils.run_until(self.loop, lambda: len(bytes_read) >= 6)
399        self.loop.call_soon(w.close)
400        self.loop.call_soon(self.loop.stop)
401        self.loop.run_forever()
402        self.assertEqual(bytes_read, b'abcdef')
403
404    def test_writer_callback(self):
405        r, w = socket.socketpair()
406        w.setblocking(False)
407
408        def writer(data):
409            w.send(data)
410            self.loop.stop()
411
412        data = b'x' * 1024
413        self.loop.add_writer(w.fileno(), writer, data)
414        self.loop.run_forever()
415
416        self.assertTrue(self.loop.remove_writer(w.fileno()))
417        self.assertFalse(self.loop.remove_writer(w.fileno()))
418
419        w.close()
420        read = r.recv(len(data) * 2)
421        r.close()
422        self.assertEqual(read, data)
423
424    @unittest.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL')
425    def test_add_signal_handler(self):
426        caught = 0
427
428        def my_handler():
429            nonlocal caught
430            caught += 1
431
432        # Check error behavior first.
433        self.assertRaises(
434            TypeError, self.loop.add_signal_handler, 'boom', my_handler)
435        self.assertRaises(
436            TypeError, self.loop.remove_signal_handler, 'boom')
437        self.assertRaises(
438            ValueError, self.loop.add_signal_handler, signal.NSIG+1,
439            my_handler)
440        self.assertRaises(
441            ValueError, self.loop.remove_signal_handler, signal.NSIG+1)
442        self.assertRaises(
443            ValueError, self.loop.add_signal_handler, 0, my_handler)
444        self.assertRaises(
445            ValueError, self.loop.remove_signal_handler, 0)
446        self.assertRaises(
447            ValueError, self.loop.add_signal_handler, -1, my_handler)
448        self.assertRaises(
449            ValueError, self.loop.remove_signal_handler, -1)
450        self.assertRaises(
451            RuntimeError, self.loop.add_signal_handler, signal.SIGKILL,
452            my_handler)
453        # Removing SIGKILL doesn't raise, since we don't call signal().
454        self.assertFalse(self.loop.remove_signal_handler(signal.SIGKILL))
455        # Now set a handler and handle it.
456        self.loop.add_signal_handler(signal.SIGINT, my_handler)
457
458        os.kill(os.getpid(), signal.SIGINT)
459        test_utils.run_until(self.loop, lambda: caught)
460
461        # Removing it should restore the default handler.
462        self.assertTrue(self.loop.remove_signal_handler(signal.SIGINT))
463        self.assertEqual(signal.getsignal(signal.SIGINT),
464                         signal.default_int_handler)
465        # Removing again returns False.
466        self.assertFalse(self.loop.remove_signal_handler(signal.SIGINT))
467
468    @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM')
469    @unittest.skipUnless(hasattr(signal, 'setitimer'),
470                         'need signal.setitimer()')
471    def test_signal_handling_while_selecting(self):
472        # Test with a signal actually arriving during a select() call.
473        caught = 0
474
475        def my_handler():
476            nonlocal caught
477            caught += 1
478            self.loop.stop()
479
480        self.loop.add_signal_handler(signal.SIGALRM, my_handler)
481
482        signal.setitimer(signal.ITIMER_REAL, 0.01, 0)  # Send SIGALRM once.
483        self.loop.call_later(60, self.loop.stop)
484        self.loop.run_forever()
485        self.assertEqual(caught, 1)
486
487    @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM')
488    @unittest.skipUnless(hasattr(signal, 'setitimer'),
489                         'need signal.setitimer()')
490    def test_signal_handling_args(self):
491        some_args = (42,)
492        caught = 0
493
494        def my_handler(*args):
495            nonlocal caught
496            caught += 1
497            self.assertEqual(args, some_args)
498            self.loop.stop()
499
500        self.loop.add_signal_handler(signal.SIGALRM, my_handler, *some_args)
501
502        signal.setitimer(signal.ITIMER_REAL, 0.1, 0)  # Send SIGALRM once.
503        self.loop.call_later(60, self.loop.stop)
504        self.loop.run_forever()
505        self.assertEqual(caught, 1)
506
507    def _basetest_create_connection(self, connection_fut, check_sockname=True):
508        tr, pr = self.loop.run_until_complete(connection_fut)
509        self.assertIsInstance(tr, asyncio.Transport)
510        self.assertIsInstance(pr, asyncio.Protocol)
511        self.assertIs(pr.transport, tr)
512        if check_sockname:
513            self.assertIsNotNone(tr.get_extra_info('sockname'))
514        self.loop.run_until_complete(pr.done)
515        self.assertGreater(pr.nbytes, 0)
516        tr.close()
517
518    def test_create_connection(self):
519        with test_utils.run_test_server() as httpd:
520            conn_fut = self.loop.create_connection(
521                lambda: MyProto(loop=self.loop), *httpd.address)
522            self._basetest_create_connection(conn_fut)
523
524    @socket_helper.skip_unless_bind_unix_socket
525    def test_create_unix_connection(self):
526        # Issue #20682: On Mac OS X Tiger, getsockname() returns a
527        # zero-length address for UNIX socket.
528        check_sockname = not broken_unix_getsockname()
529
530        with test_utils.run_test_unix_server() as httpd:
531            conn_fut = self.loop.create_unix_connection(
532                lambda: MyProto(loop=self.loop), httpd.address)
533            self._basetest_create_connection(conn_fut, check_sockname)
534
535    def check_ssl_extra_info(self, client, check_sockname=True,
536                             peername=None, peercert={}):
537        if check_sockname:
538            self.assertIsNotNone(client.get_extra_info('sockname'))
539        if peername:
540            self.assertEqual(peername,
541                             client.get_extra_info('peername'))
542        else:
543            self.assertIsNotNone(client.get_extra_info('peername'))
544        self.assertEqual(peercert,
545                         client.get_extra_info('peercert'))
546
547        # test SSL cipher
548        cipher = client.get_extra_info('cipher')
549        self.assertIsInstance(cipher, tuple)
550        self.assertEqual(len(cipher), 3, cipher)
551        self.assertIsInstance(cipher[0], str)
552        self.assertIsInstance(cipher[1], str)
553        self.assertIsInstance(cipher[2], int)
554
555        # test SSL object
556        sslobj = client.get_extra_info('ssl_object')
557        self.assertIsNotNone(sslobj)
558        self.assertEqual(sslobj.compression(),
559                         client.get_extra_info('compression'))
560        self.assertEqual(sslobj.cipher(),
561                         client.get_extra_info('cipher'))
562        self.assertEqual(sslobj.getpeercert(),
563                         client.get_extra_info('peercert'))
564        self.assertEqual(sslobj.compression(),
565                         client.get_extra_info('compression'))
566
567    def _basetest_create_ssl_connection(self, connection_fut,
568                                        check_sockname=True,
569                                        peername=None):
570        tr, pr = self.loop.run_until_complete(connection_fut)
571        self.assertIsInstance(tr, asyncio.Transport)
572        self.assertIsInstance(pr, asyncio.Protocol)
573        self.assertTrue('ssl' in tr.__class__.__name__.lower())
574        self.check_ssl_extra_info(tr, check_sockname, peername)
575        self.loop.run_until_complete(pr.done)
576        self.assertGreater(pr.nbytes, 0)
577        tr.close()
578
579    def _test_create_ssl_connection(self, httpd, create_connection,
580                                    check_sockname=True, peername=None):
581        conn_fut = create_connection(ssl=test_utils.dummy_ssl_context())
582        self._basetest_create_ssl_connection(conn_fut, check_sockname,
583                                             peername)
584
585        # ssl.Purpose was introduced in Python 3.4
586        if hasattr(ssl, 'Purpose'):
587            def _dummy_ssl_create_context(purpose=ssl.Purpose.SERVER_AUTH, *,
588                                          cafile=None, capath=None,
589                                          cadata=None):
590                """
591                A ssl.create_default_context() replacement that doesn't enable
592                cert validation.
593                """
594                self.assertEqual(purpose, ssl.Purpose.SERVER_AUTH)
595                return test_utils.dummy_ssl_context()
596
597            # With ssl=True, ssl.create_default_context() should be called
598            with mock.patch('ssl.create_default_context',
599                            side_effect=_dummy_ssl_create_context) as m:
600                conn_fut = create_connection(ssl=True)
601                self._basetest_create_ssl_connection(conn_fut, check_sockname,
602                                                     peername)
603                self.assertEqual(m.call_count, 1)
604
605        # With the real ssl.create_default_context(), certificate
606        # validation will fail
607        with self.assertRaises(ssl.SSLError) as cm:
608            conn_fut = create_connection(ssl=True)
609            # Ignore the "SSL handshake failed" log in debug mode
610            with test_utils.disable_logger():
611                self._basetest_create_ssl_connection(conn_fut, check_sockname,
612                                                     peername)
613
614        self.assertEqual(cm.exception.reason, 'CERTIFICATE_VERIFY_FAILED')
615
616    @unittest.skipIf(ssl is None, 'No ssl module')
617    def test_create_ssl_connection(self):
618        with test_utils.run_test_server(use_ssl=True) as httpd:
619            create_connection = functools.partial(
620                self.loop.create_connection,
621                lambda: MyProto(loop=self.loop),
622                *httpd.address)
623            self._test_create_ssl_connection(httpd, create_connection,
624                                             peername=httpd.address)
625
626    @socket_helper.skip_unless_bind_unix_socket
627    @unittest.skipIf(ssl is None, 'No ssl module')
628    def test_create_ssl_unix_connection(self):
629        # Issue #20682: On Mac OS X Tiger, getsockname() returns a
630        # zero-length address for UNIX socket.
631        check_sockname = not broken_unix_getsockname()
632
633        with test_utils.run_test_unix_server(use_ssl=True) as httpd:
634            create_connection = functools.partial(
635                self.loop.create_unix_connection,
636                lambda: MyProto(loop=self.loop), httpd.address,
637                server_hostname='127.0.0.1')
638
639            self._test_create_ssl_connection(httpd, create_connection,
640                                             check_sockname,
641                                             peername=httpd.address)
642
643    def test_create_connection_local_addr(self):
644        with test_utils.run_test_server() as httpd:
645            port = socket_helper.find_unused_port()
646            f = self.loop.create_connection(
647                lambda: MyProto(loop=self.loop),
648                *httpd.address, local_addr=(httpd.address[0], port))
649            tr, pr = self.loop.run_until_complete(f)
650            expected = pr.transport.get_extra_info('sockname')[1]
651            self.assertEqual(port, expected)
652            tr.close()
653
654    def test_create_connection_local_addr_in_use(self):
655        with test_utils.run_test_server() as httpd:
656            f = self.loop.create_connection(
657                lambda: MyProto(loop=self.loop),
658                *httpd.address, local_addr=httpd.address)
659            with self.assertRaises(OSError) as cm:
660                self.loop.run_until_complete(f)
661            self.assertEqual(cm.exception.errno, errno.EADDRINUSE)
662            self.assertIn(str(httpd.address), cm.exception.strerror)
663
664    def test_connect_accepted_socket(self, server_ssl=None, client_ssl=None):
665        loop = self.loop
666
667        class MyProto(MyBaseProto):
668
669            def connection_lost(self, exc):
670                super().connection_lost(exc)
671                loop.call_soon(loop.stop)
672
673            def data_received(self, data):
674                super().data_received(data)
675                self.transport.write(expected_response)
676
677        lsock = socket.create_server(('127.0.0.1', 0), backlog=1)
678        addr = lsock.getsockname()
679
680        message = b'test data'
681        response = None
682        expected_response = b'roger'
683
684        def client():
685            nonlocal response
686            try:
687                csock = socket.socket()
688                if client_ssl is not None:
689                    csock = client_ssl.wrap_socket(csock)
690                csock.connect(addr)
691                csock.sendall(message)
692                response = csock.recv(99)
693                csock.close()
694            except Exception as exc:
695                print(
696                    "Failure in client thread in test_connect_accepted_socket",
697                    exc)
698
699        thread = threading.Thread(target=client, daemon=True)
700        thread.start()
701
702        conn, _ = lsock.accept()
703        proto = MyProto(loop=loop)
704        proto.loop = loop
705        loop.run_until_complete(
706            loop.connect_accepted_socket(
707                (lambda: proto), conn, ssl=server_ssl))
708        loop.run_forever()
709        proto.transport.close()
710        lsock.close()
711
712        threading_helper.join_thread(thread)
713        self.assertFalse(thread.is_alive())
714        self.assertEqual(proto.state, 'CLOSED')
715        self.assertEqual(proto.nbytes, len(message))
716        self.assertEqual(response, expected_response)
717
718    @unittest.skipIf(ssl is None, 'No ssl module')
719    def test_ssl_connect_accepted_socket(self):
720        if (sys.platform == 'win32' and
721            sys.version_info < (3, 5) and
722            isinstance(self.loop, proactor_events.BaseProactorEventLoop)
723            ):
724            raise unittest.SkipTest(
725                'SSL not supported with proactor event loops before Python 3.5'
726                )
727
728        server_context = test_utils.simple_server_sslcontext()
729        client_context = test_utils.simple_client_sslcontext()
730
731        self.test_connect_accepted_socket(server_context, client_context)
732
733    def test_connect_accepted_socket_ssl_timeout_for_plain_socket(self):
734        sock = socket.socket()
735        self.addCleanup(sock.close)
736        coro = self.loop.connect_accepted_socket(
737            MyProto, sock, ssl_handshake_timeout=support.LOOPBACK_TIMEOUT)
738        with self.assertRaisesRegex(
739                ValueError,
740                'ssl_handshake_timeout is only meaningful with ssl'):
741            self.loop.run_until_complete(coro)
742
743    @mock.patch('asyncio.base_events.socket')
744    def create_server_multiple_hosts(self, family, hosts, mock_sock):
745        async def getaddrinfo(host, port, *args, **kw):
746            if family == socket.AF_INET:
747                return [(family, socket.SOCK_STREAM, 6, '', (host, port))]
748            else:
749                return [(family, socket.SOCK_STREAM, 6, '', (host, port, 0, 0))]
750
751        def getaddrinfo_task(*args, **kwds):
752            return self.loop.create_task(getaddrinfo(*args, **kwds))
753
754        unique_hosts = set(hosts)
755
756        if family == socket.AF_INET:
757            mock_sock.socket().getsockbyname.side_effect = [
758                (host, 80) for host in unique_hosts]
759        else:
760            mock_sock.socket().getsockbyname.side_effect = [
761                (host, 80, 0, 0) for host in unique_hosts]
762        self.loop.getaddrinfo = getaddrinfo_task
763        self.loop._start_serving = mock.Mock()
764        self.loop._stop_serving = mock.Mock()
765        f = self.loop.create_server(lambda: MyProto(self.loop), hosts, 80)
766        server = self.loop.run_until_complete(f)
767        self.addCleanup(server.close)
768        server_hosts = {sock.getsockbyname()[0] for sock in server.sockets}
769        self.assertEqual(server_hosts, unique_hosts)
770
771    def test_create_server_multiple_hosts_ipv4(self):
772        self.create_server_multiple_hosts(socket.AF_INET,
773                                          ['1.2.3.4', '5.6.7.8', '1.2.3.4'])
774
775    def test_create_server_multiple_hosts_ipv6(self):
776        self.create_server_multiple_hosts(socket.AF_INET6,
777                                          ['::1', '::2', '::1'])
778
779    def test_create_server(self):
780        proto = MyProto(self.loop)
781        f = self.loop.create_server(lambda: proto, '0.0.0.0', 0)
782        server = self.loop.run_until_complete(f)
783        self.assertEqual(len(server.sockets), 1)
784        sock = server.sockets[0]
785        host, port = sock.getsockname()
786        self.assertEqual(host, '0.0.0.0')
787        client = socket.socket()
788        client.connect(('127.0.0.1', port))
789        client.sendall(b'xxx')
790
791        self.loop.run_until_complete(proto.connected)
792        self.assertEqual('CONNECTED', proto.state)
793
794        test_utils.run_until(self.loop, lambda: proto.nbytes > 0)
795        self.assertEqual(3, proto.nbytes)
796
797        # extra info is available
798        self.assertIsNotNone(proto.transport.get_extra_info('sockname'))
799        self.assertEqual('127.0.0.1',
800                         proto.transport.get_extra_info('peername')[0])
801
802        # close connection
803        proto.transport.close()
804        self.loop.run_until_complete(proto.done)
805
806        self.assertEqual('CLOSED', proto.state)
807
808        # the client socket must be closed after to avoid ECONNRESET upon
809        # recv()/send() on the serving socket
810        client.close()
811
812        # close server
813        server.close()
814
815    @unittest.skipUnless(hasattr(socket, 'SO_REUSEPORT'), 'No SO_REUSEPORT')
816    def test_create_server_reuse_port(self):
817        proto = MyProto(self.loop)
818        f = self.loop.create_server(
819            lambda: proto, '0.0.0.0', 0)
820        server = self.loop.run_until_complete(f)
821        self.assertEqual(len(server.sockets), 1)
822        sock = server.sockets[0]
823        self.assertFalse(
824            sock.getsockopt(
825                socket.SOL_SOCKET, socket.SO_REUSEPORT))
826        server.close()
827
828        test_utils.run_briefly(self.loop)
829
830        proto = MyProto(self.loop)
831        f = self.loop.create_server(
832            lambda: proto, '0.0.0.0', 0, reuse_port=True)
833        server = self.loop.run_until_complete(f)
834        self.assertEqual(len(server.sockets), 1)
835        sock = server.sockets[0]
836        self.assertTrue(
837            sock.getsockopt(
838                socket.SOL_SOCKET, socket.SO_REUSEPORT))
839        server.close()
840
841    def _make_unix_server(self, factory, **kwargs):
842        path = test_utils.gen_unix_socket_path()
843        self.addCleanup(lambda: os.path.exists(path) and os.unlink(path))
844
845        f = self.loop.create_unix_server(factory, path, **kwargs)
846        server = self.loop.run_until_complete(f)
847
848        return server, path
849
850    @socket_helper.skip_unless_bind_unix_socket
851    def test_create_unix_server(self):
852        proto = MyProto(loop=self.loop)
853        server, path = self._make_unix_server(lambda: proto)
854        self.assertEqual(len(server.sockets), 1)
855
856        client = socket.socket(socket.AF_UNIX)
857        client.connect(path)
858        client.sendall(b'xxx')
859
860        self.loop.run_until_complete(proto.connected)
861        self.assertEqual('CONNECTED', proto.state)
862        test_utils.run_until(self.loop, lambda: proto.nbytes > 0)
863        self.assertEqual(3, proto.nbytes)
864
865        # close connection
866        proto.transport.close()
867        self.loop.run_until_complete(proto.done)
868
869        self.assertEqual('CLOSED', proto.state)
870
871        # the client socket must be closed after to avoid ECONNRESET upon
872        # recv()/send() on the serving socket
873        client.close()
874
875        # close server
876        server.close()
877
878    @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
879    def test_create_unix_server_path_socket_error(self):
880        proto = MyProto(loop=self.loop)
881        sock = socket.socket()
882        with sock:
883            f = self.loop.create_unix_server(lambda: proto, '/test', sock=sock)
884            with self.assertRaisesRegex(ValueError,
885                                        'path and sock can not be specified '
886                                        'at the same time'):
887                self.loop.run_until_complete(f)
888
889    def _create_ssl_context(self, certfile, keyfile=None):
890        sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
891        sslcontext.options |= ssl.OP_NO_SSLv2
892        sslcontext.load_cert_chain(certfile, keyfile)
893        return sslcontext
894
895    def _make_ssl_server(self, factory, certfile, keyfile=None):
896        sslcontext = self._create_ssl_context(certfile, keyfile)
897
898        f = self.loop.create_server(factory, '127.0.0.1', 0, ssl=sslcontext)
899        server = self.loop.run_until_complete(f)
900
901        sock = server.sockets[0]
902        host, port = sock.getsockname()
903        self.assertEqual(host, '127.0.0.1')
904        return server, host, port
905
906    def _make_ssl_unix_server(self, factory, certfile, keyfile=None):
907        sslcontext = self._create_ssl_context(certfile, keyfile)
908        return self._make_unix_server(factory, ssl=sslcontext)
909
910    @unittest.skipIf(ssl is None, 'No ssl module')
911    def test_create_server_ssl(self):
912        proto = MyProto(loop=self.loop)
913        server, host, port = self._make_ssl_server(
914            lambda: proto, test_utils.ONLYCERT, test_utils.ONLYKEY)
915
916        f_c = self.loop.create_connection(MyBaseProto, host, port,
917                                          ssl=test_utils.dummy_ssl_context())
918        client, pr = self.loop.run_until_complete(f_c)
919
920        client.write(b'xxx')
921        self.loop.run_until_complete(proto.connected)
922        self.assertEqual('CONNECTED', proto.state)
923
924        test_utils.run_until(self.loop, lambda: proto.nbytes > 0)
925        self.assertEqual(3, proto.nbytes)
926
927        # extra info is available
928        self.check_ssl_extra_info(client, peername=(host, port))
929
930        # close connection
931        proto.transport.close()
932        self.loop.run_until_complete(proto.done)
933        self.assertEqual('CLOSED', proto.state)
934
935        # the client socket must be closed after to avoid ECONNRESET upon
936        # recv()/send() on the serving socket
937        client.close()
938
939        # stop serving
940        server.close()
941
942    @socket_helper.skip_unless_bind_unix_socket
943    @unittest.skipIf(ssl is None, 'No ssl module')
944    def test_create_unix_server_ssl(self):
945        proto = MyProto(loop=self.loop)
946        server, path = self._make_ssl_unix_server(
947            lambda: proto, test_utils.ONLYCERT, test_utils.ONLYKEY)
948
949        f_c = self.loop.create_unix_connection(
950            MyBaseProto, path, ssl=test_utils.dummy_ssl_context(),
951            server_hostname='')
952
953        client, pr = self.loop.run_until_complete(f_c)
954
955        client.write(b'xxx')
956        self.loop.run_until_complete(proto.connected)
957        self.assertEqual('CONNECTED', proto.state)
958        test_utils.run_until(self.loop, lambda: proto.nbytes > 0)
959        self.assertEqual(3, proto.nbytes)
960
961        # close connection
962        proto.transport.close()
963        self.loop.run_until_complete(proto.done)
964        self.assertEqual('CLOSED', proto.state)
965
966        # the client socket must be closed after to avoid ECONNRESET upon
967        # recv()/send() on the serving socket
968        client.close()
969
970        # stop serving
971        server.close()
972
973    @unittest.skipIf(ssl is None, 'No ssl module')
974    def test_create_server_ssl_verify_failed(self):
975        proto = MyProto(loop=self.loop)
976        server, host, port = self._make_ssl_server(
977            lambda: proto, test_utils.SIGNED_CERTFILE)
978
979        sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
980        sslcontext_client.options |= ssl.OP_NO_SSLv2
981        sslcontext_client.verify_mode = ssl.CERT_REQUIRED
982        if hasattr(sslcontext_client, 'check_hostname'):
983            sslcontext_client.check_hostname = True
984
985
986        # no CA loaded
987        f_c = self.loop.create_connection(MyProto, host, port,
988                                          ssl=sslcontext_client)
989        with mock.patch.object(self.loop, 'call_exception_handler'):
990            with test_utils.disable_logger():
991                with self.assertRaisesRegex(ssl.SSLError,
992                                            '(?i)certificate.verify.failed'):
993                    self.loop.run_until_complete(f_c)
994
995            # execute the loop to log the connection error
996            test_utils.run_briefly(self.loop)
997
998        # close connection
999        self.assertIsNone(proto.transport)
1000        server.close()
1001
1002    @socket_helper.skip_unless_bind_unix_socket
1003    @unittest.skipIf(ssl is None, 'No ssl module')
1004    def test_create_unix_server_ssl_verify_failed(self):
1005        proto = MyProto(loop=self.loop)
1006        server, path = self._make_ssl_unix_server(
1007            lambda: proto, test_utils.SIGNED_CERTFILE)
1008
1009        sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1010        sslcontext_client.options |= ssl.OP_NO_SSLv2
1011        sslcontext_client.verify_mode = ssl.CERT_REQUIRED
1012        if hasattr(sslcontext_client, 'check_hostname'):
1013            sslcontext_client.check_hostname = True
1014
1015        # no CA loaded
1016        f_c = self.loop.create_unix_connection(MyProto, path,
1017                                               ssl=sslcontext_client,
1018                                               server_hostname='invalid')
1019        with mock.patch.object(self.loop, 'call_exception_handler'):
1020            with test_utils.disable_logger():
1021                with self.assertRaisesRegex(ssl.SSLError,
1022                                            '(?i)certificate.verify.failed'):
1023                    self.loop.run_until_complete(f_c)
1024
1025            # execute the loop to log the connection error
1026            test_utils.run_briefly(self.loop)
1027
1028        # close connection
1029        self.assertIsNone(proto.transport)
1030        server.close()
1031
1032    @unittest.skipIf(ssl is None, 'No ssl module')
1033    def test_create_server_ssl_match_failed(self):
1034        proto = MyProto(loop=self.loop)
1035        server, host, port = self._make_ssl_server(
1036            lambda: proto, test_utils.SIGNED_CERTFILE)
1037
1038        sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1039        sslcontext_client.options |= ssl.OP_NO_SSLv2
1040        sslcontext_client.verify_mode = ssl.CERT_REQUIRED
1041        sslcontext_client.load_verify_locations(
1042            cafile=test_utils.SIGNING_CA)
1043        if hasattr(sslcontext_client, 'check_hostname'):
1044            sslcontext_client.check_hostname = True
1045
1046        # incorrect server_hostname
1047        f_c = self.loop.create_connection(MyProto, host, port,
1048                                          ssl=sslcontext_client)
1049        with mock.patch.object(self.loop, 'call_exception_handler'):
1050            with test_utils.disable_logger():
1051                with self.assertRaisesRegex(
1052                        ssl.CertificateError,
1053                        "IP address mismatch, certificate is not valid for "
1054                        "'127.0.0.1'"):
1055                    self.loop.run_until_complete(f_c)
1056
1057        # close connection
1058        # transport is None because TLS ALERT aborted the handshake
1059        self.assertIsNone(proto.transport)
1060        server.close()
1061
1062    @socket_helper.skip_unless_bind_unix_socket
1063    @unittest.skipIf(ssl is None, 'No ssl module')
1064    def test_create_unix_server_ssl_verified(self):
1065        proto = MyProto(loop=self.loop)
1066        server, path = self._make_ssl_unix_server(
1067            lambda: proto, test_utils.SIGNED_CERTFILE)
1068
1069        sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1070        sslcontext_client.options |= ssl.OP_NO_SSLv2
1071        sslcontext_client.verify_mode = ssl.CERT_REQUIRED
1072        sslcontext_client.load_verify_locations(cafile=test_utils.SIGNING_CA)
1073        if hasattr(sslcontext_client, 'check_hostname'):
1074            sslcontext_client.check_hostname = True
1075
1076        # Connection succeeds with correct CA and server hostname.
1077        f_c = self.loop.create_unix_connection(MyProto, path,
1078                                               ssl=sslcontext_client,
1079                                               server_hostname='localhost')
1080        client, pr = self.loop.run_until_complete(f_c)
1081        self.loop.run_until_complete(proto.connected)
1082
1083        # close connection
1084        proto.transport.close()
1085        client.close()
1086        server.close()
1087        self.loop.run_until_complete(proto.done)
1088
1089    @unittest.skipIf(ssl is None, 'No ssl module')
1090    def test_create_server_ssl_verified(self):
1091        proto = MyProto(loop=self.loop)
1092        server, host, port = self._make_ssl_server(
1093            lambda: proto, test_utils.SIGNED_CERTFILE)
1094
1095        sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1096        sslcontext_client.options |= ssl.OP_NO_SSLv2
1097        sslcontext_client.verify_mode = ssl.CERT_REQUIRED
1098        sslcontext_client.load_verify_locations(cafile=test_utils.SIGNING_CA)
1099        if hasattr(sslcontext_client, 'check_hostname'):
1100            sslcontext_client.check_hostname = True
1101
1102        # Connection succeeds with correct CA and server hostname.
1103        f_c = self.loop.create_connection(MyProto, host, port,
1104                                          ssl=sslcontext_client,
1105                                          server_hostname='localhost')
1106        client, pr = self.loop.run_until_complete(f_c)
1107        self.loop.run_until_complete(proto.connected)
1108
1109        # extra info is available
1110        self.check_ssl_extra_info(client, peername=(host, port),
1111                                  peercert=test_utils.PEERCERT)
1112
1113        # close connection
1114        proto.transport.close()
1115        client.close()
1116        server.close()
1117        self.loop.run_until_complete(proto.done)
1118
1119    def test_create_server_sock(self):
1120        proto = self.loop.create_future()
1121
1122        class TestMyProto(MyProto):
1123            def connection_made(self, transport):
1124                super().connection_made(transport)
1125                proto.set_result(self)
1126
1127        sock_ob = socket.create_server(('0.0.0.0', 0))
1128
1129        f = self.loop.create_server(TestMyProto, sock=sock_ob)
1130        server = self.loop.run_until_complete(f)
1131        sock = server.sockets[0]
1132        self.assertEqual(sock.fileno(), sock_ob.fileno())
1133
1134        host, port = sock.getsockname()
1135        self.assertEqual(host, '0.0.0.0')
1136        client = socket.socket()
1137        client.connect(('127.0.0.1', port))
1138        client.send(b'xxx')
1139        client.close()
1140        server.close()
1141
1142    def test_create_server_addr_in_use(self):
1143        sock_ob = socket.create_server(('0.0.0.0', 0))
1144
1145        f = self.loop.create_server(MyProto, sock=sock_ob)
1146        server = self.loop.run_until_complete(f)
1147        sock = server.sockets[0]
1148        host, port = sock.getsockname()
1149
1150        f = self.loop.create_server(MyProto, host=host, port=port)
1151        with self.assertRaises(OSError) as cm:
1152            self.loop.run_until_complete(f)
1153        self.assertEqual(cm.exception.errno, errno.EADDRINUSE)
1154
1155        server.close()
1156
1157    @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 not supported or enabled')
1158    def test_create_server_dual_stack(self):
1159        f_proto = self.loop.create_future()
1160
1161        class TestMyProto(MyProto):
1162            def connection_made(self, transport):
1163                super().connection_made(transport)
1164                f_proto.set_result(self)
1165
1166        try_count = 0
1167        while True:
1168            try:
1169                port = socket_helper.find_unused_port()
1170                f = self.loop.create_server(TestMyProto, host=None, port=port)
1171                server = self.loop.run_until_complete(f)
1172            except OSError as ex:
1173                if ex.errno == errno.EADDRINUSE:
1174                    try_count += 1
1175                    self.assertGreaterEqual(5, try_count)
1176                    continue
1177                else:
1178                    raise
1179            else:
1180                break
1181        client = socket.socket()
1182        client.connect(('127.0.0.1', port))
1183        client.send(b'xxx')
1184        proto = self.loop.run_until_complete(f_proto)
1185        proto.transport.close()
1186        client.close()
1187
1188        f_proto = self.loop.create_future()
1189        client = socket.socket(socket.AF_INET6)
1190        client.connect(('::1', port))
1191        client.send(b'xxx')
1192        proto = self.loop.run_until_complete(f_proto)
1193        proto.transport.close()
1194        client.close()
1195
1196        server.close()
1197
1198    def test_server_close(self):
1199        f = self.loop.create_server(MyProto, '0.0.0.0', 0)
1200        server = self.loop.run_until_complete(f)
1201        sock = server.sockets[0]
1202        host, port = sock.getsockname()
1203
1204        client = socket.socket()
1205        client.connect(('127.0.0.1', port))
1206        client.send(b'xxx')
1207        client.close()
1208
1209        server.close()
1210
1211        client = socket.socket()
1212        self.assertRaises(
1213            ConnectionRefusedError, client.connect, ('127.0.0.1', port))
1214        client.close()
1215
1216    def _test_create_datagram_endpoint(self, local_addr, family):
1217        class TestMyDatagramProto(MyDatagramProto):
1218            def __init__(inner_self):
1219                super().__init__(loop=self.loop)
1220
1221            def datagram_received(self, data, addr):
1222                super().datagram_received(data, addr)
1223                self.transport.sendto(b'resp:'+data, addr)
1224
1225        coro = self.loop.create_datagram_endpoint(
1226            TestMyDatagramProto, local_addr=local_addr, family=family)
1227        s_transport, server = self.loop.run_until_complete(coro)
1228        sockname = s_transport.get_extra_info('sockname')
1229        host, port = socket.getnameinfo(
1230            sockname, socket.NI_NUMERICHOST|socket.NI_NUMERICSERV)
1231
1232        self.assertIsInstance(s_transport, asyncio.Transport)
1233        self.assertIsInstance(server, TestMyDatagramProto)
1234        self.assertEqual('INITIALIZED', server.state)
1235        self.assertIs(server.transport, s_transport)
1236
1237        coro = self.loop.create_datagram_endpoint(
1238            lambda: MyDatagramProto(loop=self.loop),
1239            remote_addr=(host, port))
1240        transport, client = self.loop.run_until_complete(coro)
1241
1242        self.assertIsInstance(transport, asyncio.Transport)
1243        self.assertIsInstance(client, MyDatagramProto)
1244        self.assertEqual('INITIALIZED', client.state)
1245        self.assertIs(client.transport, transport)
1246
1247        transport.sendto(b'xxx')
1248        test_utils.run_until(self.loop, lambda: server.nbytes)
1249        self.assertEqual(3, server.nbytes)
1250        test_utils.run_until(self.loop, lambda: client.nbytes)
1251
1252        # received
1253        self.assertEqual(8, client.nbytes)
1254
1255        # extra info is available
1256        self.assertIsNotNone(transport.get_extra_info('sockname'))
1257
1258        # close connection
1259        transport.close()
1260        self.loop.run_until_complete(client.done)
1261        self.assertEqual('CLOSED', client.state)
1262        server.transport.close()
1263
1264    def test_create_datagram_endpoint(self):
1265        self._test_create_datagram_endpoint(('127.0.0.1', 0), socket.AF_INET)
1266
1267    @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 not supported or enabled')
1268    def test_create_datagram_endpoint_ipv6(self):
1269        self._test_create_datagram_endpoint(('::1', 0), socket.AF_INET6)
1270
1271    def test_create_datagram_endpoint_sock(self):
1272        sock = None
1273        local_address = ('127.0.0.1', 0)
1274        infos = self.loop.run_until_complete(
1275            self.loop.getaddrinfo(
1276                *local_address, type=socket.SOCK_DGRAM))
1277        for family, type, proto, cname, address in infos:
1278            try:
1279                sock = socket.socket(family=family, type=type, proto=proto)
1280                sock.setblocking(False)
1281                sock.bind(address)
1282            except:
1283                pass
1284            else:
1285                break
1286        else:
1287            assert False, 'Can not create socket.'
1288
1289        f = self.loop.create_datagram_endpoint(
1290            lambda: MyDatagramProto(loop=self.loop), sock=sock)
1291        tr, pr = self.loop.run_until_complete(f)
1292        self.assertIsInstance(tr, asyncio.Transport)
1293        self.assertIsInstance(pr, MyDatagramProto)
1294        tr.close()
1295        self.loop.run_until_complete(pr.done)
1296
1297    def test_internal_fds(self):
1298        loop = self.create_event_loop()
1299        if not isinstance(loop, selector_events.BaseSelectorEventLoop):
1300            loop.close()
1301            self.skipTest('loop is not a BaseSelectorEventLoop')
1302
1303        self.assertEqual(1, loop._internal_fds)
1304        loop.close()
1305        self.assertEqual(0, loop._internal_fds)
1306        self.assertIsNone(loop._csock)
1307        self.assertIsNone(loop._ssock)
1308
1309    @unittest.skipUnless(sys.platform != 'win32',
1310                         "Don't support pipes for Windows")
1311    def test_read_pipe(self):
1312        proto = MyReadPipeProto(loop=self.loop)
1313
1314        rpipe, wpipe = os.pipe()
1315        pipeobj = io.open(rpipe, 'rb', 1024)
1316
1317        async def connect():
1318            t, p = await self.loop.connect_read_pipe(
1319                lambda: proto, pipeobj)
1320            self.assertIs(p, proto)
1321            self.assertIs(t, proto.transport)
1322            self.assertEqual(['INITIAL', 'CONNECTED'], proto.state)
1323            self.assertEqual(0, proto.nbytes)
1324
1325        self.loop.run_until_complete(connect())
1326
1327        os.write(wpipe, b'1')
1328        test_utils.run_until(self.loop, lambda: proto.nbytes >= 1)
1329        self.assertEqual(1, proto.nbytes)
1330
1331        os.write(wpipe, b'2345')
1332        test_utils.run_until(self.loop, lambda: proto.nbytes >= 5)
1333        self.assertEqual(['INITIAL', 'CONNECTED'], proto.state)
1334        self.assertEqual(5, proto.nbytes)
1335
1336        os.close(wpipe)
1337        self.loop.run_until_complete(proto.done)
1338        self.assertEqual(
1339            ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state)
1340        # extra info is available
1341        self.assertIsNotNone(proto.transport.get_extra_info('pipe'))
1342
1343    @unittest.skipUnless(sys.platform != 'win32',
1344                         "Don't support pipes for Windows")
1345    def test_unclosed_pipe_transport(self):
1346        # This test reproduces the issue #314 on GitHub
1347        loop = self.create_event_loop()
1348        read_proto = MyReadPipeProto(loop=loop)
1349        write_proto = MyWritePipeProto(loop=loop)
1350
1351        rpipe, wpipe = os.pipe()
1352        rpipeobj = io.open(rpipe, 'rb', 1024)
1353        wpipeobj = io.open(wpipe, 'w', 1024, encoding="utf-8")
1354
1355        async def connect():
1356            read_transport, _ = await loop.connect_read_pipe(
1357                lambda: read_proto, rpipeobj)
1358            write_transport, _ = await loop.connect_write_pipe(
1359                lambda: write_proto, wpipeobj)
1360            return read_transport, write_transport
1361
1362        # Run and close the loop without closing the transports
1363        read_transport, write_transport = loop.run_until_complete(connect())
1364        loop.close()
1365
1366        # These 'repr' calls used to raise an AttributeError
1367        # See Issue #314 on GitHub
1368        self.assertIn('open', repr(read_transport))
1369        self.assertIn('open', repr(write_transport))
1370
1371        # Clean up (avoid ResourceWarning)
1372        rpipeobj.close()
1373        wpipeobj.close()
1374        read_transport._pipe = None
1375        write_transport._pipe = None
1376
1377    @unittest.skipUnless(sys.platform != 'win32',
1378                         "Don't support pipes for Windows")
1379    @unittest.skipUnless(hasattr(os, 'openpty'), 'need os.openpty()')
1380    def test_read_pty_output(self):
1381        proto = MyReadPipeProto(loop=self.loop)
1382
1383        master, slave = os.openpty()
1384        master_read_obj = io.open(master, 'rb', 0)
1385
1386        async def connect():
1387            t, p = await self.loop.connect_read_pipe(lambda: proto,
1388                                                     master_read_obj)
1389            self.assertIs(p, proto)
1390            self.assertIs(t, proto.transport)
1391            self.assertEqual(['INITIAL', 'CONNECTED'], proto.state)
1392            self.assertEqual(0, proto.nbytes)
1393
1394        self.loop.run_until_complete(connect())
1395
1396        os.write(slave, b'1')
1397        test_utils.run_until(self.loop, lambda: proto.nbytes)
1398        self.assertEqual(1, proto.nbytes)
1399
1400        os.write(slave, b'2345')
1401        test_utils.run_until(self.loop, lambda: proto.nbytes >= 5)
1402        self.assertEqual(['INITIAL', 'CONNECTED'], proto.state)
1403        self.assertEqual(5, proto.nbytes)
1404
1405        os.close(slave)
1406        proto.transport.close()
1407        self.loop.run_until_complete(proto.done)
1408        self.assertEqual(
1409            ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state)
1410        # extra info is available
1411        self.assertIsNotNone(proto.transport.get_extra_info('pipe'))
1412
1413    @unittest.skipUnless(sys.platform != 'win32',
1414                         "Don't support pipes for Windows")
1415    def test_write_pipe(self):
1416        rpipe, wpipe = os.pipe()
1417        pipeobj = io.open(wpipe, 'wb', 1024)
1418
1419        proto = MyWritePipeProto(loop=self.loop)
1420        connect = self.loop.connect_write_pipe(lambda: proto, pipeobj)
1421        transport, p = self.loop.run_until_complete(connect)
1422        self.assertIs(p, proto)
1423        self.assertIs(transport, proto.transport)
1424        self.assertEqual('CONNECTED', proto.state)
1425
1426        transport.write(b'1')
1427
1428        data = bytearray()
1429        def reader(data):
1430            chunk = os.read(rpipe, 1024)
1431            data += chunk
1432            return len(data)
1433
1434        test_utils.run_until(self.loop, lambda: reader(data) >= 1)
1435        self.assertEqual(b'1', data)
1436
1437        transport.write(b'2345')
1438        test_utils.run_until(self.loop, lambda: reader(data) >= 5)
1439        self.assertEqual(b'12345', data)
1440        self.assertEqual('CONNECTED', proto.state)
1441
1442        os.close(rpipe)
1443
1444        # extra info is available
1445        self.assertIsNotNone(proto.transport.get_extra_info('pipe'))
1446
1447        # close connection
1448        proto.transport.close()
1449        self.loop.run_until_complete(proto.done)
1450        self.assertEqual('CLOSED', proto.state)
1451
1452    @unittest.skipUnless(sys.platform != 'win32',
1453                         "Don't support pipes for Windows")
1454    def test_write_pipe_disconnect_on_close(self):
1455        rsock, wsock = socket.socketpair()
1456        rsock.setblocking(False)
1457        pipeobj = io.open(wsock.detach(), 'wb', 1024)
1458
1459        proto = MyWritePipeProto(loop=self.loop)
1460        connect = self.loop.connect_write_pipe(lambda: proto, pipeobj)
1461        transport, p = self.loop.run_until_complete(connect)
1462        self.assertIs(p, proto)
1463        self.assertIs(transport, proto.transport)
1464        self.assertEqual('CONNECTED', proto.state)
1465
1466        transport.write(b'1')
1467        data = self.loop.run_until_complete(self.loop.sock_recv(rsock, 1024))
1468        self.assertEqual(b'1', data)
1469
1470        rsock.close()
1471
1472        self.loop.run_until_complete(proto.done)
1473        self.assertEqual('CLOSED', proto.state)
1474
1475    @unittest.skipUnless(sys.platform != 'win32',
1476                         "Don't support pipes for Windows")
1477    @unittest.skipUnless(hasattr(os, 'openpty'), 'need os.openpty()')
1478    # select, poll and kqueue don't support character devices (PTY) on Mac OS X
1479    # older than 10.6 (Snow Leopard)
1480    @support.requires_mac_ver(10, 6)
1481    def test_write_pty(self):
1482        master, slave = os.openpty()
1483        slave_write_obj = io.open(slave, 'wb', 0)
1484
1485        proto = MyWritePipeProto(loop=self.loop)
1486        connect = self.loop.connect_write_pipe(lambda: proto, slave_write_obj)
1487        transport, p = self.loop.run_until_complete(connect)
1488        self.assertIs(p, proto)
1489        self.assertIs(transport, proto.transport)
1490        self.assertEqual('CONNECTED', proto.state)
1491
1492        transport.write(b'1')
1493
1494        data = bytearray()
1495        def reader(data):
1496            chunk = os.read(master, 1024)
1497            data += chunk
1498            return len(data)
1499
1500        test_utils.run_until(self.loop, lambda: reader(data) >= 1,
1501                             timeout=support.SHORT_TIMEOUT)
1502        self.assertEqual(b'1', data)
1503
1504        transport.write(b'2345')
1505        test_utils.run_until(self.loop, lambda: reader(data) >= 5,
1506                             timeout=support.SHORT_TIMEOUT)
1507        self.assertEqual(b'12345', data)
1508        self.assertEqual('CONNECTED', proto.state)
1509
1510        os.close(master)
1511
1512        # extra info is available
1513        self.assertIsNotNone(proto.transport.get_extra_info('pipe'))
1514
1515        # close connection
1516        proto.transport.close()
1517        self.loop.run_until_complete(proto.done)
1518        self.assertEqual('CLOSED', proto.state)
1519
1520    @unittest.skipUnless(sys.platform != 'win32',
1521                         "Don't support pipes for Windows")
1522    @unittest.skipUnless(hasattr(os, 'openpty'), 'need os.openpty()')
1523    # select, poll and kqueue don't support character devices (PTY) on Mac OS X
1524    # older than 10.6 (Snow Leopard)
1525    @support.requires_mac_ver(10, 6)
1526    def test_bidirectional_pty(self):
1527        master, read_slave = os.openpty()
1528        write_slave = os.dup(read_slave)
1529        tty.setraw(read_slave)
1530
1531        slave_read_obj = io.open(read_slave, 'rb', 0)
1532        read_proto = MyReadPipeProto(loop=self.loop)
1533        read_connect = self.loop.connect_read_pipe(lambda: read_proto,
1534                                                   slave_read_obj)
1535        read_transport, p = self.loop.run_until_complete(read_connect)
1536        self.assertIs(p, read_proto)
1537        self.assertIs(read_transport, read_proto.transport)
1538        self.assertEqual(['INITIAL', 'CONNECTED'], read_proto.state)
1539        self.assertEqual(0, read_proto.nbytes)
1540
1541
1542        slave_write_obj = io.open(write_slave, 'wb', 0)
1543        write_proto = MyWritePipeProto(loop=self.loop)
1544        write_connect = self.loop.connect_write_pipe(lambda: write_proto,
1545                                                     slave_write_obj)
1546        write_transport, p = self.loop.run_until_complete(write_connect)
1547        self.assertIs(p, write_proto)
1548        self.assertIs(write_transport, write_proto.transport)
1549        self.assertEqual('CONNECTED', write_proto.state)
1550
1551        data = bytearray()
1552        def reader(data):
1553            chunk = os.read(master, 1024)
1554            data += chunk
1555            return len(data)
1556
1557        write_transport.write(b'1')
1558        test_utils.run_until(self.loop, lambda: reader(data) >= 1,
1559                             timeout=support.SHORT_TIMEOUT)
1560        self.assertEqual(b'1', data)
1561        self.assertEqual(['INITIAL', 'CONNECTED'], read_proto.state)
1562        self.assertEqual('CONNECTED', write_proto.state)
1563
1564        os.write(master, b'a')
1565        test_utils.run_until(self.loop, lambda: read_proto.nbytes >= 1,
1566                             timeout=support.SHORT_TIMEOUT)
1567        self.assertEqual(['INITIAL', 'CONNECTED'], read_proto.state)
1568        self.assertEqual(1, read_proto.nbytes)
1569        self.assertEqual('CONNECTED', write_proto.state)
1570
1571        write_transport.write(b'2345')
1572        test_utils.run_until(self.loop, lambda: reader(data) >= 5,
1573                             timeout=support.SHORT_TIMEOUT)
1574        self.assertEqual(b'12345', data)
1575        self.assertEqual(['INITIAL', 'CONNECTED'], read_proto.state)
1576        self.assertEqual('CONNECTED', write_proto.state)
1577
1578        os.write(master, b'bcde')
1579        test_utils.run_until(self.loop, lambda: read_proto.nbytes >= 5,
1580                             timeout=support.SHORT_TIMEOUT)
1581        self.assertEqual(['INITIAL', 'CONNECTED'], read_proto.state)
1582        self.assertEqual(5, read_proto.nbytes)
1583        self.assertEqual('CONNECTED', write_proto.state)
1584
1585        os.close(master)
1586
1587        read_transport.close()
1588        self.loop.run_until_complete(read_proto.done)
1589        self.assertEqual(
1590            ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], read_proto.state)
1591
1592        write_transport.close()
1593        self.loop.run_until_complete(write_proto.done)
1594        self.assertEqual('CLOSED', write_proto.state)
1595
1596    def test_prompt_cancellation(self):
1597        r, w = socket.socketpair()
1598        r.setblocking(False)
1599        f = self.loop.create_task(self.loop.sock_recv(r, 1))
1600        ov = getattr(f, 'ov', None)
1601        if ov is not None:
1602            self.assertTrue(ov.pending)
1603
1604        async def main():
1605            try:
1606                self.loop.call_soon(f.cancel)
1607                await f
1608            except asyncio.CancelledError:
1609                res = 'cancelled'
1610            else:
1611                res = None
1612            finally:
1613                self.loop.stop()
1614            return res
1615
1616        start = time.monotonic()
1617        t = self.loop.create_task(main())
1618        self.loop.run_forever()
1619        elapsed = time.monotonic() - start
1620
1621        self.assertLess(elapsed, 0.1)
1622        self.assertEqual(t.result(), 'cancelled')
1623        self.assertRaises(asyncio.CancelledError, f.result)
1624        if ov is not None:
1625            self.assertFalse(ov.pending)
1626        self.loop._stop_serving(r)
1627
1628        r.close()
1629        w.close()
1630
1631    def test_timeout_rounding(self):
1632        def _run_once():
1633            self.loop._run_once_counter += 1
1634            orig_run_once()
1635
1636        orig_run_once = self.loop._run_once
1637        self.loop._run_once_counter = 0
1638        self.loop._run_once = _run_once
1639
1640        async def wait():
1641            loop = self.loop
1642            await asyncio.sleep(1e-2)
1643            await asyncio.sleep(1e-4)
1644            await asyncio.sleep(1e-6)
1645            await asyncio.sleep(1e-8)
1646            await asyncio.sleep(1e-10)
1647
1648        self.loop.run_until_complete(wait())
1649        # The ideal number of call is 12, but on some platforms, the selector
1650        # may sleep at little bit less than timeout depending on the resolution
1651        # of the clock used by the kernel. Tolerate a few useless calls on
1652        # these platforms.
1653        self.assertLessEqual(self.loop._run_once_counter, 20,
1654            {'clock_resolution': self.loop._clock_resolution,
1655             'selector': self.loop._selector.__class__.__name__})
1656
1657    def test_remove_fds_after_closing(self):
1658        loop = self.create_event_loop()
1659        callback = lambda: None
1660        r, w = socket.socketpair()
1661        self.addCleanup(r.close)
1662        self.addCleanup(w.close)
1663        loop.add_reader(r, callback)
1664        loop.add_writer(w, callback)
1665        loop.close()
1666        self.assertFalse(loop.remove_reader(r))
1667        self.assertFalse(loop.remove_writer(w))
1668
1669    def test_add_fds_after_closing(self):
1670        loop = self.create_event_loop()
1671        callback = lambda: None
1672        r, w = socket.socketpair()
1673        self.addCleanup(r.close)
1674        self.addCleanup(w.close)
1675        loop.close()
1676        with self.assertRaises(RuntimeError):
1677            loop.add_reader(r, callback)
1678        with self.assertRaises(RuntimeError):
1679            loop.add_writer(w, callback)
1680
1681    def test_close_running_event_loop(self):
1682        async def close_loop(loop):
1683            self.loop.close()
1684
1685        coro = close_loop(self.loop)
1686        with self.assertRaises(RuntimeError):
1687            self.loop.run_until_complete(coro)
1688
1689    def test_close(self):
1690        self.loop.close()
1691
1692        async def test():
1693            pass
1694
1695        func = lambda: False
1696        coro = test()
1697        self.addCleanup(coro.close)
1698
1699        # operation blocked when the loop is closed
1700        with self.assertRaises(RuntimeError):
1701            self.loop.run_forever()
1702        with self.assertRaises(RuntimeError):
1703            fut = self.loop.create_future()
1704            self.loop.run_until_complete(fut)
1705        with self.assertRaises(RuntimeError):
1706            self.loop.call_soon(func)
1707        with self.assertRaises(RuntimeError):
1708            self.loop.call_soon_threadsafe(func)
1709        with self.assertRaises(RuntimeError):
1710            self.loop.call_later(1.0, func)
1711        with self.assertRaises(RuntimeError):
1712            self.loop.call_at(self.loop.time() + .0, func)
1713        with self.assertRaises(RuntimeError):
1714            self.loop.create_task(coro)
1715        with self.assertRaises(RuntimeError):
1716            self.loop.add_signal_handler(signal.SIGTERM, func)
1717
1718        # run_in_executor test is tricky: the method is a coroutine,
1719        # but run_until_complete cannot be called on closed loop.
1720        # Thus iterate once explicitly.
1721        with self.assertRaises(RuntimeError):
1722            it = self.loop.run_in_executor(None, func).__await__()
1723            next(it)
1724
1725
1726class SubprocessTestsMixin:
1727
1728    def check_terminated(self, returncode):
1729        if sys.platform == 'win32':
1730            self.assertIsInstance(returncode, int)
1731            # expect 1 but sometimes get 0
1732        else:
1733            self.assertEqual(-signal.SIGTERM, returncode)
1734
1735    def check_killed(self, returncode):
1736        if sys.platform == 'win32':
1737            self.assertIsInstance(returncode, int)
1738            # expect 1 but sometimes get 0
1739        else:
1740            self.assertEqual(-signal.SIGKILL, returncode)
1741
1742    def test_subprocess_exec(self):
1743        prog = os.path.join(os.path.dirname(__file__), 'echo.py')
1744
1745        connect = self.loop.subprocess_exec(
1746                        functools.partial(MySubprocessProtocol, self.loop),
1747                        sys.executable, prog)
1748
1749        transp, proto = self.loop.run_until_complete(connect)
1750        self.assertIsInstance(proto, MySubprocessProtocol)
1751        self.loop.run_until_complete(proto.connected)
1752        self.assertEqual('CONNECTED', proto.state)
1753
1754        stdin = transp.get_pipe_transport(0)
1755        stdin.write(b'Python The Winner')
1756        self.loop.run_until_complete(proto.got_data[1].wait())
1757        with test_utils.disable_logger():
1758            transp.close()
1759        self.loop.run_until_complete(proto.completed)
1760        self.check_killed(proto.returncode)
1761        self.assertEqual(b'Python The Winner', proto.data[1])
1762
1763    def test_subprocess_interactive(self):
1764        prog = os.path.join(os.path.dirname(__file__), 'echo.py')
1765
1766        connect = self.loop.subprocess_exec(
1767                        functools.partial(MySubprocessProtocol, self.loop),
1768                        sys.executable, prog)
1769
1770        transp, proto = self.loop.run_until_complete(connect)
1771        self.assertIsInstance(proto, MySubprocessProtocol)
1772        self.loop.run_until_complete(proto.connected)
1773        self.assertEqual('CONNECTED', proto.state)
1774
1775        stdin = transp.get_pipe_transport(0)
1776        stdin.write(b'Python ')
1777        self.loop.run_until_complete(proto.got_data[1].wait())
1778        proto.got_data[1].clear()
1779        self.assertEqual(b'Python ', proto.data[1])
1780
1781        stdin.write(b'The Winner')
1782        self.loop.run_until_complete(proto.got_data[1].wait())
1783        self.assertEqual(b'Python The Winner', proto.data[1])
1784
1785        with test_utils.disable_logger():
1786            transp.close()
1787        self.loop.run_until_complete(proto.completed)
1788        self.check_killed(proto.returncode)
1789
1790    def test_subprocess_shell(self):
1791        connect = self.loop.subprocess_shell(
1792                        functools.partial(MySubprocessProtocol, self.loop),
1793                        'echo Python')
1794        transp, proto = self.loop.run_until_complete(connect)
1795        self.assertIsInstance(proto, MySubprocessProtocol)
1796        self.loop.run_until_complete(proto.connected)
1797
1798        transp.get_pipe_transport(0).close()
1799        self.loop.run_until_complete(proto.completed)
1800        self.assertEqual(0, proto.returncode)
1801        self.assertTrue(all(f.done() for f in proto.disconnects.values()))
1802        self.assertEqual(proto.data[1].rstrip(b'\r\n'), b'Python')
1803        self.assertEqual(proto.data[2], b'')
1804        transp.close()
1805
1806    def test_subprocess_exitcode(self):
1807        connect = self.loop.subprocess_shell(
1808                        functools.partial(MySubprocessProtocol, self.loop),
1809                        'exit 7', stdin=None, stdout=None, stderr=None)
1810
1811        transp, proto = self.loop.run_until_complete(connect)
1812        self.assertIsInstance(proto, MySubprocessProtocol)
1813        self.loop.run_until_complete(proto.completed)
1814        self.assertEqual(7, proto.returncode)
1815        transp.close()
1816
1817    def test_subprocess_close_after_finish(self):
1818        connect = self.loop.subprocess_shell(
1819                        functools.partial(MySubprocessProtocol, self.loop),
1820                        'exit 7', stdin=None, stdout=None, stderr=None)
1821
1822        transp, proto = self.loop.run_until_complete(connect)
1823        self.assertIsInstance(proto, MySubprocessProtocol)
1824        self.assertIsNone(transp.get_pipe_transport(0))
1825        self.assertIsNone(transp.get_pipe_transport(1))
1826        self.assertIsNone(transp.get_pipe_transport(2))
1827        self.loop.run_until_complete(proto.completed)
1828        self.assertEqual(7, proto.returncode)
1829        self.assertIsNone(transp.close())
1830
1831    def test_subprocess_kill(self):
1832        prog = os.path.join(os.path.dirname(__file__), 'echo.py')
1833
1834        connect = self.loop.subprocess_exec(
1835                        functools.partial(MySubprocessProtocol, self.loop),
1836                        sys.executable, prog)
1837
1838        transp, proto = self.loop.run_until_complete(connect)
1839        self.assertIsInstance(proto, MySubprocessProtocol)
1840        self.loop.run_until_complete(proto.connected)
1841
1842        transp.kill()
1843        self.loop.run_until_complete(proto.completed)
1844        self.check_killed(proto.returncode)
1845        transp.close()
1846
1847    def test_subprocess_terminate(self):
1848        prog = os.path.join(os.path.dirname(__file__), 'echo.py')
1849
1850        connect = self.loop.subprocess_exec(
1851                        functools.partial(MySubprocessProtocol, self.loop),
1852                        sys.executable, prog)
1853
1854        transp, proto = self.loop.run_until_complete(connect)
1855        self.assertIsInstance(proto, MySubprocessProtocol)
1856        self.loop.run_until_complete(proto.connected)
1857
1858        transp.terminate()
1859        self.loop.run_until_complete(proto.completed)
1860        self.check_terminated(proto.returncode)
1861        transp.close()
1862
1863    @unittest.skipIf(sys.platform == 'win32', "Don't have SIGHUP")
1864    def test_subprocess_send_signal(self):
1865        # bpo-31034: Make sure that we get the default signal handler (killing
1866        # the process). The parent process may have decided to ignore SIGHUP,
1867        # and signal handlers are inherited.
1868        old_handler = signal.signal(signal.SIGHUP, signal.SIG_DFL)
1869        try:
1870            prog = os.path.join(os.path.dirname(__file__), 'echo.py')
1871
1872            connect = self.loop.subprocess_exec(
1873                            functools.partial(MySubprocessProtocol, self.loop),
1874                            sys.executable, prog)
1875
1876
1877            transp, proto = self.loop.run_until_complete(connect)
1878            self.assertIsInstance(proto, MySubprocessProtocol)
1879            self.loop.run_until_complete(proto.connected)
1880
1881            transp.send_signal(signal.SIGHUP)
1882            self.loop.run_until_complete(proto.completed)
1883            self.assertEqual(-signal.SIGHUP, proto.returncode)
1884            transp.close()
1885        finally:
1886            signal.signal(signal.SIGHUP, old_handler)
1887
1888    def test_subprocess_stderr(self):
1889        prog = os.path.join(os.path.dirname(__file__), 'echo2.py')
1890
1891        connect = self.loop.subprocess_exec(
1892                        functools.partial(MySubprocessProtocol, self.loop),
1893                        sys.executable, prog)
1894
1895        transp, proto = self.loop.run_until_complete(connect)
1896        self.assertIsInstance(proto, MySubprocessProtocol)
1897        self.loop.run_until_complete(proto.connected)
1898
1899        stdin = transp.get_pipe_transport(0)
1900        stdin.write(b'test')
1901
1902        self.loop.run_until_complete(proto.completed)
1903
1904        transp.close()
1905        self.assertEqual(b'OUT:test', proto.data[1])
1906        self.assertTrue(proto.data[2].startswith(b'ERR:test'), proto.data[2])
1907        self.assertEqual(0, proto.returncode)
1908
1909    def test_subprocess_stderr_redirect_to_stdout(self):
1910        prog = os.path.join(os.path.dirname(__file__), 'echo2.py')
1911
1912        connect = self.loop.subprocess_exec(
1913                        functools.partial(MySubprocessProtocol, self.loop),
1914                        sys.executable, prog, stderr=subprocess.STDOUT)
1915
1916
1917        transp, proto = self.loop.run_until_complete(connect)
1918        self.assertIsInstance(proto, MySubprocessProtocol)
1919        self.loop.run_until_complete(proto.connected)
1920
1921        stdin = transp.get_pipe_transport(0)
1922        self.assertIsNotNone(transp.get_pipe_transport(1))
1923        self.assertIsNone(transp.get_pipe_transport(2))
1924
1925        stdin.write(b'test')
1926        self.loop.run_until_complete(proto.completed)
1927        self.assertTrue(proto.data[1].startswith(b'OUT:testERR:test'),
1928                        proto.data[1])
1929        self.assertEqual(b'', proto.data[2])
1930
1931        transp.close()
1932        self.assertEqual(0, proto.returncode)
1933
1934    def test_subprocess_close_client_stream(self):
1935        prog = os.path.join(os.path.dirname(__file__), 'echo3.py')
1936
1937        connect = self.loop.subprocess_exec(
1938                        functools.partial(MySubprocessProtocol, self.loop),
1939                        sys.executable, prog)
1940
1941        transp, proto = self.loop.run_until_complete(connect)
1942        self.assertIsInstance(proto, MySubprocessProtocol)
1943        self.loop.run_until_complete(proto.connected)
1944
1945        stdin = transp.get_pipe_transport(0)
1946        stdout = transp.get_pipe_transport(1)
1947        stdin.write(b'test')
1948        self.loop.run_until_complete(proto.got_data[1].wait())
1949        self.assertEqual(b'OUT:test', proto.data[1])
1950
1951        stdout.close()
1952        self.loop.run_until_complete(proto.disconnects[1])
1953        stdin.write(b'xxx')
1954        self.loop.run_until_complete(proto.got_data[2].wait())
1955        if sys.platform != 'win32':
1956            self.assertEqual(b'ERR:BrokenPipeError', proto.data[2])
1957        else:
1958            # After closing the read-end of a pipe, writing to the
1959            # write-end using os.write() fails with errno==EINVAL and
1960            # GetLastError()==ERROR_INVALID_NAME on Windows!?!  (Using
1961            # WriteFile() we get ERROR_BROKEN_PIPE as expected.)
1962            self.assertEqual(b'ERR:OSError', proto.data[2])
1963        with test_utils.disable_logger():
1964            transp.close()
1965        self.loop.run_until_complete(proto.completed)
1966        self.check_killed(proto.returncode)
1967
1968    def test_subprocess_wait_no_same_group(self):
1969        # start the new process in a new session
1970        connect = self.loop.subprocess_shell(
1971                        functools.partial(MySubprocessProtocol, self.loop),
1972                        'exit 7', stdin=None, stdout=None, stderr=None,
1973                        start_new_session=True)
1974        transp, proto = self.loop.run_until_complete(connect)
1975        self.assertIsInstance(proto, MySubprocessProtocol)
1976        self.loop.run_until_complete(proto.completed)
1977        self.assertEqual(7, proto.returncode)
1978        transp.close()
1979
1980    def test_subprocess_exec_invalid_args(self):
1981        async def connect(**kwds):
1982            await self.loop.subprocess_exec(
1983                asyncio.SubprocessProtocol,
1984                'pwd', **kwds)
1985
1986        with self.assertRaises(ValueError):
1987            self.loop.run_until_complete(connect(universal_newlines=True))
1988        with self.assertRaises(ValueError):
1989            self.loop.run_until_complete(connect(bufsize=4096))
1990        with self.assertRaises(ValueError):
1991            self.loop.run_until_complete(connect(shell=True))
1992
1993    def test_subprocess_shell_invalid_args(self):
1994
1995        async def connect(cmd=None, **kwds):
1996            if not cmd:
1997                cmd = 'pwd'
1998            await self.loop.subprocess_shell(
1999                asyncio.SubprocessProtocol,
2000                cmd, **kwds)
2001
2002        with self.assertRaises(ValueError):
2003            self.loop.run_until_complete(connect(['ls', '-l']))
2004        with self.assertRaises(ValueError):
2005            self.loop.run_until_complete(connect(universal_newlines=True))
2006        with self.assertRaises(ValueError):
2007            self.loop.run_until_complete(connect(bufsize=4096))
2008        with self.assertRaises(ValueError):
2009            self.loop.run_until_complete(connect(shell=False))
2010
2011
2012if sys.platform == 'win32':
2013
2014    class SelectEventLoopTests(EventLoopTestsMixin,
2015                               test_utils.TestCase):
2016
2017        def create_event_loop(self):
2018            return asyncio.SelectorEventLoop()
2019
2020    class ProactorEventLoopTests(EventLoopTestsMixin,
2021                                 SubprocessTestsMixin,
2022                                 test_utils.TestCase):
2023
2024        def create_event_loop(self):
2025            return asyncio.ProactorEventLoop()
2026
2027        def test_reader_callback(self):
2028            raise unittest.SkipTest("IocpEventLoop does not have add_reader()")
2029
2030        def test_reader_callback_cancel(self):
2031            raise unittest.SkipTest("IocpEventLoop does not have add_reader()")
2032
2033        def test_writer_callback(self):
2034            raise unittest.SkipTest("IocpEventLoop does not have add_writer()")
2035
2036        def test_writer_callback_cancel(self):
2037            raise unittest.SkipTest("IocpEventLoop does not have add_writer()")
2038
2039        def test_remove_fds_after_closing(self):
2040            raise unittest.SkipTest("IocpEventLoop does not have add_reader()")
2041else:
2042    import selectors
2043
2044    class UnixEventLoopTestsMixin(EventLoopTestsMixin):
2045        def setUp(self):
2046            super().setUp()
2047            watcher = asyncio.SafeChildWatcher()
2048            watcher.attach_loop(self.loop)
2049            asyncio.set_child_watcher(watcher)
2050
2051        def tearDown(self):
2052            asyncio.set_child_watcher(None)
2053            super().tearDown()
2054
2055
2056    if hasattr(selectors, 'KqueueSelector'):
2057        class KqueueEventLoopTests(UnixEventLoopTestsMixin,
2058                                   SubprocessTestsMixin,
2059                                   test_utils.TestCase):
2060
2061            def create_event_loop(self):
2062                return asyncio.SelectorEventLoop(
2063                    selectors.KqueueSelector())
2064
2065            # kqueue doesn't support character devices (PTY) on Mac OS X older
2066            # than 10.9 (Maverick)
2067            @support.requires_mac_ver(10, 9)
2068            # Issue #20667: KqueueEventLoopTests.test_read_pty_output()
2069            # hangs on OpenBSD 5.5
2070            @unittest.skipIf(sys.platform.startswith('openbsd'),
2071                             'test hangs on OpenBSD')
2072            def test_read_pty_output(self):
2073                super().test_read_pty_output()
2074
2075            # kqueue doesn't support character devices (PTY) on Mac OS X older
2076            # than 10.9 (Maverick)
2077            @support.requires_mac_ver(10, 9)
2078            def test_write_pty(self):
2079                super().test_write_pty()
2080
2081    if hasattr(selectors, 'EpollSelector'):
2082        class EPollEventLoopTests(UnixEventLoopTestsMixin,
2083                                  SubprocessTestsMixin,
2084                                  test_utils.TestCase):
2085
2086            def create_event_loop(self):
2087                return asyncio.SelectorEventLoop(selectors.EpollSelector())
2088
2089    if hasattr(selectors, 'PollSelector'):
2090        class PollEventLoopTests(UnixEventLoopTestsMixin,
2091                                 SubprocessTestsMixin,
2092                                 test_utils.TestCase):
2093
2094            def create_event_loop(self):
2095                return asyncio.SelectorEventLoop(selectors.PollSelector())
2096
2097    # Should always exist.
2098    class SelectEventLoopTests(UnixEventLoopTestsMixin,
2099                               SubprocessTestsMixin,
2100                               test_utils.TestCase):
2101
2102        def create_event_loop(self):
2103            return asyncio.SelectorEventLoop(selectors.SelectSelector())
2104
2105
2106def noop(*args, **kwargs):
2107    pass
2108
2109
2110class HandleTests(test_utils.TestCase):
2111
2112    def setUp(self):
2113        super().setUp()
2114        self.loop = mock.Mock()
2115        self.loop.get_debug.return_value = True
2116
2117    def test_handle(self):
2118        def callback(*args):
2119            return args
2120
2121        args = ()
2122        h = asyncio.Handle(callback, args, self.loop)
2123        self.assertIs(h._callback, callback)
2124        self.assertIs(h._args, args)
2125        self.assertFalse(h.cancelled())
2126
2127        h.cancel()
2128        self.assertTrue(h.cancelled())
2129
2130    def test_callback_with_exception(self):
2131        def callback():
2132            raise ValueError()
2133
2134        self.loop = mock.Mock()
2135        self.loop.call_exception_handler = mock.Mock()
2136
2137        h = asyncio.Handle(callback, (), self.loop)
2138        h._run()
2139
2140        self.loop.call_exception_handler.assert_called_with({
2141            'message': test_utils.MockPattern('Exception in callback.*'),
2142            'exception': mock.ANY,
2143            'handle': h,
2144            'source_traceback': h._source_traceback,
2145        })
2146
2147    def test_handle_weakref(self):
2148        wd = weakref.WeakValueDictionary()
2149        h = asyncio.Handle(lambda: None, (), self.loop)
2150        wd['h'] = h  # Would fail without __weakref__ slot.
2151
2152    def test_handle_repr(self):
2153        self.loop.get_debug.return_value = False
2154
2155        # simple function
2156        h = asyncio.Handle(noop, (1, 2), self.loop)
2157        filename, lineno = test_utils.get_function_source(noop)
2158        self.assertEqual(repr(h),
2159                        '<Handle noop(1, 2) at %s:%s>'
2160                        % (filename, lineno))
2161
2162        # cancelled handle
2163        h.cancel()
2164        self.assertEqual(repr(h),
2165                        '<Handle cancelled>')
2166
2167        # decorated function
2168        cb = types.coroutine(noop)
2169        h = asyncio.Handle(cb, (), self.loop)
2170        self.assertEqual(repr(h),
2171                        '<Handle noop() at %s:%s>'
2172                        % (filename, lineno))
2173
2174        # partial function
2175        cb = functools.partial(noop, 1, 2)
2176        h = asyncio.Handle(cb, (3,), self.loop)
2177        regex = (r'^<Handle noop\(1, 2\)\(3\) at %s:%s>$'
2178                 % (re.escape(filename), lineno))
2179        self.assertRegex(repr(h), regex)
2180
2181        # partial function with keyword args
2182        cb = functools.partial(noop, x=1)
2183        h = asyncio.Handle(cb, (2, 3), self.loop)
2184        regex = (r'^<Handle noop\(x=1\)\(2, 3\) at %s:%s>$'
2185                 % (re.escape(filename), lineno))
2186        self.assertRegex(repr(h), regex)
2187
2188        # partial method
2189        if sys.version_info >= (3, 4):
2190            method = HandleTests.test_handle_repr
2191            cb = functools.partialmethod(method)
2192            filename, lineno = test_utils.get_function_source(method)
2193            h = asyncio.Handle(cb, (), self.loop)
2194
2195            cb_regex = r'<function HandleTests.test_handle_repr .*>'
2196            cb_regex = (r'functools.partialmethod\(%s, , \)\(\)' % cb_regex)
2197            regex = (r'^<Handle %s at %s:%s>$'
2198                     % (cb_regex, re.escape(filename), lineno))
2199            self.assertRegex(repr(h), regex)
2200
2201    def test_handle_repr_debug(self):
2202        self.loop.get_debug.return_value = True
2203
2204        # simple function
2205        create_filename = __file__
2206        create_lineno = sys._getframe().f_lineno + 1
2207        h = asyncio.Handle(noop, (1, 2), self.loop)
2208        filename, lineno = test_utils.get_function_source(noop)
2209        self.assertEqual(repr(h),
2210                        '<Handle noop(1, 2) at %s:%s created at %s:%s>'
2211                        % (filename, lineno, create_filename, create_lineno))
2212
2213        # cancelled handle
2214        h.cancel()
2215        self.assertEqual(
2216            repr(h),
2217            '<Handle cancelled noop(1, 2) at %s:%s created at %s:%s>'
2218            % (filename, lineno, create_filename, create_lineno))
2219
2220        # double cancellation won't overwrite _repr
2221        h.cancel()
2222        self.assertEqual(
2223            repr(h),
2224            '<Handle cancelled noop(1, 2) at %s:%s created at %s:%s>'
2225            % (filename, lineno, create_filename, create_lineno))
2226
2227    def test_handle_source_traceback(self):
2228        loop = asyncio.get_event_loop_policy().new_event_loop()
2229        loop.set_debug(True)
2230        self.set_event_loop(loop)
2231
2232        def check_source_traceback(h):
2233            lineno = sys._getframe(1).f_lineno - 1
2234            self.assertIsInstance(h._source_traceback, list)
2235            self.assertEqual(h._source_traceback[-1][:3],
2236                             (__file__,
2237                              lineno,
2238                              'test_handle_source_traceback'))
2239
2240        # call_soon
2241        h = loop.call_soon(noop)
2242        check_source_traceback(h)
2243
2244        # call_soon_threadsafe
2245        h = loop.call_soon_threadsafe(noop)
2246        check_source_traceback(h)
2247
2248        # call_later
2249        h = loop.call_later(0, noop)
2250        check_source_traceback(h)
2251
2252        # call_at
2253        h = loop.call_later(0, noop)
2254        check_source_traceback(h)
2255
2256    @unittest.skipUnless(hasattr(collections.abc, 'Coroutine'),
2257                         'No collections.abc.Coroutine')
2258    def test_coroutine_like_object_debug_formatting(self):
2259        # Test that asyncio can format coroutines that are instances of
2260        # collections.abc.Coroutine, but lack cr_core or gi_code attributes
2261        # (such as ones compiled with Cython).
2262
2263        coro = CoroLike()
2264        coro.__name__ = 'AAA'
2265        self.assertTrue(asyncio.iscoroutine(coro))
2266        self.assertEqual(coroutines._format_coroutine(coro), 'AAA()')
2267
2268        coro.__qualname__ = 'BBB'
2269        self.assertEqual(coroutines._format_coroutine(coro), 'BBB()')
2270
2271        coro.cr_running = True
2272        self.assertEqual(coroutines._format_coroutine(coro), 'BBB() running')
2273
2274        coro.__name__ = coro.__qualname__ = None
2275        self.assertEqual(coroutines._format_coroutine(coro),
2276                         '<CoroLike without __name__>() running')
2277
2278        coro = CoroLike()
2279        coro.__qualname__ = 'CoroLike'
2280        # Some coroutines might not have '__name__', such as
2281        # built-in async_gen.asend().
2282        self.assertEqual(coroutines._format_coroutine(coro), 'CoroLike()')
2283
2284        coro = CoroLike()
2285        coro.__qualname__ = 'AAA'
2286        coro.cr_code = None
2287        self.assertEqual(coroutines._format_coroutine(coro), 'AAA()')
2288
2289
2290class TimerTests(unittest.TestCase):
2291
2292    def setUp(self):
2293        super().setUp()
2294        self.loop = mock.Mock()
2295
2296    def test_hash(self):
2297        when = time.monotonic()
2298        h = asyncio.TimerHandle(when, lambda: False, (),
2299                                mock.Mock())
2300        self.assertEqual(hash(h), hash(when))
2301
2302    def test_when(self):
2303        when = time.monotonic()
2304        h = asyncio.TimerHandle(when, lambda: False, (),
2305                                mock.Mock())
2306        self.assertEqual(when, h.when())
2307
2308    def test_timer(self):
2309        def callback(*args):
2310            return args
2311
2312        args = (1, 2, 3)
2313        when = time.monotonic()
2314        h = asyncio.TimerHandle(when, callback, args, mock.Mock())
2315        self.assertIs(h._callback, callback)
2316        self.assertIs(h._args, args)
2317        self.assertFalse(h.cancelled())
2318
2319        # cancel
2320        h.cancel()
2321        self.assertTrue(h.cancelled())
2322        self.assertIsNone(h._callback)
2323        self.assertIsNone(h._args)
2324
2325
2326    def test_timer_repr(self):
2327        self.loop.get_debug.return_value = False
2328
2329        # simple function
2330        h = asyncio.TimerHandle(123, noop, (), self.loop)
2331        src = test_utils.get_function_source(noop)
2332        self.assertEqual(repr(h),
2333                        '<TimerHandle when=123 noop() at %s:%s>' % src)
2334
2335        # cancelled handle
2336        h.cancel()
2337        self.assertEqual(repr(h),
2338                        '<TimerHandle cancelled when=123>')
2339
2340    def test_timer_repr_debug(self):
2341        self.loop.get_debug.return_value = True
2342
2343        # simple function
2344        create_filename = __file__
2345        create_lineno = sys._getframe().f_lineno + 1
2346        h = asyncio.TimerHandle(123, noop, (), self.loop)
2347        filename, lineno = test_utils.get_function_source(noop)
2348        self.assertEqual(repr(h),
2349                        '<TimerHandle when=123 noop() '
2350                        'at %s:%s created at %s:%s>'
2351                        % (filename, lineno, create_filename, create_lineno))
2352
2353        # cancelled handle
2354        h.cancel()
2355        self.assertEqual(repr(h),
2356                        '<TimerHandle cancelled when=123 noop() '
2357                        'at %s:%s created at %s:%s>'
2358                        % (filename, lineno, create_filename, create_lineno))
2359
2360
2361    def test_timer_comparison(self):
2362        def callback(*args):
2363            return args
2364
2365        when = time.monotonic()
2366
2367        h1 = asyncio.TimerHandle(when, callback, (), self.loop)
2368        h2 = asyncio.TimerHandle(when, callback, (), self.loop)
2369        # TODO: Use assertLess etc.
2370        self.assertFalse(h1 < h2)
2371        self.assertFalse(h2 < h1)
2372        self.assertTrue(h1 <= h2)
2373        self.assertTrue(h2 <= h1)
2374        self.assertFalse(h1 > h2)
2375        self.assertFalse(h2 > h1)
2376        self.assertTrue(h1 >= h2)
2377        self.assertTrue(h2 >= h1)
2378        self.assertTrue(h1 == h2)
2379        self.assertFalse(h1 != h2)
2380
2381        h2.cancel()
2382        self.assertFalse(h1 == h2)
2383
2384        h1 = asyncio.TimerHandle(when, callback, (), self.loop)
2385        h2 = asyncio.TimerHandle(when + 10.0, callback, (), self.loop)
2386        self.assertTrue(h1 < h2)
2387        self.assertFalse(h2 < h1)
2388        self.assertTrue(h1 <= h2)
2389        self.assertFalse(h2 <= h1)
2390        self.assertFalse(h1 > h2)
2391        self.assertTrue(h2 > h1)
2392        self.assertFalse(h1 >= h2)
2393        self.assertTrue(h2 >= h1)
2394        self.assertFalse(h1 == h2)
2395        self.assertTrue(h1 != h2)
2396
2397        h3 = asyncio.Handle(callback, (), self.loop)
2398        self.assertIs(NotImplemented, h1.__eq__(h3))
2399        self.assertIs(NotImplemented, h1.__ne__(h3))
2400
2401        with self.assertRaises(TypeError):
2402            h1 < ()
2403        with self.assertRaises(TypeError):
2404            h1 > ()
2405        with self.assertRaises(TypeError):
2406            h1 <= ()
2407        with self.assertRaises(TypeError):
2408            h1 >= ()
2409        self.assertFalse(h1 == ())
2410        self.assertTrue(h1 != ())
2411
2412        self.assertTrue(h1 == ALWAYS_EQ)
2413        self.assertFalse(h1 != ALWAYS_EQ)
2414        self.assertTrue(h1 < LARGEST)
2415        self.assertFalse(h1 > LARGEST)
2416        self.assertTrue(h1 <= LARGEST)
2417        self.assertFalse(h1 >= LARGEST)
2418        self.assertFalse(h1 < SMALLEST)
2419        self.assertTrue(h1 > SMALLEST)
2420        self.assertFalse(h1 <= SMALLEST)
2421        self.assertTrue(h1 >= SMALLEST)
2422
2423
2424class AbstractEventLoopTests(unittest.TestCase):
2425
2426    def test_not_implemented(self):
2427        f = mock.Mock()
2428        loop = asyncio.AbstractEventLoop()
2429        self.assertRaises(
2430            NotImplementedError, loop.run_forever)
2431        self.assertRaises(
2432            NotImplementedError, loop.run_until_complete, None)
2433        self.assertRaises(
2434            NotImplementedError, loop.stop)
2435        self.assertRaises(
2436            NotImplementedError, loop.is_running)
2437        self.assertRaises(
2438            NotImplementedError, loop.is_closed)
2439        self.assertRaises(
2440            NotImplementedError, loop.close)
2441        self.assertRaises(
2442            NotImplementedError, loop.create_task, None)
2443        self.assertRaises(
2444            NotImplementedError, loop.call_later, None, None)
2445        self.assertRaises(
2446            NotImplementedError, loop.call_at, f, f)
2447        self.assertRaises(
2448            NotImplementedError, loop.call_soon, None)
2449        self.assertRaises(
2450            NotImplementedError, loop.time)
2451        self.assertRaises(
2452            NotImplementedError, loop.call_soon_threadsafe, None)
2453        self.assertRaises(
2454            NotImplementedError, loop.set_default_executor, f)
2455        self.assertRaises(
2456            NotImplementedError, loop.add_reader, 1, f)
2457        self.assertRaises(
2458            NotImplementedError, loop.remove_reader, 1)
2459        self.assertRaises(
2460            NotImplementedError, loop.add_writer, 1, f)
2461        self.assertRaises(
2462            NotImplementedError, loop.remove_writer, 1)
2463        self.assertRaises(
2464            NotImplementedError, loop.add_signal_handler, 1, f)
2465        self.assertRaises(
2466            NotImplementedError, loop.remove_signal_handler, 1)
2467        self.assertRaises(
2468            NotImplementedError, loop.remove_signal_handler, 1)
2469        self.assertRaises(
2470            NotImplementedError, loop.set_exception_handler, f)
2471        self.assertRaises(
2472            NotImplementedError, loop.default_exception_handler, f)
2473        self.assertRaises(
2474            NotImplementedError, loop.call_exception_handler, f)
2475        self.assertRaises(
2476            NotImplementedError, loop.get_debug)
2477        self.assertRaises(
2478            NotImplementedError, loop.set_debug, f)
2479
2480    def test_not_implemented_async(self):
2481
2482        async def inner():
2483            f = mock.Mock()
2484            loop = asyncio.AbstractEventLoop()
2485
2486            with self.assertRaises(NotImplementedError):
2487                await loop.run_in_executor(f, f)
2488            with self.assertRaises(NotImplementedError):
2489                await loop.getaddrinfo('localhost', 8080)
2490            with self.assertRaises(NotImplementedError):
2491                await loop.getnameinfo(('localhost', 8080))
2492            with self.assertRaises(NotImplementedError):
2493                await loop.create_connection(f)
2494            with self.assertRaises(NotImplementedError):
2495                await loop.create_server(f)
2496            with self.assertRaises(NotImplementedError):
2497                await loop.create_datagram_endpoint(f)
2498            with self.assertRaises(NotImplementedError):
2499                await loop.sock_recv(f, 10)
2500            with self.assertRaises(NotImplementedError):
2501                await loop.sock_recv_into(f, 10)
2502            with self.assertRaises(NotImplementedError):
2503                await loop.sock_sendall(f, 10)
2504            with self.assertRaises(NotImplementedError):
2505                await loop.sock_connect(f, f)
2506            with self.assertRaises(NotImplementedError):
2507                await loop.sock_accept(f)
2508            with self.assertRaises(NotImplementedError):
2509                await loop.sock_sendfile(f, f)
2510            with self.assertRaises(NotImplementedError):
2511                await loop.sendfile(f, f)
2512            with self.assertRaises(NotImplementedError):
2513                await loop.connect_read_pipe(f, mock.sentinel.pipe)
2514            with self.assertRaises(NotImplementedError):
2515                await loop.connect_write_pipe(f, mock.sentinel.pipe)
2516            with self.assertRaises(NotImplementedError):
2517                await loop.subprocess_shell(f, mock.sentinel)
2518            with self.assertRaises(NotImplementedError):
2519                await loop.subprocess_exec(f)
2520
2521        loop = asyncio.new_event_loop()
2522        loop.run_until_complete(inner())
2523        loop.close()
2524
2525
2526class PolicyTests(unittest.TestCase):
2527
2528    def test_event_loop_policy(self):
2529        policy = asyncio.AbstractEventLoopPolicy()
2530        self.assertRaises(NotImplementedError, policy.get_event_loop)
2531        self.assertRaises(NotImplementedError, policy.set_event_loop, object())
2532        self.assertRaises(NotImplementedError, policy.new_event_loop)
2533        self.assertRaises(NotImplementedError, policy.get_child_watcher)
2534        self.assertRaises(NotImplementedError, policy.set_child_watcher,
2535                          object())
2536
2537    def test_get_event_loop(self):
2538        policy = asyncio.DefaultEventLoopPolicy()
2539        self.assertIsNone(policy._local._loop)
2540
2541        loop = policy.get_event_loop()
2542        self.assertIsInstance(loop, asyncio.AbstractEventLoop)
2543
2544        self.assertIs(policy._local._loop, loop)
2545        self.assertIs(loop, policy.get_event_loop())
2546        loop.close()
2547
2548    def test_get_event_loop_calls_set_event_loop(self):
2549        policy = asyncio.DefaultEventLoopPolicy()
2550
2551        with mock.patch.object(
2552                policy, "set_event_loop",
2553                wraps=policy.set_event_loop) as m_set_event_loop:
2554
2555            loop = policy.get_event_loop()
2556
2557            # policy._local._loop must be set through .set_event_loop()
2558            # (the unix DefaultEventLoopPolicy needs this call to attach
2559            # the child watcher correctly)
2560            m_set_event_loop.assert_called_with(loop)
2561
2562        loop.close()
2563
2564    def test_get_event_loop_after_set_none(self):
2565        policy = asyncio.DefaultEventLoopPolicy()
2566        policy.set_event_loop(None)
2567        self.assertRaises(RuntimeError, policy.get_event_loop)
2568
2569    @mock.patch('asyncio.events.threading.current_thread')
2570    def test_get_event_loop_thread(self, m_current_thread):
2571
2572        def f():
2573            policy = asyncio.DefaultEventLoopPolicy()
2574            self.assertRaises(RuntimeError, policy.get_event_loop)
2575
2576        th = threading.Thread(target=f)
2577        th.start()
2578        th.join()
2579
2580    def test_new_event_loop(self):
2581        policy = asyncio.DefaultEventLoopPolicy()
2582
2583        loop = policy.new_event_loop()
2584        self.assertIsInstance(loop, asyncio.AbstractEventLoop)
2585        loop.close()
2586
2587    def test_set_event_loop(self):
2588        policy = asyncio.DefaultEventLoopPolicy()
2589        old_loop = policy.get_event_loop()
2590
2591        self.assertRaises(TypeError, policy.set_event_loop, object())
2592
2593        loop = policy.new_event_loop()
2594        policy.set_event_loop(loop)
2595        self.assertIs(loop, policy.get_event_loop())
2596        self.assertIsNot(old_loop, policy.get_event_loop())
2597        loop.close()
2598        old_loop.close()
2599
2600    def test_get_event_loop_policy(self):
2601        policy = asyncio.get_event_loop_policy()
2602        self.assertIsInstance(policy, asyncio.AbstractEventLoopPolicy)
2603        self.assertIs(policy, asyncio.get_event_loop_policy())
2604
2605    def test_set_event_loop_policy(self):
2606        self.assertRaises(
2607            TypeError, asyncio.set_event_loop_policy, object())
2608
2609        old_policy = asyncio.get_event_loop_policy()
2610
2611        policy = asyncio.DefaultEventLoopPolicy()
2612        asyncio.set_event_loop_policy(policy)
2613        self.assertIs(policy, asyncio.get_event_loop_policy())
2614        self.assertIsNot(policy, old_policy)
2615
2616
2617class GetEventLoopTestsMixin:
2618
2619    _get_running_loop_impl = None
2620    _set_running_loop_impl = None
2621    get_running_loop_impl = None
2622    get_event_loop_impl = None
2623
2624    def setUp(self):
2625        self._get_running_loop_saved = events._get_running_loop
2626        self._set_running_loop_saved = events._set_running_loop
2627        self.get_running_loop_saved = events.get_running_loop
2628        self.get_event_loop_saved = events.get_event_loop
2629
2630        events._get_running_loop = type(self)._get_running_loop_impl
2631        events._set_running_loop = type(self)._set_running_loop_impl
2632        events.get_running_loop = type(self).get_running_loop_impl
2633        events.get_event_loop = type(self).get_event_loop_impl
2634
2635        asyncio._get_running_loop = type(self)._get_running_loop_impl
2636        asyncio._set_running_loop = type(self)._set_running_loop_impl
2637        asyncio.get_running_loop = type(self).get_running_loop_impl
2638        asyncio.get_event_loop = type(self).get_event_loop_impl
2639
2640        super().setUp()
2641
2642        self.loop = asyncio.new_event_loop()
2643        asyncio.set_event_loop(self.loop)
2644
2645        if sys.platform != 'win32':
2646            watcher = asyncio.SafeChildWatcher()
2647            watcher.attach_loop(self.loop)
2648            asyncio.set_child_watcher(watcher)
2649
2650    def tearDown(self):
2651        try:
2652            if sys.platform != 'win32':
2653                asyncio.set_child_watcher(None)
2654
2655            super().tearDown()
2656        finally:
2657            self.loop.close()
2658            asyncio.set_event_loop(None)
2659
2660            events._get_running_loop = self._get_running_loop_saved
2661            events._set_running_loop = self._set_running_loop_saved
2662            events.get_running_loop = self.get_running_loop_saved
2663            events.get_event_loop = self.get_event_loop_saved
2664
2665            asyncio._get_running_loop = self._get_running_loop_saved
2666            asyncio._set_running_loop = self._set_running_loop_saved
2667            asyncio.get_running_loop = self.get_running_loop_saved
2668            asyncio.get_event_loop = self.get_event_loop_saved
2669
2670    if sys.platform != 'win32':
2671
2672        def test_get_event_loop_new_process(self):
2673            # bpo-32126: The multiprocessing module used by
2674            # ProcessPoolExecutor is not functional when the
2675            # multiprocessing.synchronize module cannot be imported.
2676            support.skip_if_broken_multiprocessing_synchronize()
2677
2678            async def main():
2679                pool = concurrent.futures.ProcessPoolExecutor()
2680                result = await self.loop.run_in_executor(
2681                    pool, _test_get_event_loop_new_process__sub_proc)
2682                pool.shutdown()
2683                return result
2684
2685            self.assertEqual(
2686                self.loop.run_until_complete(main()),
2687                'hello')
2688
2689    def test_get_event_loop_returns_running_loop(self):
2690        class TestError(Exception):
2691            pass
2692
2693        class Policy(asyncio.DefaultEventLoopPolicy):
2694            def get_event_loop(self):
2695                raise TestError
2696
2697        old_policy = asyncio.get_event_loop_policy()
2698        try:
2699            asyncio.set_event_loop_policy(Policy())
2700            loop = asyncio.new_event_loop()
2701
2702            with self.assertWarns(DeprecationWarning) as cm:
2703                with self.assertRaises(TestError):
2704                    asyncio.get_event_loop()
2705            self.assertEqual(cm.warnings[0].filename, __file__)
2706            asyncio.set_event_loop(None)
2707            with self.assertWarns(DeprecationWarning) as cm:
2708                with self.assertRaises(TestError):
2709                    asyncio.get_event_loop()
2710            self.assertEqual(cm.warnings[0].filename, __file__)
2711
2712            with self.assertRaisesRegex(RuntimeError, 'no running'):
2713                asyncio.get_running_loop()
2714            self.assertIs(asyncio._get_running_loop(), None)
2715
2716            async def func():
2717                self.assertIs(asyncio.get_event_loop(), loop)
2718                self.assertIs(asyncio.get_running_loop(), loop)
2719                self.assertIs(asyncio._get_running_loop(), loop)
2720
2721            loop.run_until_complete(func())
2722
2723            asyncio.set_event_loop(loop)
2724            with self.assertWarns(DeprecationWarning) as cm:
2725                with self.assertRaises(TestError):
2726                    asyncio.get_event_loop()
2727            self.assertEqual(cm.warnings[0].filename, __file__)
2728
2729            asyncio.set_event_loop(None)
2730            with self.assertWarns(DeprecationWarning) as cm:
2731                with self.assertRaises(TestError):
2732                    asyncio.get_event_loop()
2733            self.assertEqual(cm.warnings[0].filename, __file__)
2734
2735        finally:
2736            asyncio.set_event_loop_policy(old_policy)
2737            if loop is not None:
2738                loop.close()
2739
2740        with self.assertRaisesRegex(RuntimeError, 'no running'):
2741            asyncio.get_running_loop()
2742
2743        self.assertIs(asyncio._get_running_loop(), None)
2744
2745    def test_get_event_loop_returns_running_loop2(self):
2746        old_policy = asyncio.get_event_loop_policy()
2747        try:
2748            asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
2749            loop = asyncio.new_event_loop()
2750            self.addCleanup(loop.close)
2751
2752            with self.assertWarns(DeprecationWarning) as cm:
2753                loop2 = asyncio.get_event_loop()
2754            self.addCleanup(loop2.close)
2755            self.assertEqual(cm.warnings[0].filename, __file__)
2756            asyncio.set_event_loop(None)
2757            with self.assertWarns(DeprecationWarning) as cm:
2758                with self.assertRaisesRegex(RuntimeError, 'no current'):
2759                    asyncio.get_event_loop()
2760            self.assertEqual(cm.warnings[0].filename, __file__)
2761
2762            with self.assertRaisesRegex(RuntimeError, 'no running'):
2763                asyncio.get_running_loop()
2764            self.assertIs(asyncio._get_running_loop(), None)
2765
2766            async def func():
2767                self.assertIs(asyncio.get_event_loop(), loop)
2768                self.assertIs(asyncio.get_running_loop(), loop)
2769                self.assertIs(asyncio._get_running_loop(), loop)
2770
2771            loop.run_until_complete(func())
2772
2773            asyncio.set_event_loop(loop)
2774            with self.assertWarns(DeprecationWarning) as cm:
2775                self.assertIs(asyncio.get_event_loop(), loop)
2776            self.assertEqual(cm.warnings[0].filename, __file__)
2777
2778            asyncio.set_event_loop(None)
2779            with self.assertWarns(DeprecationWarning) as cm:
2780                with self.assertRaisesRegex(RuntimeError, 'no current'):
2781                    asyncio.get_event_loop()
2782            self.assertEqual(cm.warnings[0].filename, __file__)
2783
2784        finally:
2785            asyncio.set_event_loop_policy(old_policy)
2786            if loop is not None:
2787                loop.close()
2788
2789        with self.assertRaisesRegex(RuntimeError, 'no running'):
2790            asyncio.get_running_loop()
2791
2792        self.assertIs(asyncio._get_running_loop(), None)
2793
2794
2795class TestPyGetEventLoop(GetEventLoopTestsMixin, unittest.TestCase):
2796
2797    _get_running_loop_impl = events._py__get_running_loop
2798    _set_running_loop_impl = events._py__set_running_loop
2799    get_running_loop_impl = events._py_get_running_loop
2800    get_event_loop_impl = events._py_get_event_loop
2801
2802
2803try:
2804    import _asyncio  # NoQA
2805except ImportError:
2806    pass
2807else:
2808
2809    class TestCGetEventLoop(GetEventLoopTestsMixin, unittest.TestCase):
2810
2811        _get_running_loop_impl = events._c__get_running_loop
2812        _set_running_loop_impl = events._c__set_running_loop
2813        get_running_loop_impl = events._c_get_running_loop
2814        get_event_loop_impl = events._c_get_event_loop
2815
2816
2817class TestServer(unittest.TestCase):
2818
2819    def test_get_loop(self):
2820        loop = asyncio.new_event_loop()
2821        self.addCleanup(loop.close)
2822        proto = MyProto(loop)
2823        server = loop.run_until_complete(loop.create_server(lambda: proto, '0.0.0.0', 0))
2824        self.assertEqual(server.get_loop(), loop)
2825        server.close()
2826        loop.run_until_complete(server.wait_closed())
2827
2828
2829class TestAbstractServer(unittest.TestCase):
2830
2831    def test_close(self):
2832        with self.assertRaises(NotImplementedError):
2833            events.AbstractServer().close()
2834
2835    def test_wait_closed(self):
2836        loop = asyncio.new_event_loop()
2837        self.addCleanup(loop.close)
2838
2839        with self.assertRaises(NotImplementedError):
2840            loop.run_until_complete(events.AbstractServer().wait_closed())
2841
2842    def test_get_loop(self):
2843        with self.assertRaises(NotImplementedError):
2844            events.AbstractServer().get_loop()
2845
2846
2847if __name__ == '__main__':
2848    unittest.main()
2849