1"""Loader code. This is the main entry point to load up a file.
2"""
3__copyright__ = "Copyright (C) 2013-2016  Martin Blais"
4__license__ = "GNU GPLv2"
5
6from os import path
7import collections
8import functools
9import glob
10import hashlib
11import importlib
12import io
13import itertools
14import logging
15import os
16import pickle
17import struct
18import textwrap
19import time
20import warnings
21from typing import Optional
22
23from beancount.utils import misc_utils
24from beancount.core import data
25from beancount.parser import parser
26from beancount.parser import booking
27from beancount.parser import options
28from beancount.parser import printer
29from beancount.ops import validation
30from beancount.utils import encryption
31from beancount.utils import file_utils
32
33
34LoadError = collections.namedtuple('LoadError', 'source message entry')
35
36
37# List of default plugins to run.
38DEFAULT_PLUGINS_PRE = [
39    ("beancount.ops.pad", None),
40    ("beancount.ops.documents", None),
41    ]
42
43DEFAULT_PLUGINS_POST = [
44    ("beancount.ops.balance", None),
45    ]
46
47# A mapping of modules to warn about, to their renamed names.
48RENAMED_MODULES = {}
49
50
51# Filename pattern for the pickle-cache.
52PICKLE_CACHE_FILENAME = '.{filename}.picklecache'
53
54# The runtime threshold below which we don't bother creating a cache file, in
55# seconds.
56PICKLE_CACHE_THRESHOLD = 1.0
57
58
59def load_file(filename, log_timings=None, log_errors=None, extra_validations=None,
60              encoding=None):
61    """Open a Beancount input file, parse it, run transformations and validate.
62
63    Args:
64      filename: The name of the file to be parsed.
65      log_timings: A file object or function to write timings to,
66        or None, if it should remain quiet.
67      log_errors: A file object or function to write errors to,
68        or None, if it should remain quiet.
69      extra_validations: A list of extra validation functions to run after loading
70        this list of entries.
71      encoding: A string or None, the encoding to decode the input filename with.
72    Returns:
73      A triple of (entries, errors, option_map) where "entries" is a date-sorted
74      list of entries from the file, "errors" a list of error objects generated
75      while parsing and validating the file, and "options_map", a dict of the
76      options parsed from the file.
77    """
78    filename = path.expandvars(path.expanduser(filename))
79    if not path.isabs(filename):
80        filename = path.normpath(path.join(os.getcwd(), filename))
81
82    if encryption.is_encrypted_file(filename):
83        # Note: Caching is not supported for encrypted files.
84        entries, errors, options_map = load_encrypted_file(
85            filename,
86            log_timings, log_errors,
87            extra_validations, False, encoding)
88    else:
89        entries, errors, options_map = _load_file(
90            filename, log_timings,
91            extra_validations, encoding)
92        _log_errors(errors, log_errors)
93    return entries, errors, options_map
94
95
96def load_encrypted_file(filename, log_timings=None, log_errors=None, extra_validations=None,
97                        dedent=False, encoding=None):
98    """Load an encrypted Beancount input file.
99
100    Args:
101      filename: The name of an encrypted file to be parsed.
102      log_timings: See load_string().
103      log_errors: See load_string().
104      extra_validations: See load_string().
105      dedent: See load_string().
106      encoding: See load_string().
107    Returns:
108      A triple of (entries, errors, option_map) where "entries" is a date-sorted
109      list of entries from the file, "errors" a list of error objects generated
110      while parsing and validating the file, and "options_map", a dict of the
111      options parsed from the file.
112    """
113    contents = encryption.read_encrypted_file(filename)
114    return load_string(contents,
115                       log_timings=log_timings,
116                       log_errors=log_errors,
117                       extra_validations=extra_validations,
118                       encoding=encoding)
119
120
121def _log_errors(errors, log_errors):
122    """Log errors, if 'log_errors' is set.
123
124    Args:
125      log_errors: A file object or function to write errors to,
126        or None, if it should remain quiet.
127    """
128    if log_errors and errors:
129        if hasattr(log_errors, 'write'):
130            printer.print_errors(errors, file=log_errors)
131        else:
132            error_io = io.StringIO()
133            printer.print_errors(errors, file=error_io)
134            log_errors(error_io.getvalue())
135
136
137def get_cache_filename(pattern: str, filename: str) -> str:
138    """Compute the cache filename from a given pattern and the top-level filename.
139
140    Args:
141      pattern: A cache filename or pattern. If the pattern contains '{filename}' this
142        will get replaced by the top-level filename. This may be absolute or relative.
143      filename: The top-level filename.
144    Returns:
145      The resolved cache filename.
146    """
147    abs_filename = path.abspath(filename)
148    if path.isabs(pattern):
149        abs_pattern = pattern
150    else:
151        abs_pattern = path.join(path.dirname(abs_filename), pattern)
152    return abs_pattern.format(filename=path.basename(filename))
153
154
155def pickle_cache_function(cache_getter, time_threshold, function):
156    """Decorate a loader function to make it loads its result from a pickle cache.
157
158    This considers the first argument as a top-level filename and assumes the
159    function to be cached returns an (entries, errors, options_map) triple. We
160    use the 'include' option value in order to check whether any of the included
161    files has changed. It's essentially a special case for an on-disk memoizer.
162    If any of the included files are more recent than the cache, the function is
163    recomputed and the cache refreshed.
164
165    Args:
166      cache_getter: A function of one argument, the top-level filename, which
167        will return the name of the corresponding cache file.
168      time_threshold: A float, the number of seconds below which we don't bother
169        caching.
170      function: A function object to decorate for caching.
171    Returns:
172      A decorated function which will pull its result from a cache file if
173      it is available.
174    """
175    @functools.wraps(function)
176    def wrapped(toplevel_filename, *args, **kw):
177        cache_filename = cache_getter(toplevel_filename)
178
179        # Read the cache if it exists in order to get the list of files whose
180        # timestamps to check.
181        exists = path.exists(cache_filename)
182        if exists:
183            with open(cache_filename, 'rb') as file:
184                try:
185                    result = pickle.load(file)
186                except Exception as exc:
187                    # Note: Not a big fan of doing this, but here we handle all
188                    # possible exceptions because unpickling of an old or
189                    # corrupted pickle file manifests as a variety of different
190                    # exception types.
191
192                    # The cache file is corrupted; ignore it and recompute.
193                    logging.error("Cache file is corrupted: %s; recomputing.", exc)
194                    result = None
195
196                else:
197                    # Check that the latest timestamp has not been written after the
198                    # cache file.
199                    entries, errors, options_map = result
200                    if not needs_refresh(options_map):
201                        # All timestamps are legit; cache hit.
202                        return result
203
204        # We failed; recompute the value.
205        if exists:
206            try:
207                os.remove(cache_filename)
208            except OSError as exc:
209                # Warn for errors on read-only filesystems.
210                logging.warning("Could not remove picklecache file %s: %s",
211                                cache_filename, exc)
212
213        time_before = time.time()
214        result = function(toplevel_filename, *args, **kw)
215        time_after = time.time()
216
217        # Overwrite the cache file if the time it takes to compute it
218        # justifies it.
219        if time_after - time_before > time_threshold:
220            try:
221                with open(cache_filename, 'wb') as file:
222                    pickle.dump(result, file)
223            except Exception as exc:
224                logging.warning("Could not write to picklecache file %s: %s",
225                                cache_filename, exc)
226
227        return result
228    return wrapped
229
230
231def delete_cache_function(cache_getter, function):
232    """A wrapper that removes the cached filename.
233
234    Args:
235      cache_getter: A function of one argument, the top-level filename, which
236        will return the name of the corresponding cache file.
237      function: A function object to decorate for caching.
238    Returns:
239      A decorated function which will delete the cached filename, if it exists.
240    """
241    @functools.wraps(function)
242    def wrapped(toplevel_filename, *args, **kw):
243        # Delete the cache.
244        cache_filename = cache_getter(toplevel_filename)
245        if path.exists(cache_filename):
246            os.remove(cache_filename)
247
248        # Invoke the original function.
249        return function(toplevel_filename, *args, **kw)
250    return wrapped
251
252
253def _uncached_load_file(filename, *args, **kw):
254    """Delegate to _load. Note: This gets conditionally advised by caching below."""
255    return _load([(filename, True)], *args, **kw)
256
257
258def needs_refresh(options_map):
259    """Predicate that returns true if at least one of the input files may have changed.
260
261    Args:
262      options_map: An options dict as per the parser.
263      mtime: A modified time, to check if it covers the include files in the options_map.
264    Returns:
265      A boolean, true if the input is obsoleted by changes in the input files.
266    """
267    if options_map is None:
268        return True
269    input_hash = compute_input_hash(options_map['include'])
270    return 'input_hash' not in options_map or input_hash != options_map['input_hash']
271
272
273def compute_input_hash(filenames):
274    """Compute a hash of the input data.
275
276    Args:
277      filenames: A list of input files. Order is not relevant.
278    """
279    md5 = hashlib.md5()
280    for filename in sorted(filenames):
281        md5.update(filename.encode('utf8'))
282        if not path.exists(filename):
283            continue
284        stat = os.stat(filename)
285        md5.update(struct.pack('dd', stat.st_mtime_ns, stat.st_size))
286    return md5.hexdigest()
287
288
289def load_string(string, log_timings=None, log_errors=None, extra_validations=None,
290                dedent=False, encoding=None):
291
292    """Open a Beancount input string, parse it, run transformations and validate.
293
294    Args:
295      string: A Beancount input string.
296      log_timings: A file object or function to write timings to,
297        or None, if it should remain quiet.
298      log_errors: A file object or function to write errors to,
299        or None, if it should remain quiet.
300      extra_validations: A list of extra validation functions to run after loading
301        this list of entries.
302      dedent: A boolean, if set, remove the whitespace in front of the lines.
303      encoding: A string or None, the encoding to decode the input string with.
304    Returns:
305      A triple of (entries, errors, option_map) where "entries" is a date-sorted
306      list of entries from the string, "errors" a list of error objects
307      generated while parsing and validating the string, and "options_map", a
308      dict of the options parsed from the string.
309    """
310    if dedent:
311        string = textwrap.dedent(string)
312    entries, errors, options_map = _load([(string, False)], log_timings,
313                                         extra_validations, encoding)
314    _log_errors(errors, log_errors)
315    return entries, errors, options_map
316
317
318def _parse_recursive(sources, log_timings, encoding=None):
319    """Parse Beancount input, run its transformations and validate it.
320
321    Recursively parse a list of files or strings and their include files and
322    return an aggregate of parsed directives, errors, and the top-level
323    options-map. If the same file is being parsed twice, ignore it and issue an
324    error.
325
326    Args:
327      sources: A list of (filename-or-string, is-filename) where the first
328        element is a string, with either a filename or a string to be parsed directly,
329        and the second argument is a boolean that is true if the first is a filename.
330        You may provide a list of such arguments to be parsed. Filenames must be absolute
331        paths.
332      log_timings: A function to write timings to, or None, if it should remain quiet.
333      encoding: A string or None, the encoding to decode the input filename with.
334    Returns:
335      A tuple of (entries, parse_errors, options_map).
336    """
337    assert isinstance(sources, list) and all(isinstance(el, tuple) for el in sources)
338
339    # Current parse state.
340    entries, parse_errors = [], []
341    options_map = None
342
343    # A stack of sources to be parsed.
344    source_stack = list(sources)
345
346    # A list of absolute filenames that have been parsed in the past, used to
347    # detect and avoid duplicates (cycles).
348    filenames_seen = set()
349
350    with misc_utils.log_time('beancount.parser.parser', log_timings, indent=1):
351        while source_stack:
352            source, is_file = source_stack.pop(0)
353            is_top_level = options_map is None
354
355            # If the file is encrypted, read it in and process it as a string.
356            if is_file:
357                cwd = path.dirname(source)
358                source_filename = source
359                if encryption.is_encrypted_file(source):
360                    source = encryption.read_encrypted_file(source)
361                    is_file = False
362            else:
363                # If we're parsing a string, the CWD is the current process
364                # working directory.
365                cwd = os.getcwd()
366                source_filename = None
367
368            if is_file:
369                # All filenames here must be absolute.
370                assert path.isabs(source)
371                filename = path.normpath(source)
372
373                # Check for file previously parsed... detect duplicates.
374                if filename in filenames_seen:
375                    parse_errors.append(
376                        LoadError(data.new_metadata("<load>", 0),
377                                  'Duplicate filename parsed: "{}"'.format(filename),
378                                  None))
379                    continue
380
381                # Check for a file that does not exist.
382                if not path.exists(filename):
383                    parse_errors.append(
384                        LoadError(data.new_metadata("<load>", 0),
385                                  'File "{}" does not exist'.format(filename), None))
386                    continue
387
388                # Parse a file from disk directly.
389                filenames_seen.add(filename)
390                with misc_utils.log_time('beancount.parser.parser.parse_file',
391                                         log_timings, indent=2):
392                    (src_entries,
393                     src_errors,
394                     src_options_map) = parser.parse_file(filename, encoding=encoding)
395
396                cwd = path.dirname(filename)
397            else:
398                # Encode the contents if necessary.
399                if encoding:
400                    if isinstance(source, bytes):
401                        source = source.decode(encoding)
402                    source = source.encode('ascii', 'replace')
403
404                # Parse a string buffer from memory.
405                with misc_utils.log_time('beancount.parser.parser.parse_string',
406                                         log_timings, indent=2):
407                    (src_entries,
408                     src_errors,
409                     src_options_map) = parser.parse_string(source, source_filename)
410
411            # Merge the entries resulting from the parsed file.
412            entries.extend(src_entries)
413            parse_errors.extend(src_errors)
414
415            # We need the options from the very top file only (the very
416            # first file being processed). No merging of options should
417            # occur.
418            if is_top_level:
419                options_map = src_options_map
420            else:
421                aggregate_options_map(options_map, src_options_map)
422
423            # Add includes to the list of sources to process. chdir() for glob,
424            # which uses it indirectly.
425            include_expanded = []
426            with file_utils.chdir(cwd):
427                for include_filename in src_options_map['include']:
428                    matched_filenames = glob.glob(include_filename, recursive=True)
429                    if matched_filenames:
430                        include_expanded.extend(matched_filenames)
431                    else:
432                        parse_errors.append(
433                            LoadError(data.new_metadata("<load>", 0),
434                                      'File glob "{}" does not match any files'.format(
435                                          include_filename), None))
436            for include_filename in include_expanded:
437                if not path.isabs(include_filename):
438                    include_filename = path.join(cwd, include_filename)
439                include_filename = path.normpath(include_filename)
440
441                # Add the include filenames to be processed later.
442                source_stack.append((include_filename, True))
443
444    # Make sure we have at least a dict of valid options.
445    if options_map is None:
446        options_map = options.OPTIONS_DEFAULTS.copy()
447
448    # Save the set of parsed filenames in options_map.
449    options_map['include'] = sorted(filenames_seen)
450
451    return entries, parse_errors, options_map
452
453
454def aggregate_options_map(options_map, src_options_map):
455    """Aggregate some of the attributes of options map.
456
457    Args:
458      options_map: The target map in which we want to aggregate attributes.
459        Note: This value is mutated in-place.
460      src_options_map: A source map whose values we'd like to see aggregated.
461    """
462    op_currencies = options_map["operating_currency"]
463    for currency in src_options_map["operating_currency"]:
464        if currency not in op_currencies:
465            op_currencies.append(currency)
466
467
468def _load(sources, log_timings, extra_validations, encoding):
469    """Parse Beancount input, run its transformations and validate it.
470
471    (This is an internal method.)
472    This routine does all that is necessary to obtain a list of entries ready
473    for realization and working with them. This is the principal call for of the
474    scripts that load a ledger. It returns a list of entries transformed and
475    ready for reporting, a list of errors, and parser's options dict.
476
477    Args:
478      sources: A list of (filename-or-string, is-filename) where the first
479        element is a string, with either a filename or a string to be parsed directly,
480        and the second argument is a boolean that is true if the first is a filename.
481        You may provide a list of such arguments to be parsed. Filenames must be absolute
482        paths.
483      log_timings: A file object or function to write timings to,
484        or None, if it should remain quiet.
485      extra_validations: A list of extra validation functions to run after loading
486        this list of entries.
487      encoding: A string or None, the encoding to decode the input filename with.
488    Returns:
489      See load() or load_string().
490    """
491    assert isinstance(sources, list) and all(isinstance(el, tuple) for el in sources)
492
493    if hasattr(log_timings, 'write'):
494        log_timings = log_timings.write
495
496    # Parse all the files recursively. Ensure that the entries are sorted before
497    # running any processes on them.
498    with misc_utils.log_time('parse', log_timings, indent=1):
499        entries, parse_errors, options_map = _parse_recursive(
500            sources, log_timings, encoding)
501        entries.sort(key=data.entry_sortkey)
502
503    # Run interpolation on incomplete entries.
504    with misc_utils.log_time('booking', log_timings, indent=1):
505        entries, balance_errors = booking.book(entries, options_map)
506        parse_errors.extend(balance_errors)
507
508    # Transform the entries.
509    with misc_utils.log_time('run_transformations', log_timings, indent=1):
510        entries, errors = run_transformations(entries, parse_errors, options_map,
511                                              log_timings)
512
513    # Validate the list of entries.
514    with misc_utils.log_time('beancount.ops.validate', log_timings, indent=1):
515        valid_errors = validation.validate(entries, options_map, log_timings,
516                                           extra_validations)
517        errors.extend(valid_errors)
518
519        # Note: We could go hardcore here and further verify that the entries
520        # haven't been modified by user-provided validation routines, by
521        # comparing hashes before and after. Not needed for now.
522
523    # Compute the input hash.
524    options_map['input_hash'] = compute_input_hash(options_map['include'])
525
526    return entries, errors, options_map
527
528
529def run_transformations(entries, parse_errors, options_map, log_timings):
530    """Run the various transformations on the entries.
531
532    This is where entries are being synthesized, checked, plugins are run, etc.
533
534    Args:
535      entries: A list of directives as read from the parser.
536      parse_errors: A list of errors so far.
537      options_map: An options dict as read from the parser.
538      log_timings: A function to write timing log entries to, or None, if it
539        should be quiet.
540    Returns:
541      A list of modified entries, and a list of errors, also possibly modified.
542    """
543    # A list of errors to extend (make a copy to avoid modifying the input).
544    errors = list(parse_errors)
545
546    # Process the plugins.
547    if options_map['plugin_processing_mode'] == 'raw':
548        plugins_iter = options_map["plugin"]
549    elif options_map['plugin_processing_mode'] == 'default':
550        plugins_iter = itertools.chain(DEFAULT_PLUGINS_PRE,
551                                       options_map["plugin"],
552                                       DEFAULT_PLUGINS_POST)
553    else:
554        assert "Invalid value for plugin_processing_mode: {}".format(
555            options_map['plugin_processing_mode'])
556
557    for plugin_name, plugin_config in plugins_iter:
558
559        # Issue a warning on a renamed module.
560        renamed_name = RENAMED_MODULES.get(plugin_name, None)
561        if renamed_name:
562            warnings.warn("Deprecation notice: Module '{}' has been renamed to '{}'; "
563                          "please adjust your plugin directive.".format(
564                              plugin_name, renamed_name))
565            plugin_name = renamed_name
566
567        # Try to import the module.
568        try:
569            module = importlib.import_module(plugin_name)
570            if not hasattr(module, '__plugins__'):
571                continue
572
573            with misc_utils.log_time(plugin_name, log_timings, indent=2):
574
575                # Run each transformer function in the plugin.
576                for function_name in module.__plugins__:
577                    if isinstance(function_name, str):
578                        # Support plugin functions provided by name.
579                        callback = getattr(module, function_name)
580                    else:
581                        # Support function types directly, not just names.
582                        callback = function_name
583
584                    if plugin_config is not None:
585                        entries, plugin_errors = callback(entries, options_map,
586                                                          plugin_config)
587                    else:
588                        entries, plugin_errors = callback(entries, options_map)
589                    errors.extend(plugin_errors)
590
591            # Ensure that the entries are sorted. Don't trust the plugins
592            # themselves.
593            entries.sort(key=data.entry_sortkey)
594
595        except (ImportError, TypeError) as exc:
596            # Upon failure, just issue an error.
597            errors.append(LoadError(data.new_metadata("<load>", 0),
598                                    'Error importing "{}": {}'.format(
599                                        plugin_name, str(exc)), None))
600
601    return entries, errors
602
603
604def combine_plugins(*plugin_modules):
605    """Combine the plugins from the given plugin modules.
606
607    This is used to create plugins of plugins.
608    Args:
609      *plugins_modules: A sequence of module objects.
610    Returns:
611      A list that can be assigned to the new module's __plugins__ attribute.
612    """
613    modules = []
614    for module in plugin_modules:
615        modules.extend([getattr(module, name)
616                        for name in module.__plugins__])
617    return modules
618
619
620def load_doc(expect_errors=False):
621    """A factory of decorators that loads the docstring and calls the function with entries.
622
623    This is an incredibly convenient tool to write lots of tests. Write a
624    unittest using the standard TestCase class and put the input entries in the
625    function's docstring.
626
627    Args:
628      expect_errors: A boolean or None, with the following semantics,
629        True: Expect errors and fail if there are none.
630        False: Expect no errors and fail if there are some.
631        None: Do nothing, no check.
632    Returns:
633      A wrapped method that accepts a single 'self' argument.
634    """
635    def decorator(fun):
636        """A decorator that parses the function's docstring as an argument.
637
638        Args:
639          fun: A callable method, that accepts the three return arguments that
640              load() returns.
641        Returns:
642          A decorated test function.
643        """
644        @functools.wraps(fun)
645        def wrapper(self):
646            entries, errors, options_map = load_string(fun.__doc__, dedent=True)
647
648            if expect_errors is not None:
649                if expect_errors is False and errors:
650                    oss = io.StringIO()
651                    printer.print_errors(errors, file=oss)
652                    self.fail("Unexpected errors found:\n{}".format(oss.getvalue()))
653                elif expect_errors is True and not errors:
654                    self.fail("Expected errors, none found:")
655
656            # Note: Even if we expected no errors, we call this function with an
657            # empty 'errors' list. This is so that the interface does not change
658            # based on the arguments to the decorator, which would be somewhat
659            # ugly and which would require explanation.
660            return fun(self, entries, errors, options_map)
661
662        wrapper.__input__ = wrapper.__doc__
663        wrapper.__doc__ = None
664        return wrapper
665
666    return decorator
667
668
669def initialize(use_cache: bool, cache_filename: Optional[str] = None):
670    """Initialize the loader."""
671
672    # Unless an environment variable disables it, use the pickle load cache
673    # automatically. Note that this works across all Python programs running the
674    # loader which is why it's located here.
675    # pylint: disable=invalid-name
676    global _load_file
677
678    # Make a function to compute the cache filename.
679    cache_pattern = (cache_filename or
680                     os.getenv('BEANCOUNT_LOAD_CACHE_FILENAME') or
681                     PICKLE_CACHE_FILENAME)
682    cache_getter = functools.partial(get_cache_filename, cache_pattern)
683
684    if use_cache:
685        _load_file = pickle_cache_function(cache_getter, PICKLE_CACHE_THRESHOLD,
686                                           _uncached_load_file)
687    else:
688        if cache_filename is not None:
689            logging.warning("Cache disabled; "
690                            "Explicitly overridden cache filename %s will be ignored.",
691                            cache_filename)
692        _load_file = delete_cache_function(cache_getter,
693                                           _uncached_load_file)
694
695
696# Default is to use the cache every time.
697initialize(os.getenv('BEANCOUNT_DISABLE_LOAD_CACHE') is None)
698