1"""Tests for futures.py."""
2
3try:
4    import concurrent.futures
5except ImportError:
6    concurrent = None
7import re
8import six
9import sys
10import threading
11
12import trollius as asyncio
13from trollius import From
14from trollius import compat
15from trollius import test_support as support
16from trollius import test_utils
17from trollius.test_utils import mock
18from trollius.test_utils import unittest
19
20
21def get_thread_ident():
22    return threading.current_thread().ident
23
24def _fakefunc(f):
25    return f
26
27def first_cb():
28    pass
29
30def last_cb():
31    pass
32
33
34class FutureTests(test_utils.TestCase):
35
36    def setUp(self):
37        self.loop = self.new_test_loop()
38        self.addCleanup(self.loop.close)
39
40    def test_initial_state(self):
41        f = asyncio.Future(loop=self.loop)
42        self.assertFalse(f.cancelled())
43        self.assertFalse(f.done())
44        f.cancel()
45        self.assertTrue(f.cancelled())
46
47    def test_init_constructor_default_loop(self):
48        asyncio.set_event_loop(self.loop)
49        f = asyncio.Future()
50        self.assertIs(f._loop, self.loop)
51
52    def test_cancel(self):
53        f = asyncio.Future(loop=self.loop)
54        self.assertTrue(f.cancel())
55        self.assertTrue(f.cancelled())
56        self.assertTrue(f.done())
57        self.assertRaises(asyncio.CancelledError, f.result)
58        self.assertRaises(asyncio.CancelledError, f.exception)
59        self.assertRaises(asyncio.InvalidStateError, f.set_result, None)
60        self.assertRaises(asyncio.InvalidStateError, f.set_exception, None)
61        self.assertFalse(f.cancel())
62
63    def test_result(self):
64        f = asyncio.Future(loop=self.loop)
65        self.assertRaises(asyncio.InvalidStateError, f.result)
66
67        f.set_result(42)
68        self.assertFalse(f.cancelled())
69        self.assertTrue(f.done())
70        self.assertEqual(f.result(), 42)
71        self.assertEqual(f.exception(), None)
72        self.assertRaises(asyncio.InvalidStateError, f.set_result, None)
73        self.assertRaises(asyncio.InvalidStateError, f.set_exception, None)
74        self.assertFalse(f.cancel())
75
76    def test_exception(self):
77        exc = RuntimeError()
78        f = asyncio.Future(loop=self.loop)
79        self.assertRaises(asyncio.InvalidStateError, f.exception)
80
81        f.set_exception(exc)
82        self.assertFalse(f.cancelled())
83        self.assertTrue(f.done())
84        self.assertRaises(RuntimeError, f.result)
85        self.assertEqual(f.exception(), exc)
86        self.assertRaises(asyncio.InvalidStateError, f.set_result, None)
87        self.assertRaises(asyncio.InvalidStateError, f.set_exception, None)
88        self.assertFalse(f.cancel())
89
90    def test_exception_class(self):
91        f = asyncio.Future(loop=self.loop)
92        f.set_exception(RuntimeError)
93        self.assertIsInstance(f.exception(), RuntimeError)
94
95    def test_future_repr(self):
96        self.loop.set_debug(True)
97        f_pending_debug = asyncio.Future(loop=self.loop)
98        frame = f_pending_debug._source_traceback[-1]
99        self.assertEqual(repr(f_pending_debug),
100                         '<Future pending created at %s:%s>'
101                         % (frame[0], frame[1]))
102        f_pending_debug.cancel()
103
104        self.loop.set_debug(False)
105        f_pending = asyncio.Future(loop=self.loop)
106        self.assertEqual(repr(f_pending), '<Future pending>')
107        f_pending.cancel()
108
109        f_cancelled = asyncio.Future(loop=self.loop)
110        f_cancelled.cancel()
111        self.assertEqual(repr(f_cancelled), '<Future cancelled>')
112
113        f_result = asyncio.Future(loop=self.loop)
114        f_result.set_result(4)
115        self.assertEqual(repr(f_result), '<Future finished result=4>')
116        self.assertEqual(f_result.result(), 4)
117
118        exc = RuntimeError()
119        f_exception = asyncio.Future(loop=self.loop)
120        f_exception.set_exception(exc)
121        self.assertEqual(repr(f_exception),
122                         '<Future finished exception=RuntimeError()>')
123        self.assertIs(f_exception.exception(), exc)
124
125        def func_repr(func):
126            filename, lineno = test_utils.get_function_source(func)
127            func_name = getattr(func, '__qualname__', func.__name__)
128            text = '%s() at %s:%s' % (func_name, filename, lineno)
129            return re.escape(text)
130
131        f_one_callbacks = asyncio.Future(loop=self.loop)
132        f_one_callbacks.add_done_callback(_fakefunc)
133        fake_repr = func_repr(_fakefunc)
134        self.assertRegex(repr(f_one_callbacks),
135                         r'<Future pending cb=\[%s\]>' % fake_repr)
136        f_one_callbacks.cancel()
137        self.assertEqual(repr(f_one_callbacks),
138                         '<Future cancelled>')
139
140        f_two_callbacks = asyncio.Future(loop=self.loop)
141        f_two_callbacks.add_done_callback(first_cb)
142        f_two_callbacks.add_done_callback(last_cb)
143        first_repr = func_repr(first_cb)
144        last_repr = func_repr(last_cb)
145        self.assertRegex(repr(f_two_callbacks),
146                         r'<Future pending cb=\[%s, %s\]>'
147                         % (first_repr, last_repr))
148
149        f_many_callbacks = asyncio.Future(loop=self.loop)
150        f_many_callbacks.add_done_callback(first_cb)
151        for i in range(8):
152            f_many_callbacks.add_done_callback(_fakefunc)
153        f_many_callbacks.add_done_callback(last_cb)
154        cb_regex = r'%s, <8 more>, %s' % (first_repr, last_repr)
155        self.assertRegex(repr(f_many_callbacks),
156                         r'<Future pending cb=\[%s\]>' % cb_regex)
157        f_many_callbacks.cancel()
158        self.assertEqual(repr(f_many_callbacks),
159                         '<Future cancelled>')
160
161    def test_copy_state(self):
162        # Test the internal _copy_state method since it's being directly
163        # invoked in other modules.
164        f = asyncio.Future(loop=self.loop)
165        f.set_result(10)
166
167        newf = asyncio.Future(loop=self.loop)
168        newf._copy_state(f)
169        self.assertTrue(newf.done())
170        self.assertEqual(newf.result(), 10)
171
172        f_exception = asyncio.Future(loop=self.loop)
173        f_exception.set_exception(RuntimeError())
174
175        newf_exception = asyncio.Future(loop=self.loop)
176        newf_exception._copy_state(f_exception)
177        self.assertTrue(newf_exception.done())
178        self.assertRaises(RuntimeError, newf_exception.result)
179
180        f_cancelled = asyncio.Future(loop=self.loop)
181        f_cancelled.cancel()
182
183        newf_cancelled = asyncio.Future(loop=self.loop)
184        newf_cancelled._copy_state(f_cancelled)
185        self.assertTrue(newf_cancelled.cancelled())
186
187    @mock.patch('trollius.base_events.logger')
188    def test_tb_logger_abandoned(self, m_log):
189        fut = asyncio.Future(loop=self.loop)
190        del fut
191        self.assertFalse(m_log.error.called)
192
193    @mock.patch('trollius.base_events.logger')
194    def test_tb_logger_result_unretrieved(self, m_log):
195        fut = asyncio.Future(loop=self.loop)
196        fut.set_result(42)
197        del fut
198        self.assertFalse(m_log.error.called)
199
200    @mock.patch('trollius.base_events.logger')
201    def test_tb_logger_result_retrieved(self, m_log):
202        fut = asyncio.Future(loop=self.loop)
203        fut.set_result(42)
204        fut.result()
205        del fut
206        self.assertFalse(m_log.error.called)
207
208    @mock.patch('trollius.base_events.logger')
209    def test_tb_logger_exception_unretrieved(self, m_log):
210        self.loop.set_debug(True)
211        asyncio.set_event_loop(self.loop)
212        fut = asyncio.Future(loop=self.loop)
213        fut.set_exception(RuntimeError('boom'))
214        del fut
215        test_utils.run_briefly(self.loop)
216        support.gc_collect()
217        self.assertTrue(m_log.error.called)
218
219    @mock.patch('trollius.base_events.logger')
220    def test_tb_logger_exception_retrieved(self, m_log):
221        fut = asyncio.Future(loop=self.loop)
222        fut.set_exception(RuntimeError('boom'))
223        fut.exception()
224        del fut
225        self.assertFalse(m_log.error.called)
226
227    @mock.patch('trollius.base_events.logger')
228    def test_tb_logger_exception_result_retrieved(self, m_log):
229        fut = asyncio.Future(loop=self.loop)
230        fut.set_exception(RuntimeError('boom'))
231        self.assertRaises(RuntimeError, fut.result)
232        del fut
233        self.assertFalse(m_log.error.called)
234
235    @unittest.skipIf(concurrent is None, 'need concurrent.futures')
236    def test_wrap_future(self):
237
238        def run(arg):
239            return (arg, get_thread_ident())
240        ex = concurrent.futures.ThreadPoolExecutor(1)
241        f1 = ex.submit(run, 'oi')
242        f2 = asyncio.wrap_future(f1, loop=self.loop)
243        res, ident = self.loop.run_until_complete(f2)
244        self.assertIsInstance(f2, asyncio.Future)
245        self.assertEqual(res, 'oi')
246        self.assertNotEqual(ident, get_thread_ident())
247
248    def test_wrap_future_future(self):
249        f1 = asyncio.Future(loop=self.loop)
250        f2 = asyncio.wrap_future(f1)
251        self.assertIs(f1, f2)
252
253    @unittest.skipIf(concurrent is None, 'need concurrent.futures')
254    @mock.patch('trollius.futures.events')
255    def test_wrap_future_use_global_loop(self, m_events):
256        def run(arg):
257            return (arg, get_thread_ident())
258        ex = concurrent.futures.ThreadPoolExecutor(1)
259        f1 = ex.submit(run, 'oi')
260        f2 = asyncio.wrap_future(f1)
261        self.assertIs(m_events.get_event_loop.return_value, f2._loop)
262
263    @unittest.skipIf(concurrent is None, 'need concurrent.futures')
264    def test_wrap_future_cancel(self):
265        f1 = concurrent.futures.Future()
266        f2 = asyncio.wrap_future(f1, loop=self.loop)
267        f2.cancel()
268        test_utils.run_briefly(self.loop)
269        self.assertTrue(f1.cancelled())
270        self.assertTrue(f2.cancelled())
271
272    @unittest.skipIf(concurrent is None, 'need concurrent.futures')
273    def test_wrap_future_cancel2(self):
274        f1 = concurrent.futures.Future()
275        f2 = asyncio.wrap_future(f1, loop=self.loop)
276        f1.set_result(42)
277        f2.cancel()
278        test_utils.run_briefly(self.loop)
279        self.assertFalse(f1.cancelled())
280        self.assertEqual(f1.result(), 42)
281        self.assertTrue(f2.cancelled())
282
283    def test_future_source_traceback(self):
284        self.loop.set_debug(True)
285
286        future = asyncio.Future(loop=self.loop)
287        lineno = sys._getframe().f_lineno - 1
288        self.assertIsInstance(future._source_traceback, list)
289        filename = sys._getframe().f_code.co_filename
290        self.assertEqual(future._source_traceback[-1][:3],
291                         (filename,
292                          lineno,
293                          'test_future_source_traceback'))
294
295    @mock.patch('trollius.base_events.logger')
296    def check_future_exception_never_retrieved(self, debug, m_log):
297        self.loop.set_debug(debug)
298
299        def memory_error():
300            try:
301                raise MemoryError()
302            except BaseException as exc:
303                return exc
304        exc = memory_error()
305
306        future = asyncio.Future(loop=self.loop)
307        if debug:
308            source_traceback = future._source_traceback
309        future.set_exception(exc)
310        future = None
311        test_utils.run_briefly(self.loop)
312        support.gc_collect()
313
314        if sys.version_info >= (3, 4):
315            if debug:
316                frame = source_traceback[-1]
317                regex = (r'^Future exception was never retrieved\n'
318                         r'future: <Future finished exception=MemoryError\(\) '
319                             r'created at {filename}:{lineno}>\n'
320                         r'source_traceback: Object '
321                            r'created at \(most recent call last\):\n'
322                         r'  File'
323                         r'.*\n'
324                         r'  File "{filename}", line {lineno}, '
325                            r'in check_future_exception_never_retrieved\n'
326                         r'    future = asyncio\.Future\(loop=self\.loop\)$'
327                         ).format(filename=re.escape(frame[0]),
328                                  lineno=frame[1])
329            else:
330                regex = (r'^Future exception was never retrieved\n'
331                         r'future: '
332                            r'<Future finished exception=MemoryError\(\)>$'
333                         )
334            exc_info = (type(exc), exc, exc.__traceback__)
335            m_log.error.assert_called_once_with(mock.ANY, exc_info=exc_info)
336        else:
337            if debug:
338                frame = source_traceback[-1]
339                regex = (r'^Future/Task exception was never retrieved\n'
340                         r'Future/Task created at \(most recent call last\):\n'
341                         r'  File'
342                         r'.*\n'
343                         r'  File "{filename}", line {lineno}, '
344                            r'in check_future_exception_never_retrieved\n'
345                         r'    future = asyncio\.Future\(loop=self\.loop\)\n'
346                         r'Traceback \(most recent call last\):\n'
347                         r'.*\n'
348                         r'MemoryError$'
349                         ).format(filename=re.escape(frame[0]),
350                                  lineno=frame[1])
351            elif six.PY3:
352                regex = (r'^Future/Task exception was never retrieved\n'
353                         r'Traceback \(most recent call last\):\n'
354                         r'.*\n'
355                         r'MemoryError$'
356                         )
357            else:
358                regex = (r'^Future/Task exception was never retrieved\n'
359                         r'MemoryError$'
360                         )
361            m_log.error.assert_called_once_with(mock.ANY, exc_info=False)
362        message = m_log.error.call_args[0][0]
363        self.assertRegex(message, re.compile(regex, re.DOTALL))
364
365    def test_future_exception_never_retrieved(self):
366        self.check_future_exception_never_retrieved(False)
367
368    def test_future_exception_never_retrieved_debug(self):
369        self.check_future_exception_never_retrieved(True)
370
371    def test_set_result_unless_cancelled(self):
372        fut = asyncio.Future(loop=self.loop)
373        fut.cancel()
374        fut._set_result_unless_cancelled(2)
375        self.assertTrue(fut.cancelled())
376
377
378class FutureDoneCallbackTests(test_utils.TestCase):
379
380    def setUp(self):
381        self.loop = self.new_test_loop()
382
383    def run_briefly(self):
384        test_utils.run_briefly(self.loop)
385
386    def _make_callback(self, bag, thing):
387        # Create a callback function that appends thing to bag.
388        def bag_appender(future):
389            bag.append(thing)
390        return bag_appender
391
392    def _new_future(self):
393        return asyncio.Future(loop=self.loop)
394
395    def test_callbacks_invoked_on_set_result(self):
396        bag = []
397        f = self._new_future()
398        f.add_done_callback(self._make_callback(bag, 42))
399        f.add_done_callback(self._make_callback(bag, 17))
400
401        self.assertEqual(bag, [])
402        f.set_result('foo')
403
404        self.run_briefly()
405
406        self.assertEqual(bag, [42, 17])
407        self.assertEqual(f.result(), 'foo')
408
409    def test_callbacks_invoked_on_set_exception(self):
410        bag = []
411        f = self._new_future()
412        f.add_done_callback(self._make_callback(bag, 100))
413
414        self.assertEqual(bag, [])
415        exc = RuntimeError()
416        f.set_exception(exc)
417
418        self.run_briefly()
419
420        self.assertEqual(bag, [100])
421        self.assertEqual(f.exception(), exc)
422
423    def test_remove_done_callback(self):
424        bag = []
425        f = self._new_future()
426        cb1 = self._make_callback(bag, 1)
427        cb2 = self._make_callback(bag, 2)
428        cb3 = self._make_callback(bag, 3)
429
430        # Add one cb1 and one cb2.
431        f.add_done_callback(cb1)
432        f.add_done_callback(cb2)
433
434        # One instance of cb2 removed. Now there's only one cb1.
435        self.assertEqual(f.remove_done_callback(cb2), 1)
436
437        # Never had any cb3 in there.
438        self.assertEqual(f.remove_done_callback(cb3), 0)
439
440        # After this there will be 6 instances of cb1 and one of cb2.
441        f.add_done_callback(cb2)
442        for i in range(5):
443            f.add_done_callback(cb1)
444
445        # Remove all instances of cb1. One cb2 remains.
446        self.assertEqual(f.remove_done_callback(cb1), 6)
447
448        self.assertEqual(bag, [])
449        f.set_result('foo')
450
451        self.run_briefly()
452
453        self.assertEqual(bag, [2])
454        self.assertEqual(f.result(), 'foo')
455
456
457if __name__ == '__main__':
458    unittest.main()
459