1"""Tests for futures.py."""
2
3import concurrent.futures
4import re
5import sys
6import threading
7import unittest
8from unittest import mock
9
10import asyncio
11from asyncio import test_utils
12try:
13    from test import support
14except ImportError:
15    from asyncio import test_support as support
16
17
18def _fakefunc(f):
19    return f
20
21def first_cb():
22    pass
23
24def last_cb():
25    pass
26
27
28class FutureTests(test_utils.TestCase):
29
30    def setUp(self):
31        self.loop = self.new_test_loop()
32        self.addCleanup(self.loop.close)
33
34    def test_initial_state(self):
35        f = asyncio.Future(loop=self.loop)
36        self.assertFalse(f.cancelled())
37        self.assertFalse(f.done())
38        f.cancel()
39        self.assertTrue(f.cancelled())
40
41    def test_init_constructor_default_loop(self):
42        asyncio.set_event_loop(self.loop)
43        f = asyncio.Future()
44        self.assertIs(f._loop, self.loop)
45
46    def test_constructor_positional(self):
47        # Make sure Future doesn't accept a positional argument
48        self.assertRaises(TypeError, asyncio.Future, 42)
49
50    def test_cancel(self):
51        f = asyncio.Future(loop=self.loop)
52        self.assertTrue(f.cancel())
53        self.assertTrue(f.cancelled())
54        self.assertTrue(f.done())
55        self.assertRaises(asyncio.CancelledError, f.result)
56        self.assertRaises(asyncio.CancelledError, f.exception)
57        self.assertRaises(asyncio.InvalidStateError, f.set_result, None)
58        self.assertRaises(asyncio.InvalidStateError, f.set_exception, None)
59        self.assertFalse(f.cancel())
60
61    def test_result(self):
62        f = asyncio.Future(loop=self.loop)
63        self.assertRaises(asyncio.InvalidStateError, f.result)
64
65        f.set_result(42)
66        self.assertFalse(f.cancelled())
67        self.assertTrue(f.done())
68        self.assertEqual(f.result(), 42)
69        self.assertEqual(f.exception(), None)
70        self.assertRaises(asyncio.InvalidStateError, f.set_result, None)
71        self.assertRaises(asyncio.InvalidStateError, f.set_exception, None)
72        self.assertFalse(f.cancel())
73
74    def test_exception(self):
75        exc = RuntimeError()
76        f = asyncio.Future(loop=self.loop)
77        self.assertRaises(asyncio.InvalidStateError, f.exception)
78
79        f.set_exception(exc)
80        self.assertFalse(f.cancelled())
81        self.assertTrue(f.done())
82        self.assertRaises(RuntimeError, f.result)
83        self.assertEqual(f.exception(), exc)
84        self.assertRaises(asyncio.InvalidStateError, f.set_result, None)
85        self.assertRaises(asyncio.InvalidStateError, f.set_exception, None)
86        self.assertFalse(f.cancel())
87
88    def test_exception_class(self):
89        f = asyncio.Future(loop=self.loop)
90        f.set_exception(RuntimeError)
91        self.assertIsInstance(f.exception(), RuntimeError)
92
93    def test_yield_from_twice(self):
94        f = asyncio.Future(loop=self.loop)
95
96        def fixture():
97            yield 'A'
98            x = yield from f
99            yield 'B', x
100            y = yield from f
101            yield 'C', y
102
103        g = fixture()
104        self.assertEqual(next(g), 'A')  # yield 'A'.
105        self.assertEqual(next(g), f)  # First yield from f.
106        f.set_result(42)
107        self.assertEqual(next(g), ('B', 42))  # yield 'B', x.
108        # The second "yield from f" does not yield f.
109        self.assertEqual(next(g), ('C', 42))  # yield 'C', y.
110
111    def test_future_repr(self):
112        self.loop.set_debug(True)
113        f_pending_debug = asyncio.Future(loop=self.loop)
114        frame = f_pending_debug._source_traceback[-1]
115        self.assertEqual(repr(f_pending_debug),
116                         '<Future pending created at %s:%s>'
117                         % (frame[0], frame[1]))
118        f_pending_debug.cancel()
119
120        self.loop.set_debug(False)
121        f_pending = asyncio.Future(loop=self.loop)
122        self.assertEqual(repr(f_pending), '<Future pending>')
123        f_pending.cancel()
124
125        f_cancelled = asyncio.Future(loop=self.loop)
126        f_cancelled.cancel()
127        self.assertEqual(repr(f_cancelled), '<Future cancelled>')
128
129        f_result = asyncio.Future(loop=self.loop)
130        f_result.set_result(4)
131        self.assertEqual(repr(f_result), '<Future finished result=4>')
132        self.assertEqual(f_result.result(), 4)
133
134        exc = RuntimeError()
135        f_exception = asyncio.Future(loop=self.loop)
136        f_exception.set_exception(exc)
137        self.assertEqual(repr(f_exception),
138                         '<Future finished exception=RuntimeError()>')
139        self.assertIs(f_exception.exception(), exc)
140
141        def func_repr(func):
142            filename, lineno = test_utils.get_function_source(func)
143            text = '%s() at %s:%s' % (func.__qualname__, filename, lineno)
144            return re.escape(text)
145
146        f_one_callbacks = asyncio.Future(loop=self.loop)
147        f_one_callbacks.add_done_callback(_fakefunc)
148        fake_repr = func_repr(_fakefunc)
149        self.assertRegex(repr(f_one_callbacks),
150                         r'<Future pending cb=\[%s\]>' % fake_repr)
151        f_one_callbacks.cancel()
152        self.assertEqual(repr(f_one_callbacks),
153                         '<Future cancelled>')
154
155        f_two_callbacks = asyncio.Future(loop=self.loop)
156        f_two_callbacks.add_done_callback(first_cb)
157        f_two_callbacks.add_done_callback(last_cb)
158        first_repr = func_repr(first_cb)
159        last_repr = func_repr(last_cb)
160        self.assertRegex(repr(f_two_callbacks),
161                         r'<Future pending cb=\[%s, %s\]>'
162                         % (first_repr, last_repr))
163
164        f_many_callbacks = asyncio.Future(loop=self.loop)
165        f_many_callbacks.add_done_callback(first_cb)
166        for i in range(8):
167            f_many_callbacks.add_done_callback(_fakefunc)
168        f_many_callbacks.add_done_callback(last_cb)
169        cb_regex = r'%s, <8 more>, %s' % (first_repr, last_repr)
170        self.assertRegex(repr(f_many_callbacks),
171                         r'<Future pending cb=\[%s\]>' % cb_regex)
172        f_many_callbacks.cancel()
173        self.assertEqual(repr(f_many_callbacks),
174                         '<Future cancelled>')
175
176    def test_copy_state(self):
177        # Test the internal _copy_state method since it's being directly
178        # invoked in other modules.
179        f = asyncio.Future(loop=self.loop)
180        f.set_result(10)
181
182        newf = asyncio.Future(loop=self.loop)
183        newf._copy_state(f)
184        self.assertTrue(newf.done())
185        self.assertEqual(newf.result(), 10)
186
187        f_exception = asyncio.Future(loop=self.loop)
188        f_exception.set_exception(RuntimeError())
189
190        newf_exception = asyncio.Future(loop=self.loop)
191        newf_exception._copy_state(f_exception)
192        self.assertTrue(newf_exception.done())
193        self.assertRaises(RuntimeError, newf_exception.result)
194
195        f_cancelled = asyncio.Future(loop=self.loop)
196        f_cancelled.cancel()
197
198        newf_cancelled = asyncio.Future(loop=self.loop)
199        newf_cancelled._copy_state(f_cancelled)
200        self.assertTrue(newf_cancelled.cancelled())
201
202    def test_iter(self):
203        fut = asyncio.Future(loop=self.loop)
204
205        def coro():
206            yield from fut
207
208        def test():
209            arg1, arg2 = coro()
210
211        self.assertRaises(AssertionError, test)
212        fut.cancel()
213
214    @mock.patch('asyncio.base_events.logger')
215    def test_tb_logger_abandoned(self, m_log):
216        fut = asyncio.Future(loop=self.loop)
217        del fut
218        self.assertFalse(m_log.error.called)
219
220    @mock.patch('asyncio.base_events.logger')
221    def test_tb_logger_result_unretrieved(self, m_log):
222        fut = asyncio.Future(loop=self.loop)
223        fut.set_result(42)
224        del fut
225        self.assertFalse(m_log.error.called)
226
227    @mock.patch('asyncio.base_events.logger')
228    def test_tb_logger_result_retrieved(self, m_log):
229        fut = asyncio.Future(loop=self.loop)
230        fut.set_result(42)
231        fut.result()
232        del fut
233        self.assertFalse(m_log.error.called)
234
235    @mock.patch('asyncio.base_events.logger')
236    def test_tb_logger_exception_unretrieved(self, m_log):
237        fut = asyncio.Future(loop=self.loop)
238        fut.set_exception(RuntimeError('boom'))
239        del fut
240        test_utils.run_briefly(self.loop)
241        self.assertTrue(m_log.error.called)
242
243    @mock.patch('asyncio.base_events.logger')
244    def test_tb_logger_exception_retrieved(self, m_log):
245        fut = asyncio.Future(loop=self.loop)
246        fut.set_exception(RuntimeError('boom'))
247        fut.exception()
248        del fut
249        self.assertFalse(m_log.error.called)
250
251    @mock.patch('asyncio.base_events.logger')
252    def test_tb_logger_exception_result_retrieved(self, m_log):
253        fut = asyncio.Future(loop=self.loop)
254        fut.set_exception(RuntimeError('boom'))
255        self.assertRaises(RuntimeError, fut.result)
256        del fut
257        self.assertFalse(m_log.error.called)
258
259    def test_wrap_future(self):
260
261        def run(arg):
262            return (arg, threading.get_ident())
263        ex = concurrent.futures.ThreadPoolExecutor(1)
264        f1 = ex.submit(run, 'oi')
265        f2 = asyncio.wrap_future(f1, loop=self.loop)
266        res, ident = self.loop.run_until_complete(f2)
267        self.assertIsInstance(f2, asyncio.Future)
268        self.assertEqual(res, 'oi')
269        self.assertNotEqual(ident, threading.get_ident())
270
271    def test_wrap_future_future(self):
272        f1 = asyncio.Future(loop=self.loop)
273        f2 = asyncio.wrap_future(f1)
274        self.assertIs(f1, f2)
275
276    @mock.patch('asyncio.futures.events')
277    def test_wrap_future_use_global_loop(self, m_events):
278        def run(arg):
279            return (arg, threading.get_ident())
280        ex = concurrent.futures.ThreadPoolExecutor(1)
281        f1 = ex.submit(run, 'oi')
282        f2 = asyncio.wrap_future(f1)
283        self.assertIs(m_events.get_event_loop.return_value, f2._loop)
284
285    def test_wrap_future_cancel(self):
286        f1 = concurrent.futures.Future()
287        f2 = asyncio.wrap_future(f1, loop=self.loop)
288        f2.cancel()
289        test_utils.run_briefly(self.loop)
290        self.assertTrue(f1.cancelled())
291        self.assertTrue(f2.cancelled())
292
293    def test_wrap_future_cancel2(self):
294        f1 = concurrent.futures.Future()
295        f2 = asyncio.wrap_future(f1, loop=self.loop)
296        f1.set_result(42)
297        f2.cancel()
298        test_utils.run_briefly(self.loop)
299        self.assertFalse(f1.cancelled())
300        self.assertEqual(f1.result(), 42)
301        self.assertTrue(f2.cancelled())
302
303    def test_future_source_traceback(self):
304        self.loop.set_debug(True)
305
306        future = asyncio.Future(loop=self.loop)
307        lineno = sys._getframe().f_lineno - 1
308        self.assertIsInstance(future._source_traceback, list)
309        self.assertEqual(future._source_traceback[-1][:3],
310                         (__file__,
311                          lineno,
312                          'test_future_source_traceback'))
313
314    @mock.patch('asyncio.base_events.logger')
315    def check_future_exception_never_retrieved(self, debug, m_log):
316        self.loop.set_debug(debug)
317
318        def memory_error():
319            try:
320                raise MemoryError()
321            except BaseException as exc:
322                return exc
323        exc = memory_error()
324
325        future = asyncio.Future(loop=self.loop)
326        if debug:
327            source_traceback = future._source_traceback
328        future.set_exception(exc)
329        future = None
330        test_utils.run_briefly(self.loop)
331        support.gc_collect()
332
333        if sys.version_info >= (3, 4):
334            if debug:
335                frame = source_traceback[-1]
336                regex = (r'^Future exception was never retrieved\n'
337                         r'future: <Future finished exception=MemoryError\(\) '
338                             r'created at {filename}:{lineno}>\n'
339                         r'source_traceback: Object '
340                            r'created at \(most recent call last\):\n'
341                         r'  File'
342                         r'.*\n'
343                         r'  File "{filename}", line {lineno}, '
344                            r'in check_future_exception_never_retrieved\n'
345                         r'    future = asyncio\.Future\(loop=self\.loop\)$'
346                         ).format(filename=re.escape(frame[0]),
347                                  lineno=frame[1])
348            else:
349                regex = (r'^Future exception was never retrieved\n'
350                         r'future: '
351                            r'<Future finished exception=MemoryError\(\)>$'
352                         )
353            exc_info = (type(exc), exc, exc.__traceback__)
354            m_log.error.assert_called_once_with(mock.ANY, exc_info=exc_info)
355        else:
356            if debug:
357                frame = source_traceback[-1]
358                regex = (r'^Future/Task exception was never retrieved\n'
359                         r'Future/Task created at \(most recent call last\):\n'
360                         r'  File'
361                         r'.*\n'
362                         r'  File "{filename}", line {lineno}, '
363                            r'in check_future_exception_never_retrieved\n'
364                         r'    future = asyncio\.Future\(loop=self\.loop\)\n'
365                         r'Traceback \(most recent call last\):\n'
366                         r'.*\n'
367                         r'MemoryError$'
368                         ).format(filename=re.escape(frame[0]),
369                                  lineno=frame[1])
370            else:
371                regex = (r'^Future/Task exception was never retrieved\n'
372                         r'Traceback \(most recent call last\):\n'
373                         r'.*\n'
374                         r'MemoryError$'
375                         )
376            m_log.error.assert_called_once_with(mock.ANY, exc_info=False)
377        message = m_log.error.call_args[0][0]
378        self.assertRegex(message, re.compile(regex, re.DOTALL))
379
380    def test_future_exception_never_retrieved(self):
381        self.check_future_exception_never_retrieved(False)
382
383    def test_future_exception_never_retrieved_debug(self):
384        self.check_future_exception_never_retrieved(True)
385
386    def test_set_result_unless_cancelled(self):
387        fut = asyncio.Future(loop=self.loop)
388        fut.cancel()
389        fut._set_result_unless_cancelled(2)
390        self.assertTrue(fut.cancelled())
391
392
393class FutureDoneCallbackTests(test_utils.TestCase):
394
395    def setUp(self):
396        self.loop = self.new_test_loop()
397
398    def run_briefly(self):
399        test_utils.run_briefly(self.loop)
400
401    def _make_callback(self, bag, thing):
402        # Create a callback function that appends thing to bag.
403        def bag_appender(future):
404            bag.append(thing)
405        return bag_appender
406
407    def _new_future(self):
408        return asyncio.Future(loop=self.loop)
409
410    def test_callbacks_invoked_on_set_result(self):
411        bag = []
412        f = self._new_future()
413        f.add_done_callback(self._make_callback(bag, 42))
414        f.add_done_callback(self._make_callback(bag, 17))
415
416        self.assertEqual(bag, [])
417        f.set_result('foo')
418
419        self.run_briefly()
420
421        self.assertEqual(bag, [42, 17])
422        self.assertEqual(f.result(), 'foo')
423
424    def test_callbacks_invoked_on_set_exception(self):
425        bag = []
426        f = self._new_future()
427        f.add_done_callback(self._make_callback(bag, 100))
428
429        self.assertEqual(bag, [])
430        exc = RuntimeError()
431        f.set_exception(exc)
432
433        self.run_briefly()
434
435        self.assertEqual(bag, [100])
436        self.assertEqual(f.exception(), exc)
437
438    def test_remove_done_callback(self):
439        bag = []
440        f = self._new_future()
441        cb1 = self._make_callback(bag, 1)
442        cb2 = self._make_callback(bag, 2)
443        cb3 = self._make_callback(bag, 3)
444
445        # Add one cb1 and one cb2.
446        f.add_done_callback(cb1)
447        f.add_done_callback(cb2)
448
449        # One instance of cb2 removed. Now there's only one cb1.
450        self.assertEqual(f.remove_done_callback(cb2), 1)
451
452        # Never had any cb3 in there.
453        self.assertEqual(f.remove_done_callback(cb3), 0)
454
455        # After this there will be 6 instances of cb1 and one of cb2.
456        f.add_done_callback(cb2)
457        for i in range(5):
458            f.add_done_callback(cb1)
459
460        # Remove all instances of cb1. One cb2 remains.
461        self.assertEqual(f.remove_done_callback(cb1), 6)
462
463        self.assertEqual(bag, [])
464        f.set_result('foo')
465
466        self.run_briefly()
467
468        self.assertEqual(bag, [2])
469        self.assertEqual(f.result(), 'foo')
470
471
472if __name__ == '__main__':
473    unittest.main()
474