1__copyright__ = "Copyright (C) 2014-2016  Martin Blais"
2__license__ = "GNU GPLv2"
3
4import unittest
5import io
6import os
7import sys
8from os import path
9
10from beancount.utils import test_utils
11
12
13class TestTestUtils(unittest.TestCase):
14
15    def test_run_with_args(self):
16        sentinel = []
17        def main():
18            sentinel.append(sys.argv)
19        test_utils.run_with_args(main, ['a', 'b', 'c'])
20        self.assertEqual(1, len(sentinel))
21        sys_argv = sentinel[0]
22        self.assertTrue(sys_argv[0].endswith('beancount/utils/test_utils_test.py'))
23        self.assertEqual(['a', 'b', 'c'], sys_argv[1:])
24
25    def test_tempdir(self):
26        with test_utils.tempdir() as tempdir:
27            with open(path.join(tempdir, 'file1'), 'w'): pass
28            os.mkdir(path.join(tempdir, 'directory'))
29            with open(path.join(tempdir, 'directory', 'file2'), 'w'): pass
30        self.assertFalse(path.exists(tempdir))
31        self.assertFalse(path.exists(path.join(tempdir, 'file1')))
32        self.assertFalse(path.exists(path.join(tempdir, 'directory')))
33
34    def test_create_temporary_files(self):
35        with test_utils.tempdir() as tmp:
36            test_utils.create_temporary_files(tmp, {
37                'apples.beancount': """
38                  include "{root}/fruits/oranges.beancount"
39
40                  2014-01-01 open Assets:Apples
41                """,
42                'fruits/oranges.beancount': """
43                  2014-01-02 open Assets:Oranges
44                """})
45
46            # Check the total list of files.
47            apples = path.join(tmp, 'apples.beancount')
48            oranges = path.join(tmp, 'fruits/oranges.beancount')
49            self.assertEqual({apples, oranges},
50                             set(path.join(root, filename)
51                                 for root, _, files in os.walk(tmp)
52                                 for filename in files))
53
54            # Check the contents of apples (with replacement of root).
55            with open(apples) as f: apples_content = f.read()
56            self.assertRegex(apples_content, 'open Assets:Apples')
57            self.assertNotRegex(apples_content, '{root}')
58
59            # Check the contents of oranges.
60            with open(oranges) as f: oranges_content = f.read()
61            self.assertRegex(oranges_content, 'open Assets:Oranges')
62
63    def test_capture(self):
64        text = "b9baaa0c-0f0a-47db-bffc-a00c6f4ac1db"
65        with test_utils.capture() as output:
66            self.assertTrue(isinstance(output, io.StringIO))
67            print(text)
68        self.assertEqual(text + "\n", output.getvalue())
69
70    @test_utils.docfile
71    def test_docfile(self, filename):
72        "7f9034b1-51e7-420c-ac6b-945b5c594ebf"
73        with open(filename) as f: uuid = f.read()
74        self.assertEqual("7f9034b1-51e7-420c-ac6b-945b5c594ebf", uuid)
75
76    @test_utils.docfile_extra(suffix='.txt')
77    def test_docfile_extra(self, filename):
78        "7f9034b1-51e7-420c-ac6b-945b5c594ebf"
79        with open(filename) as f: uuid = f.read()
80        self.assertEqual("7f9034b1-51e7-420c-ac6b-945b5c594ebf", uuid)
81        self.assertTrue('.txt' in filename)
82
83    def test_search_words(self):
84        test_utils.search_words('i walrus is',
85                                'i am the walrus is not chicago')
86        test_utils.search_words('i walrus is'.split(),
87                                'i am the walrus is not chicago')
88
89    def test_environ_contextmanager(self):
90        with test_utils.environ('PATH', '/unlikely-to-be-your-path'):
91            self.assertEqual('/unlikely-to-be-your-path', os.getenv('PATH'))
92        self.assertNotEqual('/unlikely-to-be-your-path', os.getenv('PATH'))
93
94
95class TestTestCase(test_utils.TestCase):
96
97    def test_assertLines(self):
98        self.assertLines("""
99           43c62bff-8504-44ea-b5c0-afa218a7a973
100           95ef1cc4-0016-4452-9f4e-1a053db2bc83
101        """, """
102
103             43c62bff-8504-44ea-b5c0-afa218a7a973
104               95ef1cc4-0016-4452-9f4e-1a053db2bc83
105
106        """)
107
108        with self.assertRaises(AssertionError):
109            self.assertLines("""
110               43c62bff-8504-44ea-b5c0-afa218a7a973
111            """, """
112                683f111f-f921-4db3-a3e8-daae344981e8
113            """)
114
115    def test_assertOutput(self):
116        with self.assertOutput("""
117           3165efbc-c775-4503-be13-06b7167697a9
118        """):
119            print('3165efbc-c775-4503-be13-06b7167697a9')
120
121        with self.assertRaises(AssertionError):
122            with self.assertOutput("""
123               3165efbc-c775-4503-be13-06b7167697a9
124            """):
125                print('78d58502a15e')
126
127
128class TestSkipIfRaises(unittest.TestCase):
129
130    def test_decorator(self):
131        @test_utils.skipIfRaises(ValueError)
132        def decorator_no_skip():
133            pass
134        decorator_no_skip()
135
136        @test_utils.skipIfRaises(ValueError)
137        def decorator_skip():
138            raise ValueError
139        with self.assertRaises(unittest.SkipTest):
140            decorator_skip()
141
142    def test_decorator_many(self):
143        @test_utils.skipIfRaises(ValueError, IndexError)
144        def decorator_skip():
145            raise ValueError
146        with self.assertRaises(unittest.SkipTest):
147            decorator_skip()
148
149    def test_contextmanager(self):
150        with test_utils.skipIfRaises(ValueError):
151            pass
152
153        with self.assertRaises(unittest.SkipTest):
154            with test_utils.skipIfRaises(ValueError):
155                raise ValueError
156
157
158@test_utils.nottest
159def test_not_really():
160    assert False
161
162
163if __name__ == '__main__':
164    unittest.main()
165