1"""Tests for streams.py."""
2
3import gc
4import os
5import queue
6import pickle
7import socket
8import sys
9import threading
10import unittest
11from unittest import mock
12from test import support
13try:
14    import ssl
15except ImportError:
16    ssl = None
17
18import asyncio
19from test.test_asyncio import utils as test_utils
20
21
22class StreamTests(test_utils.TestCase):
23
24    DATA = b'line1\nline2\nline3\n'
25
26    def setUp(self):
27        super().setUp()
28        self.loop = asyncio.new_event_loop()
29        self.set_event_loop(self.loop)
30
31    def tearDown(self):
32        # just in case if we have transport close callbacks
33        test_utils.run_briefly(self.loop)
34
35        self.loop.close()
36        gc.collect()
37        super().tearDown()
38
39    @mock.patch('asyncio.streams.events')
40    def test_ctor_global_loop(self, m_events):
41        stream = asyncio.StreamReader()
42        self.assertIs(stream._loop, m_events.get_event_loop.return_value)
43
44    def _basetest_open_connection(self, open_connection_fut):
45        reader, writer = self.loop.run_until_complete(open_connection_fut)
46        writer.write(b'GET / HTTP/1.0\r\n\r\n')
47        f = reader.readline()
48        data = self.loop.run_until_complete(f)
49        self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
50        f = reader.read()
51        data = self.loop.run_until_complete(f)
52        self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
53        writer.close()
54
55    def test_open_connection(self):
56        with test_utils.run_test_server() as httpd:
57            conn_fut = asyncio.open_connection(*httpd.address,
58                                               loop=self.loop)
59            self._basetest_open_connection(conn_fut)
60
61    @support.skip_unless_bind_unix_socket
62    def test_open_unix_connection(self):
63        with test_utils.run_test_unix_server() as httpd:
64            conn_fut = asyncio.open_unix_connection(httpd.address,
65                                                    loop=self.loop)
66            self._basetest_open_connection(conn_fut)
67
68    def _basetest_open_connection_no_loop_ssl(self, open_connection_fut):
69        try:
70            reader, writer = self.loop.run_until_complete(open_connection_fut)
71        finally:
72            asyncio.set_event_loop(None)
73        writer.write(b'GET / HTTP/1.0\r\n\r\n')
74        f = reader.read()
75        data = self.loop.run_until_complete(f)
76        self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
77
78        writer.close()
79
80    @unittest.skipIf(ssl is None, 'No ssl module')
81    def test_open_connection_no_loop_ssl(self):
82        with test_utils.run_test_server(use_ssl=True) as httpd:
83            conn_fut = asyncio.open_connection(
84                *httpd.address,
85                ssl=test_utils.dummy_ssl_context(),
86                loop=self.loop)
87
88            self._basetest_open_connection_no_loop_ssl(conn_fut)
89
90    @support.skip_unless_bind_unix_socket
91    @unittest.skipIf(ssl is None, 'No ssl module')
92    def test_open_unix_connection_no_loop_ssl(self):
93        with test_utils.run_test_unix_server(use_ssl=True) as httpd:
94            conn_fut = asyncio.open_unix_connection(
95                httpd.address,
96                ssl=test_utils.dummy_ssl_context(),
97                server_hostname='',
98                loop=self.loop)
99
100            self._basetest_open_connection_no_loop_ssl(conn_fut)
101
102    def _basetest_open_connection_error(self, open_connection_fut):
103        reader, writer = self.loop.run_until_complete(open_connection_fut)
104        writer._protocol.connection_lost(ZeroDivisionError())
105        f = reader.read()
106        with self.assertRaises(ZeroDivisionError):
107            self.loop.run_until_complete(f)
108        writer.close()
109        test_utils.run_briefly(self.loop)
110
111    def test_open_connection_error(self):
112        with test_utils.run_test_server() as httpd:
113            conn_fut = asyncio.open_connection(*httpd.address,
114                                               loop=self.loop)
115            self._basetest_open_connection_error(conn_fut)
116
117    @support.skip_unless_bind_unix_socket
118    def test_open_unix_connection_error(self):
119        with test_utils.run_test_unix_server() as httpd:
120            conn_fut = asyncio.open_unix_connection(httpd.address,
121                                                    loop=self.loop)
122            self._basetest_open_connection_error(conn_fut)
123
124    def test_feed_empty_data(self):
125        stream = asyncio.StreamReader(loop=self.loop)
126
127        stream.feed_data(b'')
128        self.assertEqual(b'', stream._buffer)
129
130    def test_feed_nonempty_data(self):
131        stream = asyncio.StreamReader(loop=self.loop)
132
133        stream.feed_data(self.DATA)
134        self.assertEqual(self.DATA, stream._buffer)
135
136    def test_read_zero(self):
137        # Read zero bytes.
138        stream = asyncio.StreamReader(loop=self.loop)
139        stream.feed_data(self.DATA)
140
141        data = self.loop.run_until_complete(stream.read(0))
142        self.assertEqual(b'', data)
143        self.assertEqual(self.DATA, stream._buffer)
144
145    def test_read(self):
146        # Read bytes.
147        stream = asyncio.StreamReader(loop=self.loop)
148        read_task = asyncio.Task(stream.read(30), loop=self.loop)
149
150        def cb():
151            stream.feed_data(self.DATA)
152        self.loop.call_soon(cb)
153
154        data = self.loop.run_until_complete(read_task)
155        self.assertEqual(self.DATA, data)
156        self.assertEqual(b'', stream._buffer)
157
158    def test_read_line_breaks(self):
159        # Read bytes without line breaks.
160        stream = asyncio.StreamReader(loop=self.loop)
161        stream.feed_data(b'line1')
162        stream.feed_data(b'line2')
163
164        data = self.loop.run_until_complete(stream.read(5))
165
166        self.assertEqual(b'line1', data)
167        self.assertEqual(b'line2', stream._buffer)
168
169    def test_read_eof(self):
170        # Read bytes, stop at eof.
171        stream = asyncio.StreamReader(loop=self.loop)
172        read_task = asyncio.Task(stream.read(1024), loop=self.loop)
173
174        def cb():
175            stream.feed_eof()
176        self.loop.call_soon(cb)
177
178        data = self.loop.run_until_complete(read_task)
179        self.assertEqual(b'', data)
180        self.assertEqual(b'', stream._buffer)
181
182    def test_read_until_eof(self):
183        # Read all bytes until eof.
184        stream = asyncio.StreamReader(loop=self.loop)
185        read_task = asyncio.Task(stream.read(-1), loop=self.loop)
186
187        def cb():
188            stream.feed_data(b'chunk1\n')
189            stream.feed_data(b'chunk2')
190            stream.feed_eof()
191        self.loop.call_soon(cb)
192
193        data = self.loop.run_until_complete(read_task)
194
195        self.assertEqual(b'chunk1\nchunk2', data)
196        self.assertEqual(b'', stream._buffer)
197
198    def test_read_exception(self):
199        stream = asyncio.StreamReader(loop=self.loop)
200        stream.feed_data(b'line\n')
201
202        data = self.loop.run_until_complete(stream.read(2))
203        self.assertEqual(b'li', data)
204
205        stream.set_exception(ValueError())
206        self.assertRaises(
207            ValueError, self.loop.run_until_complete, stream.read(2))
208
209    def test_invalid_limit(self):
210        with self.assertRaisesRegex(ValueError, 'imit'):
211            asyncio.StreamReader(limit=0, loop=self.loop)
212
213        with self.assertRaisesRegex(ValueError, 'imit'):
214            asyncio.StreamReader(limit=-1, loop=self.loop)
215
216    def test_read_limit(self):
217        stream = asyncio.StreamReader(limit=3, loop=self.loop)
218        stream.feed_data(b'chunk')
219        data = self.loop.run_until_complete(stream.read(5))
220        self.assertEqual(b'chunk', data)
221        self.assertEqual(b'', stream._buffer)
222
223    def test_readline(self):
224        # Read one line. 'readline' will need to wait for the data
225        # to come from 'cb'
226        stream = asyncio.StreamReader(loop=self.loop)
227        stream.feed_data(b'chunk1 ')
228        read_task = asyncio.Task(stream.readline(), loop=self.loop)
229
230        def cb():
231            stream.feed_data(b'chunk2 ')
232            stream.feed_data(b'chunk3 ')
233            stream.feed_data(b'\n chunk4')
234        self.loop.call_soon(cb)
235
236        line = self.loop.run_until_complete(read_task)
237        self.assertEqual(b'chunk1 chunk2 chunk3 \n', line)
238        self.assertEqual(b' chunk4', stream._buffer)
239
240    def test_readline_limit_with_existing_data(self):
241        # Read one line. The data is in StreamReader's buffer
242        # before the event loop is run.
243
244        stream = asyncio.StreamReader(limit=3, loop=self.loop)
245        stream.feed_data(b'li')
246        stream.feed_data(b'ne1\nline2\n')
247
248        self.assertRaises(
249            ValueError, self.loop.run_until_complete, stream.readline())
250        # The buffer should contain the remaining data after exception
251        self.assertEqual(b'line2\n', stream._buffer)
252
253        stream = asyncio.StreamReader(limit=3, loop=self.loop)
254        stream.feed_data(b'li')
255        stream.feed_data(b'ne1')
256        stream.feed_data(b'li')
257
258        self.assertRaises(
259            ValueError, self.loop.run_until_complete, stream.readline())
260        # No b'\n' at the end. The 'limit' is set to 3. So before
261        # waiting for the new data in buffer, 'readline' will consume
262        # the entire buffer, and since the length of the consumed data
263        # is more than 3, it will raise a ValueError. The buffer is
264        # expected to be empty now.
265        self.assertEqual(b'', stream._buffer)
266
267    def test_at_eof(self):
268        stream = asyncio.StreamReader(loop=self.loop)
269        self.assertFalse(stream.at_eof())
270
271        stream.feed_data(b'some data\n')
272        self.assertFalse(stream.at_eof())
273
274        self.loop.run_until_complete(stream.readline())
275        self.assertFalse(stream.at_eof())
276
277        stream.feed_data(b'some data\n')
278        stream.feed_eof()
279        self.loop.run_until_complete(stream.readline())
280        self.assertTrue(stream.at_eof())
281
282    def test_readline_limit(self):
283        # Read one line. StreamReaders are fed with data after
284        # their 'readline' methods are called.
285
286        stream = asyncio.StreamReader(limit=7, loop=self.loop)
287        def cb():
288            stream.feed_data(b'chunk1')
289            stream.feed_data(b'chunk2')
290            stream.feed_data(b'chunk3\n')
291            stream.feed_eof()
292        self.loop.call_soon(cb)
293
294        self.assertRaises(
295            ValueError, self.loop.run_until_complete, stream.readline())
296        # The buffer had just one line of data, and after raising
297        # a ValueError it should be empty.
298        self.assertEqual(b'', stream._buffer)
299
300        stream = asyncio.StreamReader(limit=7, loop=self.loop)
301        def cb():
302            stream.feed_data(b'chunk1')
303            stream.feed_data(b'chunk2\n')
304            stream.feed_data(b'chunk3\n')
305            stream.feed_eof()
306        self.loop.call_soon(cb)
307
308        self.assertRaises(
309            ValueError, self.loop.run_until_complete, stream.readline())
310        self.assertEqual(b'chunk3\n', stream._buffer)
311
312        # check strictness of the limit
313        stream = asyncio.StreamReader(limit=7, loop=self.loop)
314        stream.feed_data(b'1234567\n')
315        line = self.loop.run_until_complete(stream.readline())
316        self.assertEqual(b'1234567\n', line)
317        self.assertEqual(b'', stream._buffer)
318
319        stream.feed_data(b'12345678\n')
320        with self.assertRaises(ValueError) as cm:
321            self.loop.run_until_complete(stream.readline())
322        self.assertEqual(b'', stream._buffer)
323
324        stream.feed_data(b'12345678')
325        with self.assertRaises(ValueError) as cm:
326            self.loop.run_until_complete(stream.readline())
327        self.assertEqual(b'', stream._buffer)
328
329    def test_readline_nolimit_nowait(self):
330        # All needed data for the first 'readline' call will be
331        # in the buffer.
332        stream = asyncio.StreamReader(loop=self.loop)
333        stream.feed_data(self.DATA[:6])
334        stream.feed_data(self.DATA[6:])
335
336        line = self.loop.run_until_complete(stream.readline())
337
338        self.assertEqual(b'line1\n', line)
339        self.assertEqual(b'line2\nline3\n', stream._buffer)
340
341    def test_readline_eof(self):
342        stream = asyncio.StreamReader(loop=self.loop)
343        stream.feed_data(b'some data')
344        stream.feed_eof()
345
346        line = self.loop.run_until_complete(stream.readline())
347        self.assertEqual(b'some data', line)
348
349    def test_readline_empty_eof(self):
350        stream = asyncio.StreamReader(loop=self.loop)
351        stream.feed_eof()
352
353        line = self.loop.run_until_complete(stream.readline())
354        self.assertEqual(b'', line)
355
356    def test_readline_read_byte_count(self):
357        stream = asyncio.StreamReader(loop=self.loop)
358        stream.feed_data(self.DATA)
359
360        self.loop.run_until_complete(stream.readline())
361
362        data = self.loop.run_until_complete(stream.read(7))
363
364        self.assertEqual(b'line2\nl', data)
365        self.assertEqual(b'ine3\n', stream._buffer)
366
367    def test_readline_exception(self):
368        stream = asyncio.StreamReader(loop=self.loop)
369        stream.feed_data(b'line\n')
370
371        data = self.loop.run_until_complete(stream.readline())
372        self.assertEqual(b'line\n', data)
373
374        stream.set_exception(ValueError())
375        self.assertRaises(
376            ValueError, self.loop.run_until_complete, stream.readline())
377        self.assertEqual(b'', stream._buffer)
378
379    def test_readuntil_separator(self):
380        stream = asyncio.StreamReader(loop=self.loop)
381        with self.assertRaisesRegex(ValueError, 'Separator should be'):
382            self.loop.run_until_complete(stream.readuntil(separator=b''))
383
384    def test_readuntil_multi_chunks(self):
385        stream = asyncio.StreamReader(loop=self.loop)
386
387        stream.feed_data(b'lineAAA')
388        data = self.loop.run_until_complete(stream.readuntil(separator=b'AAA'))
389        self.assertEqual(b'lineAAA', data)
390        self.assertEqual(b'', stream._buffer)
391
392        stream.feed_data(b'lineAAA')
393        data = self.loop.run_until_complete(stream.readuntil(b'AAA'))
394        self.assertEqual(b'lineAAA', data)
395        self.assertEqual(b'', stream._buffer)
396
397        stream.feed_data(b'lineAAAxxx')
398        data = self.loop.run_until_complete(stream.readuntil(b'AAA'))
399        self.assertEqual(b'lineAAA', data)
400        self.assertEqual(b'xxx', stream._buffer)
401
402    def test_readuntil_multi_chunks_1(self):
403        stream = asyncio.StreamReader(loop=self.loop)
404
405        stream.feed_data(b'QWEaa')
406        stream.feed_data(b'XYaa')
407        stream.feed_data(b'a')
408        data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
409        self.assertEqual(b'QWEaaXYaaa', data)
410        self.assertEqual(b'', stream._buffer)
411
412        stream.feed_data(b'QWEaa')
413        stream.feed_data(b'XYa')
414        stream.feed_data(b'aa')
415        data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
416        self.assertEqual(b'QWEaaXYaaa', data)
417        self.assertEqual(b'', stream._buffer)
418
419        stream.feed_data(b'aaa')
420        data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
421        self.assertEqual(b'aaa', data)
422        self.assertEqual(b'', stream._buffer)
423
424        stream.feed_data(b'Xaaa')
425        data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
426        self.assertEqual(b'Xaaa', data)
427        self.assertEqual(b'', stream._buffer)
428
429        stream.feed_data(b'XXX')
430        stream.feed_data(b'a')
431        stream.feed_data(b'a')
432        stream.feed_data(b'a')
433        data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
434        self.assertEqual(b'XXXaaa', data)
435        self.assertEqual(b'', stream._buffer)
436
437    def test_readuntil_eof(self):
438        stream = asyncio.StreamReader(loop=self.loop)
439        stream.feed_data(b'some dataAA')
440        stream.feed_eof()
441
442        with self.assertRaises(asyncio.IncompleteReadError) as cm:
443            self.loop.run_until_complete(stream.readuntil(b'AAA'))
444        self.assertEqual(cm.exception.partial, b'some dataAA')
445        self.assertIsNone(cm.exception.expected)
446        self.assertEqual(b'', stream._buffer)
447
448    def test_readuntil_limit_found_sep(self):
449        stream = asyncio.StreamReader(loop=self.loop, limit=3)
450        stream.feed_data(b'some dataAA')
451
452        with self.assertRaisesRegex(asyncio.LimitOverrunError,
453                                    'not found') as cm:
454            self.loop.run_until_complete(stream.readuntil(b'AAA'))
455
456        self.assertEqual(b'some dataAA', stream._buffer)
457
458        stream.feed_data(b'A')
459        with self.assertRaisesRegex(asyncio.LimitOverrunError,
460                                    'is found') as cm:
461            self.loop.run_until_complete(stream.readuntil(b'AAA'))
462
463        self.assertEqual(b'some dataAAA', stream._buffer)
464
465    def test_readexactly_zero_or_less(self):
466        # Read exact number of bytes (zero or less).
467        stream = asyncio.StreamReader(loop=self.loop)
468        stream.feed_data(self.DATA)
469
470        data = self.loop.run_until_complete(stream.readexactly(0))
471        self.assertEqual(b'', data)
472        self.assertEqual(self.DATA, stream._buffer)
473
474        with self.assertRaisesRegex(ValueError, 'less than zero'):
475            self.loop.run_until_complete(stream.readexactly(-1))
476        self.assertEqual(self.DATA, stream._buffer)
477
478    def test_readexactly(self):
479        # Read exact number of bytes.
480        stream = asyncio.StreamReader(loop=self.loop)
481
482        n = 2 * len(self.DATA)
483        read_task = asyncio.Task(stream.readexactly(n), loop=self.loop)
484
485        def cb():
486            stream.feed_data(self.DATA)
487            stream.feed_data(self.DATA)
488            stream.feed_data(self.DATA)
489        self.loop.call_soon(cb)
490
491        data = self.loop.run_until_complete(read_task)
492        self.assertEqual(self.DATA + self.DATA, data)
493        self.assertEqual(self.DATA, stream._buffer)
494
495    def test_readexactly_limit(self):
496        stream = asyncio.StreamReader(limit=3, loop=self.loop)
497        stream.feed_data(b'chunk')
498        data = self.loop.run_until_complete(stream.readexactly(5))
499        self.assertEqual(b'chunk', data)
500        self.assertEqual(b'', stream._buffer)
501
502    def test_readexactly_eof(self):
503        # Read exact number of bytes (eof).
504        stream = asyncio.StreamReader(loop=self.loop)
505        n = 2 * len(self.DATA)
506        read_task = asyncio.Task(stream.readexactly(n), loop=self.loop)
507
508        def cb():
509            stream.feed_data(self.DATA)
510            stream.feed_eof()
511        self.loop.call_soon(cb)
512
513        with self.assertRaises(asyncio.IncompleteReadError) as cm:
514            self.loop.run_until_complete(read_task)
515        self.assertEqual(cm.exception.partial, self.DATA)
516        self.assertEqual(cm.exception.expected, n)
517        self.assertEqual(str(cm.exception),
518                         '18 bytes read on a total of 36 expected bytes')
519        self.assertEqual(b'', stream._buffer)
520
521    def test_readexactly_exception(self):
522        stream = asyncio.StreamReader(loop=self.loop)
523        stream.feed_data(b'line\n')
524
525        data = self.loop.run_until_complete(stream.readexactly(2))
526        self.assertEqual(b'li', data)
527
528        stream.set_exception(ValueError())
529        self.assertRaises(
530            ValueError, self.loop.run_until_complete, stream.readexactly(2))
531
532    def test_exception(self):
533        stream = asyncio.StreamReader(loop=self.loop)
534        self.assertIsNone(stream.exception())
535
536        exc = ValueError()
537        stream.set_exception(exc)
538        self.assertIs(stream.exception(), exc)
539
540    def test_exception_waiter(self):
541        stream = asyncio.StreamReader(loop=self.loop)
542
543        @asyncio.coroutine
544        def set_err():
545            stream.set_exception(ValueError())
546
547        t1 = asyncio.Task(stream.readline(), loop=self.loop)
548        t2 = asyncio.Task(set_err(), loop=self.loop)
549
550        self.loop.run_until_complete(asyncio.wait([t1, t2], loop=self.loop))
551
552        self.assertRaises(ValueError, t1.result)
553
554    def test_exception_cancel(self):
555        stream = asyncio.StreamReader(loop=self.loop)
556
557        t = asyncio.Task(stream.readline(), loop=self.loop)
558        test_utils.run_briefly(self.loop)
559        t.cancel()
560        test_utils.run_briefly(self.loop)
561        # The following line fails if set_exception() isn't careful.
562        stream.set_exception(RuntimeError('message'))
563        test_utils.run_briefly(self.loop)
564        self.assertIs(stream._waiter, None)
565
566    def test_start_server(self):
567
568        class MyServer:
569
570            def __init__(self, loop):
571                self.server = None
572                self.loop = loop
573
574            async def handle_client(self, client_reader, client_writer):
575                data = await client_reader.readline()
576                client_writer.write(data)
577                await client_writer.drain()
578                client_writer.close()
579
580            def start(self):
581                sock = socket.socket()
582                sock.bind(('127.0.0.1', 0))
583                self.server = self.loop.run_until_complete(
584                    asyncio.start_server(self.handle_client,
585                                         sock=sock,
586                                         loop=self.loop))
587                return sock.getsockname()
588
589            def handle_client_callback(self, client_reader, client_writer):
590                self.loop.create_task(self.handle_client(client_reader,
591                                                         client_writer))
592
593            def start_callback(self):
594                sock = socket.socket()
595                sock.bind(('127.0.0.1', 0))
596                addr = sock.getsockname()
597                sock.close()
598                self.server = self.loop.run_until_complete(
599                    asyncio.start_server(self.handle_client_callback,
600                                         host=addr[0], port=addr[1],
601                                         loop=self.loop))
602                return addr
603
604            def stop(self):
605                if self.server is not None:
606                    self.server.close()
607                    self.loop.run_until_complete(self.server.wait_closed())
608                    self.server = None
609
610        async def client(addr):
611            reader, writer = await asyncio.open_connection(
612                *addr, loop=self.loop)
613            # send a line
614            writer.write(b"hello world!\n")
615            # read it back
616            msgback = await reader.readline()
617            writer.close()
618            return msgback
619
620        # test the server variant with a coroutine as client handler
621        server = MyServer(self.loop)
622        addr = server.start()
623        msg = self.loop.run_until_complete(asyncio.Task(client(addr),
624                                                        loop=self.loop))
625        server.stop()
626        self.assertEqual(msg, b"hello world!\n")
627
628        # test the server variant with a callback as client handler
629        server = MyServer(self.loop)
630        addr = server.start_callback()
631        msg = self.loop.run_until_complete(asyncio.Task(client(addr),
632                                                        loop=self.loop))
633        server.stop()
634        self.assertEqual(msg, b"hello world!\n")
635
636    @support.skip_unless_bind_unix_socket
637    def test_start_unix_server(self):
638
639        class MyServer:
640
641            def __init__(self, loop, path):
642                self.server = None
643                self.loop = loop
644                self.path = path
645
646            async def handle_client(self, client_reader, client_writer):
647                data = await client_reader.readline()
648                client_writer.write(data)
649                await client_writer.drain()
650                client_writer.close()
651
652            def start(self):
653                self.server = self.loop.run_until_complete(
654                    asyncio.start_unix_server(self.handle_client,
655                                              path=self.path,
656                                              loop=self.loop))
657
658            def handle_client_callback(self, client_reader, client_writer):
659                self.loop.create_task(self.handle_client(client_reader,
660                                                         client_writer))
661
662            def start_callback(self):
663                start = asyncio.start_unix_server(self.handle_client_callback,
664                                                  path=self.path,
665                                                  loop=self.loop)
666                self.server = self.loop.run_until_complete(start)
667
668            def stop(self):
669                if self.server is not None:
670                    self.server.close()
671                    self.loop.run_until_complete(self.server.wait_closed())
672                    self.server = None
673
674        async def client(path):
675            reader, writer = await asyncio.open_unix_connection(
676                path, loop=self.loop)
677            # send a line
678            writer.write(b"hello world!\n")
679            # read it back
680            msgback = await reader.readline()
681            writer.close()
682            return msgback
683
684        # test the server variant with a coroutine as client handler
685        with test_utils.unix_socket_path() as path:
686            server = MyServer(self.loop, path)
687            server.start()
688            msg = self.loop.run_until_complete(asyncio.Task(client(path),
689                                                            loop=self.loop))
690            server.stop()
691            self.assertEqual(msg, b"hello world!\n")
692
693        # test the server variant with a callback as client handler
694        with test_utils.unix_socket_path() as path:
695            server = MyServer(self.loop, path)
696            server.start_callback()
697            msg = self.loop.run_until_complete(asyncio.Task(client(path),
698                                                            loop=self.loop))
699            server.stop()
700            self.assertEqual(msg, b"hello world!\n")
701
702    @unittest.skipIf(sys.platform == 'win32', "Don't have pipes")
703    def test_read_all_from_pipe_reader(self):
704        # See asyncio issue 168.  This test is derived from the example
705        # subprocess_attach_read_pipe.py, but we configure the
706        # StreamReader's limit so that twice it is less than the size
707        # of the data writter.  Also we must explicitly attach a child
708        # watcher to the event loop.
709
710        code = """\
711import os, sys
712fd = int(sys.argv[1])
713os.write(fd, b'data')
714os.close(fd)
715"""
716        rfd, wfd = os.pipe()
717        args = [sys.executable, '-c', code, str(wfd)]
718
719        pipe = open(rfd, 'rb', 0)
720        reader = asyncio.StreamReader(loop=self.loop, limit=1)
721        protocol = asyncio.StreamReaderProtocol(reader, loop=self.loop)
722        transport, _ = self.loop.run_until_complete(
723            self.loop.connect_read_pipe(lambda: protocol, pipe))
724
725        watcher = asyncio.SafeChildWatcher()
726        watcher.attach_loop(self.loop)
727        try:
728            asyncio.set_child_watcher(watcher)
729            create = asyncio.create_subprocess_exec(*args,
730                                                    pass_fds={wfd},
731                                                    loop=self.loop)
732            proc = self.loop.run_until_complete(create)
733            self.loop.run_until_complete(proc.wait())
734        finally:
735            asyncio.set_child_watcher(None)
736
737        os.close(wfd)
738        data = self.loop.run_until_complete(reader.read(-1))
739        self.assertEqual(data, b'data')
740
741    def test_streamreader_constructor(self):
742        self.addCleanup(asyncio.set_event_loop, None)
743        asyncio.set_event_loop(self.loop)
744
745        # asyncio issue #184: Ensure that StreamReaderProtocol constructor
746        # retrieves the current loop if the loop parameter is not set
747        reader = asyncio.StreamReader()
748        self.assertIs(reader._loop, self.loop)
749
750    def test_streamreaderprotocol_constructor(self):
751        self.addCleanup(asyncio.set_event_loop, None)
752        asyncio.set_event_loop(self.loop)
753
754        # asyncio issue #184: Ensure that StreamReaderProtocol constructor
755        # retrieves the current loop if the loop parameter is not set
756        reader = mock.Mock()
757        protocol = asyncio.StreamReaderProtocol(reader)
758        self.assertIs(protocol._loop, self.loop)
759
760    def test_drain_raises(self):
761        # See http://bugs.python.org/issue25441
762
763        # This test should not use asyncio for the mock server; the
764        # whole point of the test is to test for a bug in drain()
765        # where it never gives up the event loop but the socket is
766        # closed on the  server side.
767
768        q = queue.Queue()
769
770        def server():
771            # Runs in a separate thread.
772            sock = socket.socket()
773            with sock:
774                sock.bind(('localhost', 0))
775                sock.listen(1)
776                addr = sock.getsockname()
777                q.put(addr)
778                clt, _ = sock.accept()
779                clt.close()
780
781        async def client(host, port):
782            reader, writer = await asyncio.open_connection(
783                host, port, loop=self.loop)
784
785            while True:
786                writer.write(b"foo\n")
787                await writer.drain()
788
789        # Start the server thread and wait for it to be listening.
790        thread = threading.Thread(target=server)
791        thread.setDaemon(True)
792        thread.start()
793        addr = q.get()
794
795        # Should not be stuck in an infinite loop.
796        with self.assertRaises((ConnectionResetError, BrokenPipeError)):
797            self.loop.run_until_complete(client(*addr))
798
799        # Clean up the thread.  (Only on success; on failure, it may
800        # be stuck in accept().)
801        thread.join()
802
803    def test___repr__(self):
804        stream = asyncio.StreamReader(loop=self.loop)
805        self.assertEqual("<StreamReader>", repr(stream))
806
807    def test___repr__nondefault_limit(self):
808        stream = asyncio.StreamReader(loop=self.loop, limit=123)
809        self.assertEqual("<StreamReader limit=123>", repr(stream))
810
811    def test___repr__eof(self):
812        stream = asyncio.StreamReader(loop=self.loop)
813        stream.feed_eof()
814        self.assertEqual("<StreamReader eof>", repr(stream))
815
816    def test___repr__data(self):
817        stream = asyncio.StreamReader(loop=self.loop)
818        stream.feed_data(b'data')
819        self.assertEqual("<StreamReader 4 bytes>", repr(stream))
820
821    def test___repr__exception(self):
822        stream = asyncio.StreamReader(loop=self.loop)
823        exc = RuntimeError()
824        stream.set_exception(exc)
825        self.assertEqual("<StreamReader exception=RuntimeError()>",
826                         repr(stream))
827
828    def test___repr__waiter(self):
829        stream = asyncio.StreamReader(loop=self.loop)
830        stream._waiter = asyncio.Future(loop=self.loop)
831        self.assertRegex(
832            repr(stream),
833            r"<StreamReader waiter=<Future pending[\S ]*>>")
834        stream._waiter.set_result(None)
835        self.loop.run_until_complete(stream._waiter)
836        stream._waiter = None
837        self.assertEqual("<StreamReader>", repr(stream))
838
839    def test___repr__transport(self):
840        stream = asyncio.StreamReader(loop=self.loop)
841        stream._transport = mock.Mock()
842        stream._transport.__repr__ = mock.Mock()
843        stream._transport.__repr__.return_value = "<Transport>"
844        self.assertEqual("<StreamReader transport=<Transport>>", repr(stream))
845
846    def test_IncompleteReadError_pickleable(self):
847        e = asyncio.IncompleteReadError(b'abc', 10)
848        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
849            with self.subTest(pickle_protocol=proto):
850                e2 = pickle.loads(pickle.dumps(e, protocol=proto))
851                self.assertEqual(str(e), str(e2))
852                self.assertEqual(e.partial, e2.partial)
853                self.assertEqual(e.expected, e2.expected)
854
855    def test_LimitOverrunError_pickleable(self):
856        e = asyncio.LimitOverrunError('message', 10)
857        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
858            with self.subTest(pickle_protocol=proto):
859                e2 = pickle.loads(pickle.dumps(e, protocol=proto))
860                self.assertEqual(str(e), str(e2))
861                self.assertEqual(e.consumed, e2.consumed)
862
863    def test_wait_closed_on_close(self):
864        with test_utils.run_test_server() as httpd:
865            rd, wr = self.loop.run_until_complete(
866                asyncio.open_connection(*httpd.address, loop=self.loop))
867
868            wr.write(b'GET / HTTP/1.0\r\n\r\n')
869            f = rd.readline()
870            data = self.loop.run_until_complete(f)
871            self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
872            f = rd.read()
873            data = self.loop.run_until_complete(f)
874            self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
875            self.assertFalse(wr.is_closing())
876            wr.close()
877            self.assertTrue(wr.is_closing())
878            self.loop.run_until_complete(wr.wait_closed())
879
880    def test_wait_closed_on_close_with_unread_data(self):
881        with test_utils.run_test_server() as httpd:
882            rd, wr = self.loop.run_until_complete(
883                asyncio.open_connection(*httpd.address, loop=self.loop))
884
885            wr.write(b'GET / HTTP/1.0\r\n\r\n')
886            f = rd.readline()
887            data = self.loop.run_until_complete(f)
888            self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
889            wr.close()
890            self.loop.run_until_complete(wr.wait_closed())
891
892
893if __name__ == '__main__':
894    unittest.main()
895