1#!/usr/bin/env python
2# pylint: skip-file
3from __future__ import absolute_import, division, print_function
4
5from salt.ext.tornado import gen
6from salt.ext.tornado.log import app_log
7from salt.ext.tornado.stack_context import (StackContext, wrap, NullContext, StackContextInconsistentError,
8                                   ExceptionStackContext, run_with_stack_context, _state)
9from salt.ext.tornado.testing import AsyncHTTPTestCase, AsyncTestCase, ExpectLog, gen_test
10from salt.ext.tornado.test.util import unittest
11from salt.ext.tornado.web import asynchronous, Application, RequestHandler
12import contextlib
13import functools
14import logging
15
16
17class TestRequestHandler(RequestHandler):
18    def __init__(self, app, request, io_loop):
19        super(TestRequestHandler, self).__init__(app, request)
20        self.io_loop = io_loop
21
22    @asynchronous
23    def get(self):
24        logging.debug('in get()')
25        # call self.part2 without a self.async_callback wrapper.  Its
26        # exception should still get thrown
27        self.io_loop.add_callback(self.part2)
28
29    def part2(self):
30        logging.debug('in part2()')
31        # Go through a third layer to make sure that contexts once restored
32        # are again passed on to future callbacks
33        self.io_loop.add_callback(self.part3)
34
35    def part3(self):
36        logging.debug('in part3()')
37        raise Exception('test exception')
38
39    def write_error(self, status_code, **kwargs):
40        if 'exc_info' in kwargs and str(kwargs['exc_info'][1]) == 'test exception':
41            self.write('got expected exception')
42        else:
43            self.write('unexpected failure')
44
45
46class HTTPStackContextTest(AsyncHTTPTestCase):
47    def get_app(self):
48        return Application([('/', TestRequestHandler,
49                             dict(io_loop=self.io_loop))])
50
51    def test_stack_context(self):
52        with ExpectLog(app_log, "Uncaught exception GET /"):
53            self.http_client.fetch(self.get_url('/'), self.handle_response)
54            self.wait()
55        self.assertEqual(self.response.code, 500)
56        self.assertTrue(b'got expected exception' in self.response.body)
57
58    def handle_response(self, response):
59        self.response = response
60        self.stop()
61
62
63class StackContextTest(AsyncTestCase):
64    def setUp(self):
65        super(StackContextTest, self).setUp()
66        self.active_contexts = []
67
68    @contextlib.contextmanager
69    def context(self, name):
70        self.active_contexts.append(name)
71        yield
72        self.assertEqual(self.active_contexts.pop(), name)
73
74    # Simulates the effect of an asynchronous library that uses its own
75    # StackContext internally and then returns control to the application.
76    def test_exit_library_context(self):
77        def library_function(callback):
78            # capture the caller's context before introducing our own
79            callback = wrap(callback)
80            with StackContext(functools.partial(self.context, 'library')):
81                self.io_loop.add_callback(
82                    functools.partial(library_inner_callback, callback))
83
84        def library_inner_callback(callback):
85            self.assertEqual(self.active_contexts[-2:],
86                             ['application', 'library'])
87            callback()
88
89        def final_callback():
90            # implementation detail:  the full context stack at this point
91            # is ['application', 'library', 'application'].  The 'library'
92            # context was not removed, but is no longer innermost so
93            # the application context takes precedence.
94            self.assertEqual(self.active_contexts[-1], 'application')
95            self.stop()
96        with StackContext(functools.partial(self.context, 'application')):
97            library_function(final_callback)
98        self.wait()
99
100    def test_deactivate(self):
101        deactivate_callbacks = []
102
103        def f1():
104            with StackContext(functools.partial(self.context, 'c1')) as c1:
105                deactivate_callbacks.append(c1)
106                self.io_loop.add_callback(f2)
107
108        def f2():
109            with StackContext(functools.partial(self.context, 'c2')) as c2:
110                deactivate_callbacks.append(c2)
111                self.io_loop.add_callback(f3)
112
113        def f3():
114            with StackContext(functools.partial(self.context, 'c3')) as c3:
115                deactivate_callbacks.append(c3)
116                self.io_loop.add_callback(f4)
117
118        def f4():
119            self.assertEqual(self.active_contexts, ['c1', 'c2', 'c3'])
120            deactivate_callbacks[1]()
121            # deactivating a context doesn't remove it immediately,
122            # but it will be missing from the next iteration
123            self.assertEqual(self.active_contexts, ['c1', 'c2', 'c3'])
124            self.io_loop.add_callback(f5)
125
126        def f5():
127            self.assertEqual(self.active_contexts, ['c1', 'c3'])
128            self.stop()
129        self.io_loop.add_callback(f1)
130        self.wait()
131
132    def test_deactivate_order(self):
133        # Stack context deactivation has separate logic for deactivation at
134        # the head and tail of the stack, so make sure it works in any order.
135        def check_contexts():
136            # Make sure that the full-context array and the exception-context
137            # linked lists are consistent with each other.
138            full_contexts, chain = _state.contexts
139            exception_contexts = []
140            while chain is not None:
141                exception_contexts.append(chain)
142                chain = chain.old_contexts[1]
143            self.assertEqual(list(reversed(full_contexts)), exception_contexts)
144            return list(self.active_contexts)
145
146        def make_wrapped_function():
147            """Wraps a function in three stack contexts, and returns
148            the function along with the deactivation functions.
149            """
150            # Remove the test's stack context to make sure we can cover
151            # the case where the last context is deactivated.
152            with NullContext():
153                partial = functools.partial
154                with StackContext(partial(self.context, 'c0')) as c0:
155                    with StackContext(partial(self.context, 'c1')) as c1:
156                        with StackContext(partial(self.context, 'c2')) as c2:
157                            return (wrap(check_contexts), [c0, c1, c2])
158
159        # First make sure the test mechanism works without any deactivations
160        func, deactivate_callbacks = make_wrapped_function()
161        self.assertEqual(func(), ['c0', 'c1', 'c2'])
162
163        # Deactivate the tail
164        func, deactivate_callbacks = make_wrapped_function()
165        deactivate_callbacks[0]()
166        self.assertEqual(func(), ['c1', 'c2'])
167
168        # Deactivate the middle
169        func, deactivate_callbacks = make_wrapped_function()
170        deactivate_callbacks[1]()
171        self.assertEqual(func(), ['c0', 'c2'])
172
173        # Deactivate the head
174        func, deactivate_callbacks = make_wrapped_function()
175        deactivate_callbacks[2]()
176        self.assertEqual(func(), ['c0', 'c1'])
177
178    def test_isolation_nonempty(self):
179        # f2 and f3 are a chain of operations started in context c1.
180        # f2 is incidentally run under context c2, but that context should
181        # not be passed along to f3.
182        def f1():
183            with StackContext(functools.partial(self.context, 'c1')):
184                wrapped = wrap(f2)
185            with StackContext(functools.partial(self.context, 'c2')):
186                wrapped()
187
188        def f2():
189            self.assertIn('c1', self.active_contexts)
190            self.io_loop.add_callback(f3)
191
192        def f3():
193            self.assertIn('c1', self.active_contexts)
194            self.assertNotIn('c2', self.active_contexts)
195            self.stop()
196
197        self.io_loop.add_callback(f1)
198        self.wait()
199
200    def test_isolation_empty(self):
201        # Similar to test_isolation_nonempty, but here the f2/f3 chain
202        # is started without any context.  Behavior should be equivalent
203        # to the nonempty case (although historically it was not)
204        def f1():
205            with NullContext():
206                wrapped = wrap(f2)
207            with StackContext(functools.partial(self.context, 'c2')):
208                wrapped()
209
210        def f2():
211            self.io_loop.add_callback(f3)
212
213        def f3():
214            self.assertNotIn('c2', self.active_contexts)
215            self.stop()
216
217        self.io_loop.add_callback(f1)
218        self.wait()
219
220    def test_yield_in_with(self):
221        @gen.engine
222        def f():
223            self.callback = yield gen.Callback('a')
224            with StackContext(functools.partial(self.context, 'c1')):
225                # This yield is a problem: the generator will be suspended
226                # and the StackContext's __exit__ is not called yet, so
227                # the context will be left on _state.contexts for anything
228                # that runs before the yield resolves.
229                yield gen.Wait('a')
230
231        with self.assertRaises(StackContextInconsistentError):
232            f()
233            self.wait()
234        # Cleanup: to avoid GC warnings (which for some reason only seem
235        # to show up on py33-asyncio), invoke the callback (which will do
236        # nothing since the gen.Runner is already finished) and delete it.
237        self.callback()
238        del self.callback
239
240    @gen_test
241    def test_yield_outside_with(self):
242        # This pattern avoids the problem in the previous test.
243        cb = yield gen.Callback('k1')
244        with StackContext(functools.partial(self.context, 'c1')):
245            self.io_loop.add_callback(cb)
246        yield gen.Wait('k1')
247
248    def test_yield_in_with_exception_stack_context(self):
249        # As above, but with ExceptionStackContext instead of StackContext.
250        @gen.engine
251        def f():
252            with ExceptionStackContext(lambda t, v, tb: False):
253                yield gen.Task(self.io_loop.add_callback)
254
255        with self.assertRaises(StackContextInconsistentError):
256            f()
257            self.wait()
258
259    @gen_test
260    def test_yield_outside_with_exception_stack_context(self):
261        cb = yield gen.Callback('k1')
262        with ExceptionStackContext(lambda t, v, tb: False):
263            self.io_loop.add_callback(cb)
264        yield gen.Wait('k1')
265
266    @gen_test
267    def test_run_with_stack_context(self):
268        @gen.coroutine
269        def f1():
270            self.assertEqual(self.active_contexts, ['c1'])
271            yield run_with_stack_context(
272                StackContext(functools.partial(self.context, 'c2')),
273                f2)
274            self.assertEqual(self.active_contexts, ['c1'])
275
276        @gen.coroutine
277        def f2():
278            self.assertEqual(self.active_contexts, ['c1', 'c2'])
279            yield gen.Task(self.io_loop.add_callback)
280            self.assertEqual(self.active_contexts, ['c1', 'c2'])
281
282        self.assertEqual(self.active_contexts, [])
283        yield run_with_stack_context(
284            StackContext(functools.partial(self.context, 'c1')),
285            f1)
286        self.assertEqual(self.active_contexts, [])
287
288
289if __name__ == '__main__':
290    unittest.main()
291