1"""Loader code. This is the main entry point to load up a file. 2""" 3__copyright__ = "Copyright (C) 2013-2016 Martin Blais" 4__license__ = "GNU GPLv2" 5 6from os import path 7import collections 8import functools 9import glob 10import hashlib 11import importlib 12import io 13import itertools 14import logging 15import os 16import pickle 17import struct 18import textwrap 19import time 20import warnings 21from typing import Optional 22 23from beancount.utils import misc_utils 24from beancount.core import data 25from beancount.parser import parser 26from beancount.parser import booking 27from beancount.parser import options 28from beancount.parser import printer 29from beancount.ops import validation 30from beancount.utils import encryption 31from beancount.utils import file_utils 32 33 34LoadError = collections.namedtuple('LoadError', 'source message entry') 35 36 37# List of default plugins to run. 38DEFAULT_PLUGINS_PRE = [ 39 ("beancount.ops.pad", None), 40 ("beancount.ops.documents", None), 41 ] 42 43DEFAULT_PLUGINS_POST = [ 44 ("beancount.ops.balance", None), 45 ] 46 47# A mapping of modules to warn about, to their renamed names. 48RENAMED_MODULES = {} 49 50 51# Filename pattern for the pickle-cache. 52PICKLE_CACHE_FILENAME = '.{filename}.picklecache' 53 54# The runtime threshold below which we don't bother creating a cache file, in 55# seconds. 56PICKLE_CACHE_THRESHOLD = 1.0 57 58 59def load_file(filename, log_timings=None, log_errors=None, extra_validations=None, 60 encoding=None): 61 """Open a Beancount input file, parse it, run transformations and validate. 62 63 Args: 64 filename: The name of the file to be parsed. 65 log_timings: A file object or function to write timings to, 66 or None, if it should remain quiet. 67 log_errors: A file object or function to write errors to, 68 or None, if it should remain quiet. 69 extra_validations: A list of extra validation functions to run after loading 70 this list of entries. 71 encoding: A string or None, the encoding to decode the input filename with. 72 Returns: 73 A triple of (entries, errors, option_map) where "entries" is a date-sorted 74 list of entries from the file, "errors" a list of error objects generated 75 while parsing and validating the file, and "options_map", a dict of the 76 options parsed from the file. 77 """ 78 filename = path.expandvars(path.expanduser(filename)) 79 if not path.isabs(filename): 80 filename = path.normpath(path.join(os.getcwd(), filename)) 81 82 if encryption.is_encrypted_file(filename): 83 # Note: Caching is not supported for encrypted files. 84 entries, errors, options_map = load_encrypted_file( 85 filename, 86 log_timings, log_errors, 87 extra_validations, False, encoding) 88 else: 89 entries, errors, options_map = _load_file( 90 filename, log_timings, 91 extra_validations, encoding) 92 _log_errors(errors, log_errors) 93 return entries, errors, options_map 94 95 96def load_encrypted_file(filename, log_timings=None, log_errors=None, extra_validations=None, 97 dedent=False, encoding=None): 98 """Load an encrypted Beancount input file. 99 100 Args: 101 filename: The name of an encrypted file to be parsed. 102 log_timings: See load_string(). 103 log_errors: See load_string(). 104 extra_validations: See load_string(). 105 dedent: See load_string(). 106 encoding: See load_string(). 107 Returns: 108 A triple of (entries, errors, option_map) where "entries" is a date-sorted 109 list of entries from the file, "errors" a list of error objects generated 110 while parsing and validating the file, and "options_map", a dict of the 111 options parsed from the file. 112 """ 113 contents = encryption.read_encrypted_file(filename) 114 return load_string(contents, 115 log_timings=log_timings, 116 log_errors=log_errors, 117 extra_validations=extra_validations, 118 encoding=encoding) 119 120 121def _log_errors(errors, log_errors): 122 """Log errors, if 'log_errors' is set. 123 124 Args: 125 log_errors: A file object or function to write errors to, 126 or None, if it should remain quiet. 127 """ 128 if log_errors and errors: 129 if hasattr(log_errors, 'write'): 130 printer.print_errors(errors, file=log_errors) 131 else: 132 error_io = io.StringIO() 133 printer.print_errors(errors, file=error_io) 134 log_errors(error_io.getvalue()) 135 136 137def get_cache_filename(pattern: str, filename: str) -> str: 138 """Compute the cache filename from a given pattern and the top-level filename. 139 140 Args: 141 pattern: A cache filename or pattern. If the pattern contains '{filename}' this 142 will get replaced by the top-level filename. This may be absolute or relative. 143 filename: The top-level filename. 144 Returns: 145 The resolved cache filename. 146 """ 147 abs_filename = path.abspath(filename) 148 if path.isabs(pattern): 149 abs_pattern = pattern 150 else: 151 abs_pattern = path.join(path.dirname(abs_filename), pattern) 152 return abs_pattern.format(filename=path.basename(filename)) 153 154 155def pickle_cache_function(cache_getter, time_threshold, function): 156 """Decorate a loader function to make it loads its result from a pickle cache. 157 158 This considers the first argument as a top-level filename and assumes the 159 function to be cached returns an (entries, errors, options_map) triple. We 160 use the 'include' option value in order to check whether any of the included 161 files has changed. It's essentially a special case for an on-disk memoizer. 162 If any of the included files are more recent than the cache, the function is 163 recomputed and the cache refreshed. 164 165 Args: 166 cache_getter: A function of one argument, the top-level filename, which 167 will return the name of the corresponding cache file. 168 time_threshold: A float, the number of seconds below which we don't bother 169 caching. 170 function: A function object to decorate for caching. 171 Returns: 172 A decorated function which will pull its result from a cache file if 173 it is available. 174 """ 175 @functools.wraps(function) 176 def wrapped(toplevel_filename, *args, **kw): 177 cache_filename = cache_getter(toplevel_filename) 178 179 # Read the cache if it exists in order to get the list of files whose 180 # timestamps to check. 181 exists = path.exists(cache_filename) 182 if exists: 183 with open(cache_filename, 'rb') as file: 184 try: 185 result = pickle.load(file) 186 except Exception as exc: 187 # Note: Not a big fan of doing this, but here we handle all 188 # possible exceptions because unpickling of an old or 189 # corrupted pickle file manifests as a variety of different 190 # exception types. 191 192 # The cache file is corrupted; ignore it and recompute. 193 logging.error("Cache file is corrupted: %s; recomputing.", exc) 194 result = None 195 196 else: 197 # Check that the latest timestamp has not been written after the 198 # cache file. 199 entries, errors, options_map = result 200 if not needs_refresh(options_map): 201 # All timestamps are legit; cache hit. 202 return result 203 204 # We failed; recompute the value. 205 if exists: 206 try: 207 os.remove(cache_filename) 208 except OSError as exc: 209 # Warn for errors on read-only filesystems. 210 logging.warning("Could not remove picklecache file %s: %s", 211 cache_filename, exc) 212 213 time_before = time.time() 214 result = function(toplevel_filename, *args, **kw) 215 time_after = time.time() 216 217 # Overwrite the cache file if the time it takes to compute it 218 # justifies it. 219 if time_after - time_before > time_threshold: 220 try: 221 with open(cache_filename, 'wb') as file: 222 pickle.dump(result, file) 223 except Exception as exc: 224 logging.warning("Could not write to picklecache file %s: %s", 225 cache_filename, exc) 226 227 return result 228 return wrapped 229 230 231def delete_cache_function(cache_getter, function): 232 """A wrapper that removes the cached filename. 233 234 Args: 235 cache_getter: A function of one argument, the top-level filename, which 236 will return the name of the corresponding cache file. 237 function: A function object to decorate for caching. 238 Returns: 239 A decorated function which will delete the cached filename, if it exists. 240 """ 241 @functools.wraps(function) 242 def wrapped(toplevel_filename, *args, **kw): 243 # Delete the cache. 244 cache_filename = cache_getter(toplevel_filename) 245 if path.exists(cache_filename): 246 os.remove(cache_filename) 247 248 # Invoke the original function. 249 return function(toplevel_filename, *args, **kw) 250 return wrapped 251 252 253def _uncached_load_file(filename, *args, **kw): 254 """Delegate to _load. Note: This gets conditionally advised by caching below.""" 255 return _load([(filename, True)], *args, **kw) 256 257 258def needs_refresh(options_map): 259 """Predicate that returns true if at least one of the input files may have changed. 260 261 Args: 262 options_map: An options dict as per the parser. 263 mtime: A modified time, to check if it covers the include files in the options_map. 264 Returns: 265 A boolean, true if the input is obsoleted by changes in the input files. 266 """ 267 if options_map is None: 268 return True 269 input_hash = compute_input_hash(options_map['include']) 270 return 'input_hash' not in options_map or input_hash != options_map['input_hash'] 271 272 273def compute_input_hash(filenames): 274 """Compute a hash of the input data. 275 276 Args: 277 filenames: A list of input files. Order is not relevant. 278 """ 279 md5 = hashlib.md5() 280 for filename in sorted(filenames): 281 md5.update(filename.encode('utf8')) 282 if not path.exists(filename): 283 continue 284 stat = os.stat(filename) 285 md5.update(struct.pack('dd', stat.st_mtime_ns, stat.st_size)) 286 return md5.hexdigest() 287 288 289def load_string(string, log_timings=None, log_errors=None, extra_validations=None, 290 dedent=False, encoding=None): 291 292 """Open a Beancount input string, parse it, run transformations and validate. 293 294 Args: 295 string: A Beancount input string. 296 log_timings: A file object or function to write timings to, 297 or None, if it should remain quiet. 298 log_errors: A file object or function to write errors to, 299 or None, if it should remain quiet. 300 extra_validations: A list of extra validation functions to run after loading 301 this list of entries. 302 dedent: A boolean, if set, remove the whitespace in front of the lines. 303 encoding: A string or None, the encoding to decode the input string with. 304 Returns: 305 A triple of (entries, errors, option_map) where "entries" is a date-sorted 306 list of entries from the string, "errors" a list of error objects 307 generated while parsing and validating the string, and "options_map", a 308 dict of the options parsed from the string. 309 """ 310 if dedent: 311 string = textwrap.dedent(string) 312 entries, errors, options_map = _load([(string, False)], log_timings, 313 extra_validations, encoding) 314 _log_errors(errors, log_errors) 315 return entries, errors, options_map 316 317 318def _parse_recursive(sources, log_timings, encoding=None): 319 """Parse Beancount input, run its transformations and validate it. 320 321 Recursively parse a list of files or strings and their include files and 322 return an aggregate of parsed directives, errors, and the top-level 323 options-map. If the same file is being parsed twice, ignore it and issue an 324 error. 325 326 Args: 327 sources: A list of (filename-or-string, is-filename) where the first 328 element is a string, with either a filename or a string to be parsed directly, 329 and the second argument is a boolean that is true if the first is a filename. 330 You may provide a list of such arguments to be parsed. Filenames must be absolute 331 paths. 332 log_timings: A function to write timings to, or None, if it should remain quiet. 333 encoding: A string or None, the encoding to decode the input filename with. 334 Returns: 335 A tuple of (entries, parse_errors, options_map). 336 """ 337 assert isinstance(sources, list) and all(isinstance(el, tuple) for el in sources) 338 339 # Current parse state. 340 entries, parse_errors = [], [] 341 options_map = None 342 343 # A stack of sources to be parsed. 344 source_stack = list(sources) 345 346 # A list of absolute filenames that have been parsed in the past, used to 347 # detect and avoid duplicates (cycles). 348 filenames_seen = set() 349 350 with misc_utils.log_time('beancount.parser.parser', log_timings, indent=1): 351 while source_stack: 352 source, is_file = source_stack.pop(0) 353 is_top_level = options_map is None 354 355 # If the file is encrypted, read it in and process it as a string. 356 if is_file: 357 cwd = path.dirname(source) 358 source_filename = source 359 if encryption.is_encrypted_file(source): 360 source = encryption.read_encrypted_file(source) 361 is_file = False 362 else: 363 # If we're parsing a string, the CWD is the current process 364 # working directory. 365 cwd = os.getcwd() 366 source_filename = None 367 368 if is_file: 369 # All filenames here must be absolute. 370 assert path.isabs(source) 371 filename = path.normpath(source) 372 373 # Check for file previously parsed... detect duplicates. 374 if filename in filenames_seen: 375 parse_errors.append( 376 LoadError(data.new_metadata("<load>", 0), 377 'Duplicate filename parsed: "{}"'.format(filename), 378 None)) 379 continue 380 381 # Check for a file that does not exist. 382 if not path.exists(filename): 383 parse_errors.append( 384 LoadError(data.new_metadata("<load>", 0), 385 'File "{}" does not exist'.format(filename), None)) 386 continue 387 388 # Parse a file from disk directly. 389 filenames_seen.add(filename) 390 with misc_utils.log_time('beancount.parser.parser.parse_file', 391 log_timings, indent=2): 392 (src_entries, 393 src_errors, 394 src_options_map) = parser.parse_file(filename, encoding=encoding) 395 396 cwd = path.dirname(filename) 397 else: 398 # Encode the contents if necessary. 399 if encoding: 400 if isinstance(source, bytes): 401 source = source.decode(encoding) 402 source = source.encode('ascii', 'replace') 403 404 # Parse a string buffer from memory. 405 with misc_utils.log_time('beancount.parser.parser.parse_string', 406 log_timings, indent=2): 407 (src_entries, 408 src_errors, 409 src_options_map) = parser.parse_string(source, source_filename) 410 411 # Merge the entries resulting from the parsed file. 412 entries.extend(src_entries) 413 parse_errors.extend(src_errors) 414 415 # We need the options from the very top file only (the very 416 # first file being processed). No merging of options should 417 # occur. 418 if is_top_level: 419 options_map = src_options_map 420 else: 421 aggregate_options_map(options_map, src_options_map) 422 423 # Add includes to the list of sources to process. chdir() for glob, 424 # which uses it indirectly. 425 include_expanded = [] 426 with file_utils.chdir(cwd): 427 for include_filename in src_options_map['include']: 428 matched_filenames = glob.glob(include_filename, recursive=True) 429 if matched_filenames: 430 include_expanded.extend(matched_filenames) 431 else: 432 parse_errors.append( 433 LoadError(data.new_metadata("<load>", 0), 434 'File glob "{}" does not match any files'.format( 435 include_filename), None)) 436 for include_filename in include_expanded: 437 if not path.isabs(include_filename): 438 include_filename = path.join(cwd, include_filename) 439 include_filename = path.normpath(include_filename) 440 441 # Add the include filenames to be processed later. 442 source_stack.append((include_filename, True)) 443 444 # Make sure we have at least a dict of valid options. 445 if options_map is None: 446 options_map = options.OPTIONS_DEFAULTS.copy() 447 448 # Save the set of parsed filenames in options_map. 449 options_map['include'] = sorted(filenames_seen) 450 451 return entries, parse_errors, options_map 452 453 454def aggregate_options_map(options_map, src_options_map): 455 """Aggregate some of the attributes of options map. 456 457 Args: 458 options_map: The target map in which we want to aggregate attributes. 459 Note: This value is mutated in-place. 460 src_options_map: A source map whose values we'd like to see aggregated. 461 """ 462 op_currencies = options_map["operating_currency"] 463 for currency in src_options_map["operating_currency"]: 464 if currency not in op_currencies: 465 op_currencies.append(currency) 466 467 468def _load(sources, log_timings, extra_validations, encoding): 469 """Parse Beancount input, run its transformations and validate it. 470 471 (This is an internal method.) 472 This routine does all that is necessary to obtain a list of entries ready 473 for realization and working with them. This is the principal call for of the 474 scripts that load a ledger. It returns a list of entries transformed and 475 ready for reporting, a list of errors, and parser's options dict. 476 477 Args: 478 sources: A list of (filename-or-string, is-filename) where the first 479 element is a string, with either a filename or a string to be parsed directly, 480 and the second argument is a boolean that is true if the first is a filename. 481 You may provide a list of such arguments to be parsed. Filenames must be absolute 482 paths. 483 log_timings: A file object or function to write timings to, 484 or None, if it should remain quiet. 485 extra_validations: A list of extra validation functions to run after loading 486 this list of entries. 487 encoding: A string or None, the encoding to decode the input filename with. 488 Returns: 489 See load() or load_string(). 490 """ 491 assert isinstance(sources, list) and all(isinstance(el, tuple) for el in sources) 492 493 if hasattr(log_timings, 'write'): 494 log_timings = log_timings.write 495 496 # Parse all the files recursively. Ensure that the entries are sorted before 497 # running any processes on them. 498 with misc_utils.log_time('parse', log_timings, indent=1): 499 entries, parse_errors, options_map = _parse_recursive( 500 sources, log_timings, encoding) 501 entries.sort(key=data.entry_sortkey) 502 503 # Run interpolation on incomplete entries. 504 with misc_utils.log_time('booking', log_timings, indent=1): 505 entries, balance_errors = booking.book(entries, options_map) 506 parse_errors.extend(balance_errors) 507 508 # Transform the entries. 509 with misc_utils.log_time('run_transformations', log_timings, indent=1): 510 entries, errors = run_transformations(entries, parse_errors, options_map, 511 log_timings) 512 513 # Validate the list of entries. 514 with misc_utils.log_time('beancount.ops.validate', log_timings, indent=1): 515 valid_errors = validation.validate(entries, options_map, log_timings, 516 extra_validations) 517 errors.extend(valid_errors) 518 519 # Note: We could go hardcore here and further verify that the entries 520 # haven't been modified by user-provided validation routines, by 521 # comparing hashes before and after. Not needed for now. 522 523 # Compute the input hash. 524 options_map['input_hash'] = compute_input_hash(options_map['include']) 525 526 return entries, errors, options_map 527 528 529def run_transformations(entries, parse_errors, options_map, log_timings): 530 """Run the various transformations on the entries. 531 532 This is where entries are being synthesized, checked, plugins are run, etc. 533 534 Args: 535 entries: A list of directives as read from the parser. 536 parse_errors: A list of errors so far. 537 options_map: An options dict as read from the parser. 538 log_timings: A function to write timing log entries to, or None, if it 539 should be quiet. 540 Returns: 541 A list of modified entries, and a list of errors, also possibly modified. 542 """ 543 # A list of errors to extend (make a copy to avoid modifying the input). 544 errors = list(parse_errors) 545 546 # Process the plugins. 547 if options_map['plugin_processing_mode'] == 'raw': 548 plugins_iter = options_map["plugin"] 549 elif options_map['plugin_processing_mode'] == 'default': 550 plugins_iter = itertools.chain(DEFAULT_PLUGINS_PRE, 551 options_map["plugin"], 552 DEFAULT_PLUGINS_POST) 553 else: 554 assert "Invalid value for plugin_processing_mode: {}".format( 555 options_map['plugin_processing_mode']) 556 557 for plugin_name, plugin_config in plugins_iter: 558 559 # Issue a warning on a renamed module. 560 renamed_name = RENAMED_MODULES.get(plugin_name, None) 561 if renamed_name: 562 warnings.warn("Deprecation notice: Module '{}' has been renamed to '{}'; " 563 "please adjust your plugin directive.".format( 564 plugin_name, renamed_name)) 565 plugin_name = renamed_name 566 567 # Try to import the module. 568 try: 569 module = importlib.import_module(plugin_name) 570 if not hasattr(module, '__plugins__'): 571 continue 572 573 with misc_utils.log_time(plugin_name, log_timings, indent=2): 574 575 # Run each transformer function in the plugin. 576 for function_name in module.__plugins__: 577 if isinstance(function_name, str): 578 # Support plugin functions provided by name. 579 callback = getattr(module, function_name) 580 else: 581 # Support function types directly, not just names. 582 callback = function_name 583 584 if plugin_config is not None: 585 entries, plugin_errors = callback(entries, options_map, 586 plugin_config) 587 else: 588 entries, plugin_errors = callback(entries, options_map) 589 errors.extend(plugin_errors) 590 591 # Ensure that the entries are sorted. Don't trust the plugins 592 # themselves. 593 entries.sort(key=data.entry_sortkey) 594 595 except (ImportError, TypeError) as exc: 596 # Upon failure, just issue an error. 597 errors.append(LoadError(data.new_metadata("<load>", 0), 598 'Error importing "{}": {}'.format( 599 plugin_name, str(exc)), None)) 600 601 return entries, errors 602 603 604def combine_plugins(*plugin_modules): 605 """Combine the plugins from the given plugin modules. 606 607 This is used to create plugins of plugins. 608 Args: 609 *plugins_modules: A sequence of module objects. 610 Returns: 611 A list that can be assigned to the new module's __plugins__ attribute. 612 """ 613 modules = [] 614 for module in plugin_modules: 615 modules.extend([getattr(module, name) 616 for name in module.__plugins__]) 617 return modules 618 619 620def load_doc(expect_errors=False): 621 """A factory of decorators that loads the docstring and calls the function with entries. 622 623 This is an incredibly convenient tool to write lots of tests. Write a 624 unittest using the standard TestCase class and put the input entries in the 625 function's docstring. 626 627 Args: 628 expect_errors: A boolean or None, with the following semantics, 629 True: Expect errors and fail if there are none. 630 False: Expect no errors and fail if there are some. 631 None: Do nothing, no check. 632 Returns: 633 A wrapped method that accepts a single 'self' argument. 634 """ 635 def decorator(fun): 636 """A decorator that parses the function's docstring as an argument. 637 638 Args: 639 fun: A callable method, that accepts the three return arguments that 640 load() returns. 641 Returns: 642 A decorated test function. 643 """ 644 @functools.wraps(fun) 645 def wrapper(self): 646 entries, errors, options_map = load_string(fun.__doc__, dedent=True) 647 648 if expect_errors is not None: 649 if expect_errors is False and errors: 650 oss = io.StringIO() 651 printer.print_errors(errors, file=oss) 652 self.fail("Unexpected errors found:\n{}".format(oss.getvalue())) 653 elif expect_errors is True and not errors: 654 self.fail("Expected errors, none found:") 655 656 # Note: Even if we expected no errors, we call this function with an 657 # empty 'errors' list. This is so that the interface does not change 658 # based on the arguments to the decorator, which would be somewhat 659 # ugly and which would require explanation. 660 return fun(self, entries, errors, options_map) 661 662 wrapper.__input__ = wrapper.__doc__ 663 wrapper.__doc__ = None 664 return wrapper 665 666 return decorator 667 668 669def initialize(use_cache: bool, cache_filename: Optional[str] = None): 670 """Initialize the loader.""" 671 672 # Unless an environment variable disables it, use the pickle load cache 673 # automatically. Note that this works across all Python programs running the 674 # loader which is why it's located here. 675 # pylint: disable=invalid-name 676 global _load_file 677 678 # Make a function to compute the cache filename. 679 cache_pattern = (cache_filename or 680 os.getenv('BEANCOUNT_LOAD_CACHE_FILENAME') or 681 PICKLE_CACHE_FILENAME) 682 cache_getter = functools.partial(get_cache_filename, cache_pattern) 683 684 if use_cache: 685 _load_file = pickle_cache_function(cache_getter, PICKLE_CACHE_THRESHOLD, 686 _uncached_load_file) 687 else: 688 if cache_filename is not None: 689 logging.warning("Cache disabled; " 690 "Explicitly overridden cache filename %s will be ignored.", 691 cache_filename) 692 _load_file = delete_cache_function(cache_getter, 693 _uncached_load_file) 694 695 696# Default is to use the cache every time. 697initialize(os.getenv('BEANCOUNT_DISABLE_LOAD_CACHE') is None) 698