1__copyright__ = "Copyright (C) 2014-2016  Martin Blais"
2__license__ = "GNU GPLv2"
3
4import functools
5import logging
6import unittest
7import tempfile
8import textwrap
9import os
10from unittest import mock
11from os import path
12
13from beancount import loader
14from beancount.parser import parser
15from beancount.utils import test_utils
16from beancount.utils import encryption_test
17
18
19TEST_INPUT = """
20
212014-01-01 open Assets:MyBank:Checking   USD
222014-01-01 open Expenses:Restaurant   USD
23
242014-02-22 * "Something happened."
25  Assets:MyBank:Checking       100.00 USD
26  Expenses:Restaurant         -100.00 USD
27
282015-01-01 close Assets:MyBank:Checking
292015-01-01 close Expenses:Restaurant
30
31"""
32
33
34class TestLoader(unittest.TestCase):
35
36    def test_run_transformations(self):
37        # Test success case.
38        entries, errors, options_map = parser.parse_string(TEST_INPUT)
39        trans_entries, trans_errors = loader.run_transformations(
40            entries, errors, options_map, None)
41        self.assertEqual(0, len(trans_errors))
42
43        # Test an invalid plugin name.
44        entries, errors, options_map = parser.parse_string(
45            'plugin "invalid.module.name"\n\n' + TEST_INPUT)
46        trans_entries, trans_errors = loader.run_transformations(
47            entries, errors, options_map, None)
48        self.assertEqual(1, len(trans_errors))
49
50    def test_load(self):
51        with test_utils.capture():
52            with tempfile.NamedTemporaryFile('w') as tmpfile:
53                tmpfile.write(TEST_INPUT)
54                tmpfile.flush()
55                entries, errors, options_map = loader.load_file(tmpfile.name)
56                self.assertTrue(isinstance(entries, list))
57                self.assertTrue(isinstance(errors, list))
58                self.assertTrue(isinstance(options_map, dict))
59
60                entries, errors, options_map = loader.load_file(tmpfile.name,
61                                                                log_timings=logging.info)
62                self.assertTrue(isinstance(entries, list))
63                self.assertTrue(isinstance(errors, list))
64                self.assertTrue(isinstance(options_map, dict))
65
66    def test_load_string(self):
67        with test_utils.capture():
68            entries, errors, options_map = loader.load_string(TEST_INPUT)
69            self.assertTrue(isinstance(entries, list))
70            self.assertTrue(isinstance(errors, list))
71            self.assertTrue(isinstance(options_map, dict))
72
73            entries, errors, options_map = loader.load_string(TEST_INPUT,
74                                                              log_timings=logging.info)
75            self.assertTrue(isinstance(entries, list))
76            self.assertTrue(isinstance(errors, list))
77            self.assertTrue(isinstance(options_map, dict))
78
79    def test_load_nonexist(self):
80        entries, errors, options_map = loader.load_file('/some/bullshit/filename.beancount')
81        self.assertEqual([], entries)
82        self.assertTrue(errors)
83        self.assertRegex(errors[0].message, 'does not exist')
84
85    @mock.patch.dict(loader.RENAMED_MODULES,
86                     {"beancount.ops.auto_accounts": "beancount.plugins.auto_accounts"},
87                     clear=True)
88    @mock.patch('warnings.warn')
89    def test_renamed_plugin_warnings(self, warn):
90        with test_utils.capture('stderr'):
91            entries, errors, options_map = loader.load_string("""
92              plugin "beancount.ops.auto_accounts"
93            """, dedent=True)
94        self.assertTrue(warn.called)
95        self.assertFalse(errors)
96
97
98class TestLoadDoc(unittest.TestCase):
99
100    def test_load_doc(self):
101        def test_function(self_, entries, errors, options_map):
102            self.assertTrue(isinstance(entries, list))
103            self.assertTrue(isinstance(errors, list))
104            self.assertTrue(isinstance(options_map, dict))
105
106        test_function.__doc__ = TEST_INPUT
107        test_function = loader.load_doc(test_function)
108        test_function(self)
109
110    # pylint: disable=empty-docstring
111    @loader.load_doc()
112    def test_load_doc_empty(self, entries, errors, options_map):
113        """
114        """
115        self.assertTrue(isinstance(entries, list))
116        self.assertTrue(isinstance(errors, list))
117        self.assertFalse(errors)
118        self.assertTrue(isinstance(options_map, dict))
119
120    @loader.load_doc(expect_errors=True)
121    def test_load_doc_plugin(self, entries, errors, options_map):
122        """
123        plugin "beancount.does.not.exist"
124        """
125        self.assertTrue(isinstance(entries, list))
126        self.assertTrue(isinstance(options_map, dict))
127        self.assertTrue([loader.LoadError], list(map(type, errors)))
128
129    def test_load_doc_plugin_auto_pythonpath(self):
130        with tempfile.TemporaryDirectory() as tmpdir:
131            ledger_fn = path.join(tmpdir, 'my.beancount')
132            with open(ledger_fn, 'w') as ledger_file:
133                ledger_file.write('option "insert_pythonpath" "TRUE"\n')
134                ledger_file.write('plugin "localplugin"\n')
135
136            plugin_fn = path.join(tmpdir, 'localplugin.py')
137            with open(plugin_fn, 'w') as plugin_file:
138                plugin_file.write(textwrap.dedent("""\
139                  __plugins__ = ()
140                """))
141            entries, errors, options_map = loader.load_file(ledger_fn)
142            self.assertTrue(isinstance(entries, list))
143            self.assertTrue(isinstance(errors, list))
144            self.assertTrue(isinstance(options_map, dict))
145            self.assertFalse(errors)
146
147
148class TestLoadIncludes(unittest.TestCase):
149
150    def test_load_file_no_includes(self):
151        with test_utils.tempdir() as tmp:
152            test_utils.create_temporary_files(tmp, {
153                'apples.beancount': """
154                  2014-01-01 open Assets:Apples
155                """})
156            entries, errors, options_map = loader.load_file(
157                path.join(tmp, 'apples.beancount'))
158            self.assertEqual(0, len(errors))
159            self.assertEqual(['apples.beancount'],
160                             list(map(path.basename, options_map['include'])))
161
162    def test_load_file_nonexist(self):
163        entries, errors, options_map = loader.load_file('/bull/bla/root.beancount')
164        self.assertEqual(1, len(errors))
165        self.assertRegex(errors[0].message, 'does not exist')
166        self.assertEqual([], list(map(path.basename, options_map['include'])))
167
168    def test_load_file_with_nonexist_include(self):
169        with test_utils.tempdir() as tmp:
170            test_utils.create_temporary_files(tmp, {
171                'root.beancount': """
172                  include "/some/file/that/does/not/exist.beancount"
173                """})
174            entries, errors, options_map = loader.load_file(
175                path.join(tmp, 'root.beancount'))
176            self.assertEqual(1, len(errors))
177            self.assertRegex(errors[0].message, 'does not (match any|exist)')
178        self.assertEqual(['root.beancount'],
179                         list(map(path.basename, options_map['include'])))
180
181    def test_load_file_with_absolute_include(self):
182        with test_utils.tempdir() as tmp:
183            test_utils.create_temporary_files(tmp, {
184                'apples.beancount': """
185                  include "{root}/fruits/oranges.beancount"
186                  2014-01-01 open Assets:Apples
187                """,
188                'fruits/oranges.beancount': """
189                  2014-01-02 open Assets:Oranges
190                """})
191            entries, errors, options_map = loader.load_file(
192                path.join(tmp, 'apples.beancount'))
193        self.assertFalse(errors)
194        self.assertEqual(2, len(entries))
195        self.assertEqual(['apples.beancount', 'oranges.beancount'],
196                         list(map(path.basename, options_map['include'])))
197
198    def test_load_file_with_relative_include(self):
199        with test_utils.tempdir() as tmp:
200            test_utils.create_temporary_files(tmp, {
201                'apples.beancount': """
202                  include "fruits/oranges.beancount"
203                  2014-01-01 open Assets:Apples
204                """,
205                'fruits/oranges.beancount': """
206                  2014-01-02 open Assets:Oranges
207                """})
208            entries, errors, options_map = loader.load_file(
209                path.join(tmp, 'apples.beancount'))
210        self.assertFalse(errors)
211        self.assertEqual(2, len(entries))
212        self.assertEqual(['apples.beancount', 'oranges.beancount'],
213                         list(map(path.basename, options_map['include'])))
214
215    def test_load_file_with_multiple_includes(self):
216        # Including recursive includes and mixed and absolute.
217        with test_utils.tempdir() as tmp:
218            test_utils.create_temporary_files(tmp, {
219                'apples.beancount': """
220                  include "fruits/oranges.beancount"
221                  include "{root}/legumes/patates.beancount"
222                  2014-01-01 open Assets:Apples
223                """,
224                'fruits/oranges.beancount': """
225                  include "../legumes/tomates.beancount"
226                  2014-01-02 open Assets:Oranges
227                """,
228                'legumes/tomates.beancount': """
229                  2014-01-03 open Assets:Tomates
230                """,
231                'legumes/patates.beancount': """
232                  2014-01-04 open Assets:Patates
233                """})
234            entries, errors, options_map = loader.load_file(
235                path.join(tmp, 'apples.beancount'))
236        self.assertFalse(errors)
237        self.assertEqual(4, len(entries))
238        self.assertEqual(['apples.beancount', 'oranges.beancount',
239                          'patates.beancount', 'tomates.beancount'],
240                         list(map(path.basename, options_map['include'])))
241
242    def test_load_file_with_duplicate_includes(self):
243        with test_utils.tempdir() as tmp:
244            test_utils.create_temporary_files(tmp, {
245                'apples.beancount': """
246                  include "fruits/oranges.beancount"
247                  include "{root}/legumes/tomates.beancount"
248                  2014-01-01 open Assets:Apples
249                """,
250                'fruits/oranges.beancount': """
251                  include "../legumes/tomates.beancount"
252                  2014-01-02 open Assets:Oranges
253                """,
254                'legumes/tomates.beancount': """
255                  2014-01-03 open Assets:Tomates
256                """,
257                'legumes/patates.beancount': """
258                  2014-01-04 open Assets:Patates
259                """})
260            entries, errors, options_map = loader.load_file(
261                path.join(tmp, 'apples.beancount'))
262        self.assertTrue(errors)
263        self.assertEqual(3, len(entries))
264        self.assertEqual(['apples.beancount', 'oranges.beancount', 'tomates.beancount'],
265                         list(map(path.basename, options_map['include'])))
266
267    def test_load_string_with_relative_include(self):
268        with test_utils.tempdir() as tmp:
269            test_utils.create_temporary_files(tmp, {
270                'apples.beancount': """
271                  include "fruits/oranges.beancount"
272                  2014-01-01 open Assets:Apples
273                """,
274                'fruits/oranges.beancount': """
275                  2014-01-02 open Assets:Oranges
276                """})
277            try:
278                cwd = os.getcwd()
279                os.chdir(tmp)
280                entries, errors, options_map = loader.load_file(
281                    path.join(tmp, 'apples.beancount'))
282            finally:
283                os.chdir(cwd)
284        self.assertFalse(errors)
285        self.assertEqual(2, len(entries))
286        self.assertEqual(['apples.beancount', 'oranges.beancount'],
287                         list(map(path.basename, options_map['include'])))
288
289    def test_load_file_return_include_filenames(self):
290        # Also check that they are normalized paths.
291        with test_utils.tempdir() as tmp:
292            test_utils.create_temporary_files(tmp, {
293                'apples.beancount': """
294                  include "oranges.beancount"
295                  2014-01-01 open Assets:Apples
296                """,
297                'oranges.beancount': """
298                  include "bananas.beancount"
299                  2014-01-02 open Assets:Oranges
300                """,
301                'bananas.beancount': """
302                  2014-01-02 open Assets:Bananas
303                """})
304            entries, errors, options_map = loader.load_file(
305                path.join(tmp, 'apples.beancount'))
306        self.assertFalse(errors)
307        self.assertEqual(3, len(entries))
308        self.assertTrue(all(path.isabs(filename)
309                            for filename in options_map['include']))
310        self.assertEqual(['apples.beancount', 'bananas.beancount', 'oranges.beancount'],
311                         list(map(path.basename, options_map['include'])))
312
313
314class TestLoadIncludesEncrypted(encryption_test.TestEncryptedBase):
315
316    def test_include_encrypted(self):
317        with test_utils.tempdir() as tmpdir:
318            test_utils.create_temporary_files(tmpdir, {
319                'apples.beancount': """
320                  include "oranges.beancount.asc"
321                  2014-01-01 open Assets:Apples
322                """,
323                'oranges.beancount': """
324                  2014-01-02 open Assets:Oranges
325                """})
326
327            # Encrypt the oranges file and remove the unencrypted file.
328            with open(path.join(tmpdir, 'oranges.beancount')) as infile:
329                self.encrypt_as_file(infile.read(),
330                                     path.join(tmpdir, 'oranges.beancount.asc'))
331            os.remove(path.join(tmpdir, 'oranges.beancount'))
332
333            # Load the top-level file which includes the encrypted file.
334            with test_utils.environ('GNUPGHOME', self.ringdir):
335                entries, errors, options_map = loader.load_file(
336                    path.join(tmpdir, 'apples.beancount'))
337
338        self.assertFalse(errors)
339        self.assertEqual(2, len(entries))
340        self.assertRegex(entries[0].meta['filename'], 'apples.beancount')
341        self.assertRegex(entries[1].meta['filename'], 'oranges.+count.asc')
342
343
344class TestLoadCache(unittest.TestCase):
345
346    def setUp(self):
347        self.num_calls = 0
348        cache_getter = functools.partial(loader.get_cache_filename,
349                                         loader.PICKLE_CACHE_FILENAME)
350        mock.patch('beancount.loader._load_file',
351                   loader.pickle_cache_function(cache_getter,
352                                                0,  # No time threshold.
353                                                self._load_file)).start()
354    def tearDown(self):
355        mock.patch.stopall()
356
357    def _load_file(self, filename, *args, **kw):
358        self.num_calls += 1
359        return loader._load([(filename, True)], *args, **kw)
360
361    def test_load_cache(self):
362        # Create an initial set of files and load file, thus creating a cache.
363        with test_utils.tempdir() as tmp:
364            test_utils.create_temporary_files(tmp, {
365                'apples.beancount': """
366                  include "oranges.beancount"
367                  2014-01-01 open Assets:Apples
368                """,
369                'oranges.beancount': """
370                  include "bananas.beancount"
371                  2014-01-02 open Assets:Oranges
372                """,
373                'bananas.beancount': """
374                  2014-01-02 open Assets:Bananas
375                """})
376            top_filename = path.join(tmp, 'apples.beancount')
377            entries, errors, options_map = loader.load_file(top_filename)
378            self.assertFalse(errors)
379            self.assertEqual(3, len(entries))
380            self.assertEqual(1, self.num_calls)
381
382            # Make sure the cache was created.
383            self.assertTrue(path.exists(path.join(tmp, '.apples.beancount.picklecache')))
384
385            # Load the root file again, make sure the cache is being hit.
386            entries, errors, options_map = loader.load_file(top_filename)
387            self.assertEqual(1, self.num_calls)
388
389            # Touch the top-level file and ensure it's a cache miss.
390            with open(top_filename, 'a') as file:
391                file.write('\n')
392            entries, errors, options_map = loader.load_file(top_filename)
393            self.assertEqual(2, self.num_calls)
394
395            # Load the root file again, make sure the cache is being hit.
396            entries, errors, options_map = loader.load_file(top_filename)
397            self.assertEqual(2, self.num_calls)
398
399            # Touch the top-level file and ensure it's a cache miss.
400            with open(top_filename, 'a') as file:
401                file.write('\n')
402            entries, errors, options_map = loader.load_file(top_filename)
403            self.assertEqual(3, self.num_calls)
404
405    def test_load_cache_moved_file(self):
406        # Create an initial set of files and load file, thus creating a cache.
407        with test_utils.tempdir() as tmp:
408            test_utils.create_temporary_files(tmp, {
409                'apples.beancount': """
410                  include "oranges.beancount"
411                  2014-01-01 open Assets:Apples
412                """,
413                'oranges.beancount': """
414                  2014-01-02 open Assets:Oranges
415                """})
416            top_filename = path.join(tmp, 'apples.beancount')
417            entries, errors, options_map = loader.load_file(top_filename)
418            self.assertFalse(errors)
419            self.assertEqual(2, len(entries))
420            self.assertEqual(1, self.num_calls)
421
422            # Make sure the cache was created.
423            self.assertTrue(path.exists(path.join(tmp, '.apples.beancount.picklecache')))
424
425            # Check that it doesn't need refresh
426            self.assertFalse(loader.needs_refresh(options_map))
427
428            # Move the input file.
429            new_top_filename = path.join(tmp, 'bigapples.beancount')
430            os.rename(top_filename, new_top_filename)
431
432            # Check that it needs refresh.
433            self.assertTrue(loader.needs_refresh(options_map))
434
435            # Load the root file again, make sure the cache is being hit.
436            entries, errors, options_map = loader.load_file(top_filename)
437            self.assertEqual(2, self.num_calls)
438
439    @mock.patch('os.remove', side_effect=OSError)
440    @mock.patch('logging.warning')
441    def test_load_cache_read_only_fs(self, remove_mock, warn_mock):
442        # Create an initial set of files and load file, thus creating a cache.
443        with test_utils.tempdir() as tmp:
444            test_utils.create_temporary_files(tmp, {
445                'apples.beancount': """
446                  2014-01-01 open Assets:Apples
447                """})
448            filename = path.join(tmp, 'apples.beancount')
449            entries, errors, options_map = loader.load_file(filename)
450            with open(filename, 'w'): pass
451            entries, errors, options_map = loader.load_file(filename)
452            self.assertEqual(1, len(warn_mock.mock_calls))
453
454    @mock.patch('beancount.loader.PICKLE_CACHE_THRESHOLD', 0.0)
455    @mock.patch.object(loader, 'load_file', loader.load_file)
456    def test_load_cache_override_filename_pattern_by_env_var(self):
457        with test_utils.environ('BEANCOUNT_LOAD_CACHE_FILENAME', '__{filename}__'):
458            loader.initialize(use_cache=True)
459            with test_utils.tempdir() as tmp:
460                test_utils.create_temporary_files(tmp, {
461                    'apples.beancount': """
462                      2014-01-01 open Assets:Apples
463                    """})
464                filename = path.join(tmp, 'apples.beancount')
465                entries, errors, options_map = loader.load_file(filename)
466                self.assertEqual({'__apples.beancount__', 'apples.beancount'},
467                                 set(os.listdir(tmp)))
468
469    @mock.patch('beancount.loader.PICKLE_CACHE_THRESHOLD', 0.0)
470    @mock.patch.object(loader, 'load_file', loader.load_file)
471    def test_load_cache_override_filename_pattern_by_argument(self):
472        with test_utils.tempdir() as tmp:
473            cache_filename = path.join(tmp, "__{filename}__")
474            loader.initialize(use_cache=True, cache_filename=cache_filename)
475            test_utils.create_temporary_files(tmp, {
476                'apples.beancount': """
477                  2014-01-01 open Assets:Apples
478                """})
479            filename = path.join(tmp, 'apples.beancount')
480            entries, errors, options_map = loader.load_file(filename)
481            self.assertEqual({'__apples.beancount__', 'apples.beancount'},
482                             set(os.listdir(tmp)))
483
484    @mock.patch('beancount.loader.PICKLE_CACHE_THRESHOLD', 0.0)
485    @mock.patch.object(loader, 'load_file', loader.load_file)
486    def test_load_cache_disable(self):
487        with test_utils.tempdir() as tmp:
488            cache_filename = path.join(tmp, "__{filename}__")
489            for kwargs in [dict(use_cache=False),
490                           dict(use_cache=False, cache_filename=cache_filename)]:
491                loader.initialize(**kwargs)
492                test_utils.create_temporary_files(tmp, {
493                    'apples.beancount': """
494                      2014-01-01 open Assets:Apples
495                    """})
496                filename = path.join(tmp, 'apples.beancount')
497                entries, errors, options_map = loader.load_file(filename)
498                self.assertEqual({'apples.beancount'}, set(os.listdir(tmp)))
499
500
501class TestEncoding(unittest.TestCase):
502
503    def test_string_unicode(self):
504        utf8_bytes = textwrap.dedent("""
505          2015-01-01 open Assets:Something
506          2015-05-23 note Assets:Something "¡¢£¤¥¦§¨©ª«¬®¯°±²³´µ¶·¸¹º»¼ "
507        """).encode('utf-8')
508        entries, errors, options_map = loader.load_string(utf8_bytes, encoding='utf8')
509        self.assertFalse(errors)
510
511    def test_string_latin1(self):
512        utf8_bytes = textwrap.dedent("""
513          2015-01-01 open Assets:Something
514          2015-05-23 note Assets:Something "¡¢£¤¥¦§¨©ª«¬®¯°±²³´µ¶·¸¹º»¼ "
515        """).encode('latin1')
516        entries, errors, options_map = loader.load_string(utf8_bytes, encoding='latin1')
517        self.assertFalse(errors)
518
519
520class TestOptionsAggregation(unittest.TestCase):
521
522    def test_aggregate_operating_currencies(self):
523        with test_utils.tempdir() as tmp:
524            test_utils.create_temporary_files(tmp, {
525                'apples.beancount': """
526                  include "oranges.beancount"
527                  include "bananas.beancount"
528                  option "operating_currency" "USD"
529                """,
530                'oranges.beancount': """
531                  option "operating_currency" "CAD"
532                """,
533                'bananas.beancount': """
534                  option "operating_currency" "EUR"
535                """})
536            top_filename = path.join(tmp, 'apples.beancount')
537            entries, errors, options_map = loader.load_file(top_filename)
538
539            self.assertEqual({'USD', 'EUR', 'CAD'}, set(options_map['operating_currency']))
540
541
542if __name__ == '__main__':
543    unittest.main()
544