1# copyright 2003-2012 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
2# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
3#
4# This file is part of logilab-common.
5#
6# logilab-common is free software: you can redistribute it and/or modify it under
7# the terms of the GNU Lesser General Public License as published by the Free
8# Software Foundation, either version 2.1 of the License, or (at your option) any
9# later version.
10#
11# logilab-common is distributed in the hope that it will be useful, but WITHOUT
12# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
13# FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
14# details.
15#
16# You should have received a copy of the GNU Lesser General Public License along
17# with logilab-common.  If not, see <http://www.gnu.org/licenses/>.
18"""unittest module for logilab.comon.testlib"""
19
20
21import os
22import sys
23from os.path import join, dirname, isdir, isfile, abspath, exists
24import tempfile
25import shutil
26
27try:
28    __file__
29except NameError:
30    __file__ = sys.argv[0]
31
32from logilab.common.compat import StringIO
33from logilab.common.testlib import (
34    TestSuite,
35    unittest_main,
36    Tags,
37    TestCase,
38    mock_object,
39    create_files,
40    InnerTest,
41    with_tempdir,
42    tag,
43    require_version,
44    require_module,
45)
46from logilab.common.pytest import SkipAwareTextTestRunner, NonStrictTestLoader
47
48
49class MockTestCase(TestCase):
50    def __init__(self):
51        # Do not call unittest.TestCase's __init__
52        pass
53
54    def fail(self, msg):
55        raise AssertionError(msg)
56
57
58class UtilTC(TestCase):
59    def test_mockobject(self):
60        obj = mock_object(foo="bar", baz="bam")
61        self.assertEqual(obj.foo, "bar")
62        self.assertEqual(obj.baz, "bam")
63
64    def test_create_files(self):
65        chroot = tempfile.mkdtemp()
66        path_to = lambda path: join(chroot, path)
67        dircontent = lambda path: sorted(os.listdir(join(chroot, path)))
68        try:
69            self.assertFalse(isdir(path_to("a/")))
70            create_files(["a/b/foo.py", "a/b/c/", "a/b/c/d/e.py"], chroot)
71            # make sure directories exist
72            self.assertTrue(isdir(path_to("a")))
73            self.assertTrue(isdir(path_to("a/b")))
74            self.assertTrue(isdir(path_to("a/b/c")))
75            self.assertTrue(isdir(path_to("a/b/c/d")))
76            # make sure files exist
77            self.assertTrue(isfile(path_to("a/b/foo.py")))
78            self.assertTrue(isfile(path_to("a/b/c/d/e.py")))
79            # make sure only asked files were created
80            self.assertEqual(dircontent("a"), ["b"])
81            self.assertEqual(dircontent("a/b"), ["c", "foo.py"])
82            self.assertEqual(dircontent("a/b/c"), ["d"])
83            self.assertEqual(dircontent("a/b/c/d"), ["e.py"])
84        finally:
85            shutil.rmtree(chroot)
86
87
88class TestlibTC(TestCase):
89    def mkdir(self, path):
90        if not exists(path):
91            self._dirs.add(path)
92            os.mkdir(path)
93
94    def setUp(self):
95        self.tc = MockTestCase()
96        self._dirs = set()
97
98    def tearDown(self):
99        while self._dirs:
100            shutil.rmtree(self._dirs.pop(), ignore_errors=True)
101
102    def test_dict_equals(self):
103        """tests TestCase.assertDictEqual"""
104        d1 = {"a": 1, "b": 2}
105        d2 = {"a": 1, "b": 3}
106        d3 = dict(d1)
107        self.assertRaises(AssertionError, self.tc.assertDictEqual, d1, d2)
108        self.tc.assertDictEqual(d1, d3)
109        self.tc.assertDictEqual(d3, d1)
110        self.tc.assertDictEqual(d1, d1)
111
112    def test_list_equals(self):
113        """tests TestCase.assertListEqual"""
114        l1 = list(range(10))
115        l2 = list(range(5))
116        l3 = list(range(10))
117        self.assertRaises(AssertionError, self.tc.assertListEqual, l1, l2)
118        self.tc.assertListEqual(l1, l1)
119        self.tc.assertListEqual(l1, l3)
120        self.tc.assertListEqual(l3, l1)
121
122    def test_equality_for_sets(self):
123        s1 = set("ab")
124        s2 = set("a")
125        self.assertRaises(AssertionError, self.tc.assertSetEqual, s1, s2)
126        self.tc.assertSetEqual(s1, s1)
127        self.tc.assertSetEqual(set(), set())
128
129    def test_text_equality(self):
130        self.assertRaises(AssertionError, self.tc.assertMultiLineEqual, "toto", 12)
131        self.assertRaises(AssertionError, self.tc.assertMultiLineEqual, "toto", 12)
132        self.assertRaises(AssertionError, self.tc.assertMultiLineEqual, "toto", None)
133        self.assertRaises(AssertionError, self.tc.assertMultiLineEqual, "toto", None)
134        self.assertRaises(AssertionError, self.tc.assertMultiLineEqual, 3.12, "toto")
135        self.assertRaises(AssertionError, self.tc.assertMultiLineEqual, 3.12, "toto")
136        self.assertRaises(AssertionError, self.tc.assertMultiLineEqual, None, "toto")
137        self.assertRaises(AssertionError, self.tc.assertMultiLineEqual, None, "toto")
138        self.tc.assertMultiLineEqual("toto\ntiti", "toto\ntiti")
139        self.tc.assertMultiLineEqual("toto\ntiti", "toto\ntiti")
140        self.assertRaises(
141            AssertionError, self.tc.assertMultiLineEqual, "toto\ntiti", "toto\n titi\n"
142        )
143        self.assertRaises(
144            AssertionError, self.tc.assertMultiLineEqual, "toto\ntiti", "toto\n titi\n"
145        )
146        foo = join(dirname(__file__), "data", "foo.txt")
147        spam = join(dirname(__file__), "data", "spam.txt")
148        with open(foo) as fobj:
149            text1 = fobj.read()
150        self.tc.assertMultiLineEqual(text1, text1)
151        self.tc.assertMultiLineEqual(text1, text1)
152        with open(spam) as fobj:
153            text2 = fobj.read()
154        self.assertRaises(AssertionError, self.tc.assertMultiLineEqual, text1, text2)
155        self.assertRaises(AssertionError, self.tc.assertMultiLineEqual, text1, text2)
156
157    def test_default_datadir(self):
158        expected_datadir = join(dirname(abspath(__file__)), "data")
159        self.assertEqual(self.datadir, expected_datadir)
160        self.assertEqual(self.datapath("foo"), join(expected_datadir, "foo"))
161
162    def test_multiple_args_datadir(self):
163        expected_datadir = join(dirname(abspath(__file__)), "data")
164        self.assertEqual(self.datadir, expected_datadir)
165        self.assertEqual(self.datapath("foo", "bar"), join(expected_datadir, "foo", "bar"))
166
167    def test_custom_datadir(self):
168        class MyTC(TestCase):
169            datadir = "foo"
170
171            def test_1(self):
172                pass
173
174        # class' custom datadir
175        tc = MyTC("test_1")
176        self.assertEqual(tc.datapath("bar"), join("foo", "bar"))
177
178    def test_cached_datadir(self):
179        """test datadir is cached on the class"""
180
181        class MyTC(TestCase):
182            def test_1(self):
183                pass
184
185        expected_datadir = join(dirname(abspath(__file__)), "data")
186        tc = MyTC("test_1")
187        self.assertEqual(tc.datadir, expected_datadir)
188        # changing module should not change the datadir
189        MyTC.__module__ = "os"
190        self.assertEqual(tc.datadir, expected_datadir)
191        # even on new instances
192        tc2 = MyTC("test_1")
193        self.assertEqual(tc2.datadir, expected_datadir)
194
195    def test_is(self):
196        obj_1 = []
197        obj_2 = []
198        self.assertIs(obj_1, obj_1)
199        self.assertRaises(AssertionError, self.assertIs, obj_1, obj_2)
200
201    def test_isnot(self):
202        obj_1 = []
203        obj_2 = []
204        self.assertIsNot(obj_1, obj_2)
205        self.assertRaises(AssertionError, self.assertIsNot, obj_1, obj_1)
206
207    def test_none(self):
208        self.assertIsNone(None)
209        self.assertRaises(AssertionError, self.assertIsNone, object())
210
211    def test_not_none(self):
212        self.assertIsNotNone(object())
213        self.assertRaises(AssertionError, self.assertIsNotNone, None)
214
215    def test_in(self):
216        self.assertIn("a", "dsqgaqg")
217        obj, seq = "a", ("toto", "azf", "coin")
218        self.assertRaises(AssertionError, self.assertIn, obj, seq)
219
220    def test_not_in(self):
221        self.assertNotIn("a", ("toto", "azf", "coin"))
222        self.assertRaises(AssertionError, self.assertNotIn, "a", "dsqgaqg")
223
224
225class GenerativeTestsTC(TestCase):
226    def setUp(self):
227        output = StringIO()
228        self.runner = SkipAwareTextTestRunner(stream=output)
229
230    def test_generative_ok(self):
231        class FooTC(TestCase):
232            def test_generative(self):
233                for i in range(10):
234                    yield self.assertEqual, i, i
235
236        result = self.runner.run(FooTC("test_generative"))
237        self.assertEqual(result.testsRun, 10)
238        self.assertEqual(len(result.failures), 0)
239        self.assertEqual(len(result.errors), 0)
240
241    def test_generative_half_bad(self):
242        class FooTC(TestCase):
243            def test_generative(self):
244                for i in range(10):
245                    yield self.assertEqual, i % 2, 0
246
247        result = self.runner.run(FooTC("test_generative"))
248        self.assertEqual(result.testsRun, 10)
249        self.assertEqual(len(result.failures), 5)
250        self.assertEqual(len(result.errors), 0)
251
252    def test_generative_error(self):
253        class FooTC(TestCase):
254            def test_generative(self):
255                for i in range(10):
256                    if i == 5:
257                        raise ValueError("STOP !")
258                    yield self.assertEqual, i, i
259
260        result = self.runner.run(FooTC("test_generative"))
261        self.assertEqual(result.testsRun, 5)
262        self.assertEqual(len(result.failures), 0)
263        self.assertEqual(len(result.errors), 1)
264
265    def test_generative_error2(self):
266        class FooTC(TestCase):
267            def test_generative(self):
268                for i in range(10):
269                    if i == 5:
270                        yield self.ouch
271                    yield self.assertEqual, i, i
272
273            def ouch(self):
274                raise ValueError("stop !")
275
276        result = self.runner.run(FooTC("test_generative"))
277        self.assertEqual(result.testsRun, 11)
278        self.assertEqual(len(result.failures), 0)
279        self.assertEqual(len(result.errors), 1)
280
281    def test_generative_setup(self):
282        class FooTC(TestCase):
283            def setUp(self):
284                raise ValueError("STOP !")
285
286            def test_generative(self):
287                for i in range(10):
288                    yield self.assertEqual, i, i
289
290        result = self.runner.run(FooTC("test_generative"))
291        self.assertEqual(result.testsRun, 1)
292        self.assertEqual(len(result.failures), 0)
293        self.assertEqual(len(result.errors), 1)
294
295    def test_generative_inner_skip(self):
296        class FooTC(TestCase):
297            def check(self, val):
298                if val == 5:
299                    self.innerSkip("no 5")
300                else:
301                    self.assertEqual(val, val)
302
303            def test_generative(self):
304                for i in range(10):
305                    yield InnerTest("check_%s" % i, self.check, i)
306
307        result = self.runner.run(FooTC("test_generative"))
308        self.assertEqual(result.testsRun, 10)
309        self.assertEqual(len(result.failures), 0)
310        self.assertEqual(len(result.errors), 0)
311        self.assertEqual(len(result.skipped), 1)
312
313    def test_generative_skip(self):
314        class FooTC(TestCase):
315            def check(self, val):
316                if val == 5:
317                    self.skipTest("no 5")
318                else:
319                    self.assertEqual(val, val)
320
321            def test_generative(self):
322                for i in range(10):
323                    yield InnerTest("check_%s" % i, self.check, i)
324
325        result = self.runner.run(FooTC("test_generative"))
326        self.assertEqual(result.testsRun, 10)
327        self.assertEqual(len(result.failures), 0)
328        self.assertEqual(len(result.errors), 0)
329        self.assertEqual(len(result.skipped), 1)
330
331    def test_generative_inner_error(self):
332        class FooTC(TestCase):
333            def check(self, val):
334                if val == 5:
335                    raise ValueError("no 5")
336                else:
337                    self.assertEqual(val, val)
338
339            def test_generative(self):
340                for i in range(10):
341                    yield InnerTest("check_%s" % i, self.check, i)
342
343        result = self.runner.run(FooTC("test_generative"))
344        self.assertEqual(result.testsRun, 10)
345        self.assertEqual(len(result.failures), 0)
346        self.assertEqual(len(result.errors), 1)
347        self.assertEqual(len(result.skipped), 0)
348
349    def test_generative_inner_failure(self):
350        class FooTC(TestCase):
351            def check(self, val):
352                if val == 5:
353                    self.assertEqual(val, val + 1)
354                else:
355                    self.assertEqual(val, val)
356
357            def test_generative(self):
358                for i in range(10):
359                    yield InnerTest("check_%s" % i, self.check, i)
360
361        result = self.runner.run(FooTC("test_generative"))
362        self.assertEqual(result.testsRun, 10)
363        self.assertEqual(len(result.failures), 1)
364        self.assertEqual(len(result.errors), 0)
365        self.assertEqual(len(result.skipped), 0)
366
367    def test_generative_outer_failure(self):
368        class FooTC(TestCase):
369            def test_generative(self):
370                self.fail()
371                yield
372
373        result = self.runner.run(FooTC("test_generative"))
374        self.assertEqual(result.testsRun, 0)
375        self.assertEqual(len(result.failures), 1)
376        self.assertEqual(len(result.errors), 0)
377        self.assertEqual(len(result.skipped), 0)
378
379    def test_generative_outer_skip(self):
380        class FooTC(TestCase):
381            def test_generative(self):
382                self.skipTest("blah")
383                yield
384
385        result = self.runner.run(FooTC("test_generative"))
386        self.assertEqual(result.testsRun, 0)
387        self.assertEqual(len(result.failures), 0)
388        self.assertEqual(len(result.errors), 0)
389        self.assertEqual(len(result.skipped), 1)
390
391
392class ExitFirstTC(TestCase):
393    def setUp(self):
394        output = StringIO()
395        self.runner = SkipAwareTextTestRunner(stream=output, exitfirst=True)
396
397    def test_failure_exit_first(self):
398        class FooTC(TestCase):
399            def test_1(self):
400                pass
401
402            def test_2(self):
403                assert False
404
405            def test_3(self):
406                pass
407
408        tests = [FooTC("test_1"), FooTC("test_2")]
409        result = self.runner.run(TestSuite(tests))
410        self.assertEqual(result.testsRun, 2)
411        self.assertEqual(len(result.failures), 1)
412        self.assertEqual(len(result.errors), 0)
413
414    def test_error_exit_first(self):
415        class FooTC(TestCase):
416            def test_1(self):
417                pass
418
419            def test_2(self):
420                raise ValueError()
421
422            def test_3(self):
423                pass
424
425        tests = [FooTC("test_1"), FooTC("test_2"), FooTC("test_3")]
426        result = self.runner.run(TestSuite(tests))
427        self.assertEqual(result.testsRun, 2)
428        self.assertEqual(len(result.failures), 0)
429        self.assertEqual(len(result.errors), 1)
430
431    def test_generative_exit_first(self):
432        class FooTC(TestCase):
433            def test_generative(self):
434                for i in range(10):
435                    yield self.assertTrue, False
436
437        result = self.runner.run(FooTC("test_generative"))
438        self.assertEqual(result.testsRun, 1)
439        self.assertEqual(len(result.failures), 1)
440        self.assertEqual(len(result.errors), 0)
441
442
443class TestLoaderTC(TestCase):
444    # internal classes for test purposes ########
445    class FooTC(TestCase):
446        def test_foo1(self):
447            pass
448
449        def test_foo2(self):
450            pass
451
452        def test_bar1(self):
453            pass
454
455    class BarTC(TestCase):
456        def test_bar2(self):
457            pass
458
459    ##############################################
460
461    def setUp(self):
462        self.loader = NonStrictTestLoader()
463        self.module = (
464            TestLoaderTC  # mock_object(FooTC=TestLoaderTC.FooTC, BarTC=TestLoaderTC.BarTC)
465        )
466        self.output = StringIO()
467        self.runner = SkipAwareTextTestRunner(stream=self.output)
468
469    def assertRunCount(self, pattern, module, expected_count, skipped=()):
470        self.loader.test_pattern = pattern
471        self.loader.skipped_patterns = skipped
472        if pattern:
473            suite = self.loader.loadTestsFromNames([pattern], module)
474        else:
475            suite = self.loader.loadTestsFromModule(module)
476        result = self.runner.run(suite)
477        self.loader.test_pattern = None
478        self.loader.skipped_patterns = ()
479        self.assertEqual(result.testsRun, expected_count)
480
481    def test_collect_everything(self):
482        """make sure we don't change the default behaviour
483        for loadTestsFromModule() and loadTestsFromTestCase
484        """
485        testsuite = self.loader.loadTestsFromModule(self.module)
486        self.assertEqual(len(testsuite._tests), 2)
487        suite1, suite2 = testsuite._tests
488        self.assertEqual(len(suite1._tests) + len(suite2._tests), 4)
489
490    def test_collect_with_classname(self):
491        self.assertRunCount("FooTC", self.module, 3)
492        self.assertRunCount("BarTC", self.module, 1)
493
494    def test_collect_with_classname_and_pattern(self):
495        data = [
496            ("FooTC.test_foo1", 1),
497            ("FooTC.test_foo", 2),
498            ("FooTC.test_fo", 2),
499            ("FooTC.foo1", 1),
500            ("FooTC.foo", 2),
501            ("FooTC.whatever", 0),
502        ]
503        for pattern, expected_count in data:
504            yield self.assertRunCount, pattern, self.module, expected_count
505
506    def test_collect_with_pattern(self):
507        data = [
508            ("test_foo1", 1),
509            ("test_foo", 2),
510            ("test_bar", 2),
511            ("foo1", 1),
512            ("foo", 2),
513            ("bar", 2),
514            ("ba", 2),
515            ("test", 4),
516            ("ab", 0),
517        ]
518        for pattern, expected_count in data:
519            yield self.assertRunCount, pattern, self.module, expected_count
520
521    def test_testcase_with_custom_metaclass(self):
522        class mymetaclass(type):
523            pass
524
525        class MyMod:
526            class MyTestCase(TestCase):
527                __metaclass__ = mymetaclass
528
529                def test_foo1(self):
530                    pass
531
532                def test_foo2(self):
533                    pass
534
535                def test_bar(self):
536                    pass
537
538        data = [
539            ("test_foo1", 1),
540            ("test_foo", 2),
541            ("test_bar", 1),
542            ("foo1", 1),
543            ("foo", 2),
544            ("bar", 1),
545            ("ba", 1),
546            ("test", 3),
547            ("ab", 0),
548            ("MyTestCase.test_foo1", 1),
549            ("MyTestCase.test_foo", 2),
550            ("MyTestCase.test_fo", 2),
551            ("MyTestCase.foo1", 1),
552            ("MyTestCase.foo", 2),
553            ("MyTestCase.whatever", 0),
554        ]
555        for pattern, expected_count in data:
556            yield self.assertRunCount, pattern, MyMod, expected_count
557
558    def test_collect_everything_and_skipped_patterns(self):
559        testdata = [
560            (["foo1"], 3),
561            (["foo"], 2),
562            (["foo", "bar"], 0),
563        ]
564        for skipped, expected_count in testdata:
565            yield self.assertRunCount, None, self.module, expected_count, skipped
566
567    def test_collect_specific_pattern_and_skip_some(self):
568        testdata = [
569            ("bar", ["foo1"], 2),
570            ("bar", [], 2),
571            ("bar", ["bar"], 0),
572        ]
573        for runpattern, skipped, expected_count in testdata:
574            yield self.assertRunCount, runpattern, self.module, expected_count, skipped
575
576    def test_skip_classname(self):
577        testdata = [
578            (["BarTC"], 3),
579            (["FooTC"], 1),
580        ]
581        for skipped, expected_count in testdata:
582            yield self.assertRunCount, None, self.module, expected_count, skipped
583
584    def test_skip_classname_and_specific_collect(self):
585        testdata = [
586            ("bar", ["BarTC"], 1),
587            ("foo", ["FooTC"], 0),
588        ]
589        for runpattern, skipped, expected_count in testdata:
590            yield self.assertRunCount, runpattern, self.module, expected_count, skipped
591
592    def test_nonregr_dotted_path(self):
593        self.assertRunCount("FooTC.test_foo", self.module, 2)
594
595    def test_inner_tests_selection(self):
596        class MyMod:
597            class MyTestCase(TestCase):
598                def test_foo(self):
599                    pass
600
601                def test_foobar(self):
602                    for i in range(5):
603                        if i % 2 == 0:
604                            yield InnerTest("even", lambda: None)
605                        else:
606                            yield InnerTest("odd", lambda: None)
607                    yield lambda: None
608
609        # FIXME InnerTest masked by pattern usage
610        # data = [('foo', 7), ('test_foobar', 6), ('even', 3), ('odd', 2), ]
611        data = [
612            ("foo", 7),
613            ("test_foobar", 6),
614            ("even", 0),
615            ("odd", 0),
616        ]
617        for pattern, expected_count in data:
618            yield self.assertRunCount, pattern, MyMod, expected_count
619
620    def test_nonregr_class_skipped_option(self):
621        class MyMod:
622            class MyTestCase(TestCase):
623                def test_foo(self):
624                    pass
625
626                def test_bar(self):
627                    pass
628
629            class FooTC(TestCase):
630                def test_foo(self):
631                    pass
632
633        self.assertRunCount("foo", MyMod, 2)
634        self.assertRunCount(None, MyMod, 3)
635        self.assertRunCount("foo", MyMod, 1, ["FooTC"])
636        self.assertRunCount(None, MyMod, 2, ["FooTC"])
637
638    def test__classes_are_ignored(self):
639        class MyMod:
640            class _Base(TestCase):
641                def test_1(self):
642                    pass
643
644            class MyTestCase(_Base):
645                def test_2(self):
646                    pass
647
648        self.assertRunCount(None, MyMod, 2)
649
650
651class DecoratorTC(TestCase):
652    @with_tempdir
653    def test_tmp_dir_normal_1(self):
654        tempdir = tempfile.gettempdir()
655        # assert temp directory is empty
656        self.assertListEqual(list(os.walk(tempdir)), [(tempdir, [], [])])
657
658        witness = []
659
660        @with_tempdir
661        def createfile(list):
662            fd1, fn1 = tempfile.mkstemp()
663            fd2, fn2 = tempfile.mkstemp()
664            dir = tempfile.mkdtemp()
665            fd3, fn3 = tempfile.mkstemp(dir=dir)
666            tempfile.mkdtemp()
667            list.append(True)
668            for fd in (fd1, fd2, fd3):
669                os.close(fd)
670
671        self.assertFalse(witness)
672        createfile(witness)
673        self.assertTrue(witness)
674
675        self.assertEqual(tempfile.gettempdir(), tempdir)
676
677        # assert temp directory is empty
678        self.assertListEqual(list(os.walk(tempdir)), [(tempdir, [], [])])
679
680    @with_tempdir
681    def test_tmp_dir_normal_2(self):
682        tempdir = tempfile.gettempdir()
683        # assert temp directory is empty
684        self.assertListEqual(list(os.walk(tempfile.tempdir)), [(tempfile.tempdir, [], [])])
685
686        class WitnessException(Exception):
687            pass
688
689        @with_tempdir
690        def createfile():
691            fd1, fn1 = tempfile.mkstemp()
692            fd2, fn2 = tempfile.mkstemp()
693            dir = tempfile.mkdtemp()
694            fd3, fn3 = tempfile.mkstemp(dir=dir)
695            tempfile.mkdtemp()
696            for fd in (fd1, fd2, fd3):
697                os.close(fd)
698            raise WitnessException()
699
700        self.assertRaises(WitnessException, createfile)
701
702        # assert tempdir didn't change
703        self.assertEqual(tempfile.gettempdir(), tempdir)
704
705        # assert temp directory is empty
706        self.assertListEqual(list(os.walk(tempdir)), [(tempdir, [], [])])
707
708    def test_tmpdir_generator(self):
709        orig_tempdir = tempfile.gettempdir()
710
711        @with_tempdir
712        def gen():
713            yield tempfile.gettempdir()
714
715        for tempdir in gen():
716            self.assertNotEqual(orig_tempdir, tempdir)
717        self.assertEqual(orig_tempdir, tempfile.gettempdir())
718
719    def setUp(self):
720        self.pyversion = sys.version_info
721
722    def tearDown(self):
723        sys.version_info = self.pyversion
724
725    def test_require_version_good(self):
726        """should return the same function"""
727
728        def func():
729            pass
730
731        sys.version_info = (2, 5, 5, "final", 4)
732        current = sys.version_info[:3]
733        compare = ("2.4", "2.5", "2.5.4", "2.5.5")
734        for version in compare:
735            decorator = require_version(version)
736            self.assertEqual(
737                func,
738                decorator(func),
739                "%s =< %s : function \
740                return by the decorator should be the same."
741                % (version, ".".join([str(element) for element in current])),
742            )
743
744    def test_require_version_bad(self):
745        """should return a different function : skipping test"""
746
747        def func():
748            pass
749
750        sys.version_info = (2, 5, 5, "final", 4)
751        current = sys.version_info[:3]
752        compare = ("2.5.6", "2.6", "2.6.5")
753        for version in compare:
754            decorator = require_version(version)
755            self.assertNotEqual(
756                func,
757                decorator(func),
758                "%s >= %s : function \
759                 return by the decorator should NOT be the same."
760                % (".".join([str(element) for element in current]), version),
761            )
762
763    def test_require_version_exception(self):
764        """should throw a ValueError exception"""
765
766        def func():
767            pass
768
769        compare = ("2.5.a", "2.a", "azerty")
770        for version in compare:
771            decorator = require_version(version)
772            self.assertRaises(ValueError, decorator, func)
773
774    def test_require_module_good(self):
775        """should return the same function"""
776
777        def func():
778            pass
779
780        module = "sys"
781        decorator = require_module(module)
782        self.assertEqual(
783            func,
784            decorator(func),
785            "module %s exists : function \
786            return by the decorator should be the same."
787            % module,
788        )
789
790    def test_require_module_bad(self):
791        """should return a different function : skipping test"""
792
793        def func():
794            pass
795
796        modules = ("bla", "blo", "bli")
797        for module in modules:
798            try:
799                __import__(module)
800            except ImportError:
801                decorator = require_module(module)
802                self.assertNotEqual(
803                    func,
804                    decorator(func),
805                    "module %s does \
806                    not exist : function return by the decorator should \
807                    NOT be the same."
808                    % module,
809                )
810                return
811        print(
812            "all modules in %s exist. Could not test %s"
813            % (", ".join(modules), sys._getframe().f_code.co_name)
814        )
815
816
817class TagTC(TestCase):
818    def setUp(self):
819        @tag("testing", "bob")
820        def bob(a, b, c):
821            return (a + b) * c
822
823        self.func = bob
824
825        class TagTestTC(TestCase):
826            tags = Tags("one", "two")
827
828            def test_one(self):
829                self.assertTrue(True)
830
831            @tag("two", "three")
832            def test_two(self):
833                self.assertTrue(True)
834
835            @tag("three", inherit=False)
836            def test_three(self):
837                self.assertTrue(True)
838
839        self.cls = TagTestTC
840
841    def test_tag_decorator(self):
842        bob = self.func
843
844        self.assertEqual(bob(2, 3, 7), 35)
845        self.assertTrue(hasattr(bob, "tags"))
846        self.assertSetEqual(bob.tags, {"testing", "bob"})
847
848    def test_tags_class(self):
849        tags = self.func.tags
850
851        self.assertTrue(tags["testing"])
852        self.assertFalse(tags["Not inside"])
853
854    def test_tags_match(self):
855        tags = self.func.tags
856
857        self.assertTrue(tags.match("testing"))
858        self.assertFalse(tags.match("other"))
859
860        self.assertFalse(tags.match("testing and coin"))
861        self.assertTrue(tags.match("testing or other"))
862
863        self.assertTrue(tags.match("not other"))
864
865        self.assertTrue(tags.match("not other or (testing and bibi)"))
866        self.assertTrue(tags.match("other or (testing and bob)"))
867
868    def test_tagged_class(self):
869        def options(tags):
870            class Options:
871                tags_pattern = tags
872
873            return Options()
874
875        tc = self.cls("test_one")
876
877        runner = SkipAwareTextTestRunner()
878        self.assertTrue(runner.does_match_tags(tc.test_one))
879        self.assertTrue(runner.does_match_tags(tc.test_two))
880        self.assertTrue(runner.does_match_tags(tc.test_three))
881
882        runner = SkipAwareTextTestRunner(options=options("one"))
883        self.assertTrue(runner.does_match_tags(tc.test_one))
884        self.assertTrue(runner.does_match_tags(tc.test_two))
885        self.assertFalse(runner.does_match_tags(tc.test_three))
886
887        runner = SkipAwareTextTestRunner(options=options("two"))
888        self.assertTrue(runner.does_match_tags(tc.test_one))
889        self.assertTrue(runner.does_match_tags(tc.test_two))
890        self.assertFalse(runner.does_match_tags(tc.test_three))
891
892        runner = SkipAwareTextTestRunner(options=options("three"))
893        self.assertFalse(runner.does_match_tags(tc.test_one))
894        self.assertTrue(runner.does_match_tags(tc.test_two))
895        self.assertTrue(runner.does_match_tags(tc.test_three))
896
897        runner = SkipAwareTextTestRunner(options=options("two or three"))
898        self.assertTrue(runner.does_match_tags(tc.test_one))
899        self.assertTrue(runner.does_match_tags(tc.test_two))
900        self.assertTrue(runner.does_match_tags(tc.test_three))
901
902        runner = SkipAwareTextTestRunner(options=options("two and three"))
903        self.assertFalse(runner.does_match_tags(tc.test_one))
904        self.assertTrue(runner.does_match_tags(tc.test_two))
905        self.assertFalse(runner.does_match_tags(tc.test_three))
906
907
908if __name__ == "__main__":
909    unittest_main()
910