1"""Unit tests for contextlib.py, and other context managers."""
2
3import sys
4import tempfile
5import unittest
6from contextlib import *  # Tests __all__
7from test import test_support
8try:
9    import threading
10except ImportError:
11    threading = None
12
13
14class ContextManagerTestCase(unittest.TestCase):
15
16    def test_contextmanager_plain(self):
17        state = []
18        @contextmanager
19        def woohoo():
20            state.append(1)
21            yield 42
22            state.append(999)
23        with woohoo() as x:
24            self.assertEqual(state, [1])
25            self.assertEqual(x, 42)
26            state.append(x)
27        self.assertEqual(state, [1, 42, 999])
28
29    def test_contextmanager_finally(self):
30        state = []
31        @contextmanager
32        def woohoo():
33            state.append(1)
34            try:
35                yield 42
36            finally:
37                state.append(999)
38        with self.assertRaises(ZeroDivisionError):
39            with woohoo() as x:
40                self.assertEqual(state, [1])
41                self.assertEqual(x, 42)
42                state.append(x)
43                raise ZeroDivisionError()
44        self.assertEqual(state, [1, 42, 999])
45
46    def test_contextmanager_no_reraise(self):
47        @contextmanager
48        def whee():
49            yield
50        ctx = whee()
51        ctx.__enter__()
52        # Calling __exit__ should not result in an exception
53        self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None))
54
55    def test_contextmanager_trap_yield_after_throw(self):
56        @contextmanager
57        def whoo():
58            try:
59                yield
60            except:
61                yield
62        ctx = whoo()
63        ctx.__enter__()
64        self.assertRaises(
65            RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None
66        )
67
68    def test_contextmanager_except(self):
69        state = []
70        @contextmanager
71        def woohoo():
72            state.append(1)
73            try:
74                yield 42
75            except ZeroDivisionError, e:
76                state.append(e.args[0])
77                self.assertEqual(state, [1, 42, 999])
78        with woohoo() as x:
79            self.assertEqual(state, [1])
80            self.assertEqual(x, 42)
81            state.append(x)
82            raise ZeroDivisionError(999)
83        self.assertEqual(state, [1, 42, 999])
84
85    def _create_contextmanager_attribs(self):
86        def attribs(**kw):
87            def decorate(func):
88                for k,v in kw.items():
89                    setattr(func,k,v)
90                return func
91            return decorate
92        @contextmanager
93        @attribs(foo='bar')
94        def baz(spam):
95            """Whee!"""
96        return baz
97
98    def test_contextmanager_attribs(self):
99        baz = self._create_contextmanager_attribs()
100        self.assertEqual(baz.__name__,'baz')
101        self.assertEqual(baz.foo, 'bar')
102
103    @unittest.skipIf(sys.flags.optimize >= 2,
104                     "Docstrings are omitted with -O2 and above")
105    def test_contextmanager_doc_attrib(self):
106        baz = self._create_contextmanager_attribs()
107        self.assertEqual(baz.__doc__, "Whee!")
108
109    def test_keywords(self):
110        # Ensure no keyword arguments are inhibited
111        @contextmanager
112        def woohoo(self, func, args, kwds):
113            yield (self, func, args, kwds)
114        with woohoo(self=11, func=22, args=33, kwds=44) as target:
115            self.assertEqual(target, (11, 22, 33, 44))
116
117class NestedTestCase(unittest.TestCase):
118
119    # XXX This needs more work
120
121    def test_nested(self):
122        @contextmanager
123        def a():
124            yield 1
125        @contextmanager
126        def b():
127            yield 2
128        @contextmanager
129        def c():
130            yield 3
131        with nested(a(), b(), c()) as (x, y, z):
132            self.assertEqual(x, 1)
133            self.assertEqual(y, 2)
134            self.assertEqual(z, 3)
135
136    def test_nested_cleanup(self):
137        state = []
138        @contextmanager
139        def a():
140            state.append(1)
141            try:
142                yield 2
143            finally:
144                state.append(3)
145        @contextmanager
146        def b():
147            state.append(4)
148            try:
149                yield 5
150            finally:
151                state.append(6)
152        with self.assertRaises(ZeroDivisionError):
153            with nested(a(), b()) as (x, y):
154                state.append(x)
155                state.append(y)
156                1 // 0
157        self.assertEqual(state, [1, 4, 2, 5, 6, 3])
158
159    def test_nested_right_exception(self):
160        @contextmanager
161        def a():
162            yield 1
163        class b(object):
164            def __enter__(self):
165                return 2
166            def __exit__(self, *exc_info):
167                try:
168                    raise Exception()
169                except:
170                    pass
171        with self.assertRaises(ZeroDivisionError):
172            with nested(a(), b()) as (x, y):
173                1 // 0
174        self.assertEqual((x, y), (1, 2))
175
176    def test_nested_b_swallows(self):
177        @contextmanager
178        def a():
179            yield
180        @contextmanager
181        def b():
182            try:
183                yield
184            except:
185                # Swallow the exception
186                pass
187        try:
188            with nested(a(), b()):
189                1 // 0
190        except ZeroDivisionError:
191            self.fail("Didn't swallow ZeroDivisionError")
192
193    def test_nested_break(self):
194        @contextmanager
195        def a():
196            yield
197        state = 0
198        while True:
199            state += 1
200            with nested(a(), a()):
201                break
202            state += 10
203        self.assertEqual(state, 1)
204
205    def test_nested_continue(self):
206        @contextmanager
207        def a():
208            yield
209        state = 0
210        while state < 3:
211            state += 1
212            with nested(a(), a()):
213                continue
214            state += 10
215        self.assertEqual(state, 3)
216
217    def test_nested_return(self):
218        @contextmanager
219        def a():
220            try:
221                yield
222            except:
223                pass
224        def foo():
225            with nested(a(), a()):
226                return 1
227            return 10
228        self.assertEqual(foo(), 1)
229
230class ClosingTestCase(unittest.TestCase):
231
232    # XXX This needs more work
233
234    def test_closing(self):
235        state = []
236        class C:
237            def close(self):
238                state.append(1)
239        x = C()
240        self.assertEqual(state, [])
241        with closing(x) as y:
242            self.assertEqual(x, y)
243        self.assertEqual(state, [1])
244
245    def test_closing_error(self):
246        state = []
247        class C:
248            def close(self):
249                state.append(1)
250        x = C()
251        self.assertEqual(state, [])
252        with self.assertRaises(ZeroDivisionError):
253            with closing(x) as y:
254                self.assertEqual(x, y)
255                1 // 0
256        self.assertEqual(state, [1])
257
258class FileContextTestCase(unittest.TestCase):
259
260    def testWithOpen(self):
261        tfn = tempfile.mktemp()
262        try:
263            f = None
264            with open(tfn, "w") as f:
265                self.assertFalse(f.closed)
266                f.write("Booh\n")
267            self.assertTrue(f.closed)
268            f = None
269            with self.assertRaises(ZeroDivisionError):
270                with open(tfn, "r") as f:
271                    self.assertFalse(f.closed)
272                    self.assertEqual(f.read(), "Booh\n")
273                    1 // 0
274            self.assertTrue(f.closed)
275        finally:
276            test_support.unlink(tfn)
277
278@unittest.skipUnless(threading, 'Threading required for this test.')
279class LockContextTestCase(unittest.TestCase):
280
281    def boilerPlate(self, lock, locked):
282        self.assertFalse(locked())
283        with lock:
284            self.assertTrue(locked())
285        self.assertFalse(locked())
286        with self.assertRaises(ZeroDivisionError):
287            with lock:
288                self.assertTrue(locked())
289                1 // 0
290        self.assertFalse(locked())
291
292    def testWithLock(self):
293        lock = threading.Lock()
294        self.boilerPlate(lock, lock.locked)
295
296    def testWithRLock(self):
297        lock = threading.RLock()
298        self.boilerPlate(lock, lock._is_owned)
299
300    def testWithCondition(self):
301        lock = threading.Condition()
302        def locked():
303            return lock._is_owned()
304        self.boilerPlate(lock, locked)
305
306    def testWithSemaphore(self):
307        lock = threading.Semaphore()
308        def locked():
309            if lock.acquire(False):
310                lock.release()
311                return False
312            else:
313                return True
314        self.boilerPlate(lock, locked)
315
316    def testWithBoundedSemaphore(self):
317        lock = threading.BoundedSemaphore()
318        def locked():
319            if lock.acquire(False):
320                lock.release()
321                return False
322            else:
323                return True
324        self.boilerPlate(lock, locked)
325
326# This is needed to make the test actually run under regrtest.py!
327def test_main():
328    with test_support.check_warnings(("With-statements now directly support "
329                                      "multiple context managers",
330                                      DeprecationWarning)):
331        test_support.run_unittest(__name__)
332
333if __name__ == "__main__":
334    test_main()
335