1# Orca 2# Copyright (C) 2016 UrbanSim Inc. 3# See full license in LICENSE. 4 5from __future__ import print_function 6 7try: 8 from inspect import getfullargspec as getargspec 9except ImportError: 10 from inspect import getargspec 11import logging 12import time 13import warnings 14from collections import namedtuple 15try: 16 from collections.abc import Callable 17except ImportError: # Python 2.7 18 from collections import Callable 19from contextlib import contextmanager 20from functools import wraps 21 22 23import pandas as pd 24import tables 25import tlz as tz 26 27from . import utils 28from .utils.logutil import log_start_finish 29 30warnings.filterwarnings('ignore', category=tables.NaturalNameWarning) 31logger = logging.getLogger(__name__) 32 33_TABLES = {} 34_COLUMNS = {} 35_STEPS = {} 36_BROADCASTS = {} 37_INJECTABLES = {} 38 39_CACHING = True 40_TABLE_CACHE = {} 41_COLUMN_CACHE = {} 42_INJECTABLE_CACHE = {} 43_MEMOIZED = {} 44 45_CS_FOREVER = 'forever' 46_CS_ITER = 'iteration' 47_CS_STEP = 'step' 48 49CacheItem = namedtuple('CacheItem', ['name', 'value', 'scope']) 50 51 52def clear_all(): 53 """ 54 Clear any and all stored state from Orca. 55 56 """ 57 _TABLES.clear() 58 _COLUMNS.clear() 59 _STEPS.clear() 60 _BROADCASTS.clear() 61 _INJECTABLES.clear() 62 _TABLE_CACHE.clear() 63 _COLUMN_CACHE.clear() 64 _INJECTABLE_CACHE.clear() 65 for m in _MEMOIZED.values(): 66 m.value.clear_cached() 67 _MEMOIZED.clear() 68 logger.debug('pipeline state cleared') 69 70 71def clear_cache(scope=None): 72 """ 73 Clear all cached data. 74 75 Parameters 76 ---------- 77 scope : {None, 'step', 'iteration', 'forever'}, optional 78 Clear cached values with a given scope. 79 By default all cached values are removed. 80 81 """ 82 if not scope: 83 _TABLE_CACHE.clear() 84 _COLUMN_CACHE.clear() 85 _INJECTABLE_CACHE.clear() 86 for m in _MEMOIZED.values(): 87 m.value.clear_cached() 88 logger.debug('pipeline cache cleared') 89 else: 90 for d in (_TABLE_CACHE, _COLUMN_CACHE, _INJECTABLE_CACHE): 91 items = tz.valfilter(lambda x: x.scope == scope, d) 92 for k in items: 93 del d[k] 94 for m in tz.filter(lambda x: x.scope == scope, _MEMOIZED.values()): 95 m.value.clear_cached() 96 logger.debug('cleared cached values with scope {!r}'.format(scope)) 97 98 99def clear_injectable(injectable_name): 100 """ 101 Clear the cached value of an injectable. *Added in Orca v1.6.* 102 103 Parameters 104 ---------- 105 name: str 106 Name of injectable to clear. 107 108 """ 109 _INJECTABLES[injectable_name].clear_cached() 110 111 112def clear_table(table_name): 113 """ 114 Clear the cached copy of an entire table. *Added in Orca v1.6.* 115 116 Parameters 117 ---------- 118 name: str 119 Name of table to clear. 120 121 """ 122 _TABLES[table_name].clear_cached() 123 124 125def clear_column(table_name, column_name): 126 """ 127 Clear the cached copy of a dynamically generated column. 128 *Added in Orca v1.6.* 129 130 Parameters 131 ---------- 132 table_name: str 133 Table containing the column to clear. 134 column_name: str 135 Name of the column to clear. 136 137 """ 138 _COLUMNS[(table_name, column_name)].clear_cached() 139 140 141def clear_columns(table_name, columns=None): 142 """ 143 Clear all (or a specified list) of the dynamically generated columns 144 associated with a table. *Added in Orca v1.6.* 145 146 Parameters 147 ---------- 148 table_name: str 149 Table name. 150 columns: list of str, optional, default None 151 List of columns to clear. If None, all extra/computed 152 columns in the table will be cleeared. 153 154 """ 155 if columns is None: 156 tab = get_table(table_name) 157 cols = tab.columns 158 local_cols = tab.local_columns 159 columns = [c for c in cols if c not in local_cols] 160 print('****************') 161 print(columns) 162 163 for col in columns: 164 clear_column(table_name, col) 165 166 167def _update_scope(wrapper, new_scope=None): 168 """ 169 Update the cache scope for a wrapper (in place). 170 *Added in Orca v1.6.* 171 172 Parameters 173 ---------- 174 wrapper: object 175 Should be an instance of wrapper with attributes 176 `cache`, `cache_scope` and method `clear_cached`. 177 new_scope: str, optional default None 178 The new scope value. None implies no caching. 179 180 """ 181 # allowable scopes, values indicate the update granularity 182 scopes = { 183 None: 0, 184 _CS_STEP: 1, 185 _CS_ITER: 2, 186 _CS_FOREVER: 3 187 } 188 if new_scope not in scopes.keys(): 189 msg = '{} is not an allowed cache scope, '.format(new_scope) 190 msg += 'allowed scopes are {}'.format(list(scopes.keys())) 191 raise ValueError(msg) 192 193 # update the cache properties 194 curr_cache = wrapper.cache 195 curr_scope = wrapper.cache_scope 196 if new_scope is None: 197 # set to defaults, i.e. no caching 198 wrapper.cache = False 199 wrapper.cache_scope = _CS_FOREVER 200 else: 201 wrapper.cache = True 202 wrapper.cache_scope = new_scope 203 204 # clear out any existing caches if the provided scope is 205 # more granular than the existing 206 old_granularity = scopes[curr_scope] 207 new_granularity = scopes[new_scope] 208 if new_granularity < old_granularity: 209 wrapper.clear_cached() 210 211 212def update_injectable_scope(name, new_scope=None): 213 """ 214 Update the cache scope for a wrapped injectable function. 215 Clears the cache if the new scope is more granular 216 than the existing. *Added in Orca v1.6.* 217 218 Parameters 219 ---------- 220 name: str 221 Name of the injectable to update. 222 new_scope: str, optional default None 223 Valid values: None, 'forever', 'iteration', 'step' 224 None implies no caching. 225 226 """ 227 _update_scope( 228 get_raw_injectable(name), new_scope) 229 230 231def update_column_scope(table_name, column_name, new_scope=None): 232 """ 233 Update the cache scope for a wrapped column function. Clears 234 the cache if the new scope is more granular than the existing. 235 *Added in Orca v1.6.* 236 237 Parameters 238 ---------- 239 table_name: str 240 Name of the table. 241 column_name: str 242 Name of the column to update. 243 new_scope: str, optional default None 244 Valid values: None, 'forever', 'iteration', 'step' 245 None implies no caching. 246 247 """ 248 _update_scope( 249 get_raw_column(table_name, column_name), new_scope) 250 251 252def update_table_scope(name, new_scope=None): 253 """ 254 Update the cache scope for a wrapped table function. Clears 255 the cache if the new scope is more granular than the existing. 256 *Added in Orca v1.6.* 257 258 Parameters 259 ---------- 260 name: str 261 Name of the table to update. 262 new_scope: str, optional default None 263 Valid values: None, 'forever', 'iteration', 'step' 264 None implies no caching. 265 266 """ 267 _update_scope( 268 get_raw_table(name), new_scope) 269 270 271def enable_cache(): 272 """ 273 Allow caching of registered variables that explicitly have 274 caching enabled. 275 276 """ 277 global _CACHING 278 _CACHING = True 279 280 281def disable_cache(): 282 """ 283 Turn off caching across Orca, even for registered variables 284 that have caching enabled. 285 286 """ 287 global _CACHING 288 _CACHING = False 289 290 291def cache_on(): 292 """ 293 Whether caching is currently enabled or disabled. 294 295 Returns 296 ------- 297 on : bool 298 True if caching is enabled. 299 300 """ 301 return _CACHING 302 303 304@contextmanager 305def cache_disabled(): 306 turn_back_on = True if cache_on() else False 307 disable_cache() 308 309 yield 310 311 if turn_back_on: 312 enable_cache() 313 314 315# for errors that occur during Orca runs 316class OrcaError(Exception): 317 pass 318 319 320class DataFrameWrapper(object): 321 """ 322 Wraps a DataFrame so it can provide certain columns and handle 323 computed columns. 324 325 Parameters 326 ---------- 327 name : str 328 Name for the table. 329 frame : pandas.DataFrame 330 copy_col : bool, optional 331 Whether to return copies when evaluating columns. 332 333 Attributes 334 ---------- 335 name : str 336 Table name. 337 copy_col : bool 338 Whether to return copies when evaluating columns. 339 local : pandas.DataFrame 340 The wrapped DataFrame. 341 342 """ 343 def __init__(self, name, frame, copy_col=True): 344 self.name = name 345 self.local = frame 346 self.copy_col = copy_col 347 348 @property 349 def columns(self): 350 """ 351 Columns in this table. 352 353 """ 354 return self.local_columns + list_columns_for_table(self.name) 355 356 @property 357 def local_columns(self): 358 """ 359 Columns that are part of the wrapped DataFrame. 360 361 """ 362 return list(self.local.columns) 363 364 @property 365 def index(self): 366 """ 367 Table index. 368 369 """ 370 return self.local.index 371 372 def to_frame(self, columns=None): 373 """ 374 Make a DataFrame with the given columns. 375 376 Will always return a copy of the underlying table. 377 378 Parameters 379 ---------- 380 columns : sequence or string, optional 381 Sequence of the column names desired in the DataFrame. A string 382 can also be passed if only one column is desired. 383 If None all columns are returned, including registered columns. 384 385 Returns 386 ------- 387 frame : pandas.DataFrame 388 389 """ 390 extra_cols = _columns_for_table(self.name) 391 392 if columns is not None: 393 columns = [columns] if isinstance(columns, str) else columns 394 columns = set(columns) 395 set_extra_cols = set(extra_cols) 396 local_cols = set(self.local.columns) & columns - set_extra_cols 397 df = self.local[list(local_cols)].copy() 398 extra_cols = {k: extra_cols[k] for k in (columns & set_extra_cols)} 399 else: 400 df = self.local.copy() 401 402 with log_start_finish( 403 'computing {!r} columns for table {!r}'.format( 404 len(extra_cols), self.name), 405 logger): 406 for name, col in extra_cols.items(): 407 with log_start_finish( 408 'computing column {!r} for table {!r}'.format( 409 name, self.name), 410 logger): 411 df[name] = col() 412 413 return df 414 415 def update_col(self, column_name, series): 416 """ 417 Add or replace a column in the underlying DataFrame. 418 419 Parameters 420 ---------- 421 column_name : str 422 Column to add or replace. 423 series : pandas.Series or sequence 424 Column data. 425 426 """ 427 logger.debug('updating column {!r} in table {!r}'.format( 428 column_name, self.name)) 429 self.local[column_name] = series 430 431 def __setitem__(self, key, value): 432 return self.update_col(key, value) 433 434 def get_column(self, column_name): 435 """ 436 Returns a column as a Series. 437 438 Parameters 439 ---------- 440 column_name : str 441 442 Returns 443 ------- 444 column : pandas.Series 445 446 """ 447 with log_start_finish( 448 'getting single column {!r} from table {!r}'.format( 449 column_name, self.name), 450 logger): 451 extra_cols = _columns_for_table(self.name) 452 if column_name in extra_cols: 453 with log_start_finish( 454 'computing column {!r} for table {!r}'.format( 455 column_name, self.name), 456 logger): 457 column = extra_cols[column_name]() 458 else: 459 column = self.local[column_name] 460 if self.copy_col: 461 return column.copy() 462 else: 463 return column 464 465 def __getitem__(self, key): 466 return self.get_column(key) 467 468 def __getattr__(self, key): 469 return self.get_column(key) 470 471 def column_type(self, column_name): 472 """ 473 Report column type as one of 'local', 'series', or 'function'. 474 475 Parameters 476 ---------- 477 column_name : str 478 479 Returns 480 ------- 481 col_type : {'local', 'series', 'function'} 482 'local' means that the column is part of the registered table, 483 'series' means the column is a registered Pandas Series, 484 and 'function' means the column is a registered function providing 485 a Pandas Series. 486 487 """ 488 extra_cols = list_columns_for_table(self.name) 489 490 if column_name in extra_cols: 491 col = _COLUMNS[(self.name, column_name)] 492 493 if isinstance(col, _SeriesWrapper): 494 return 'series' 495 elif isinstance(col, _ColumnFuncWrapper): 496 return 'function' 497 498 elif column_name in self.local_columns: 499 return 'local' 500 501 raise KeyError('column {!r} not found'.format(column_name)) 502 503 def update_col_from_series(self, column_name, series, cast=False): 504 """ 505 Update existing values in a column from another series. 506 Index values must match in both column and series. Optionally 507 casts data type to match the existing column. 508 509 Parameters 510 --------------- 511 column_name : str 512 series : panas.Series 513 cast: bool, optional, default False 514 """ 515 logger.debug('updating column {!r} in table {!r}'.format( 516 column_name, self.name)) 517 518 col_dtype = self.local[column_name].dtype 519 if series.dtype != col_dtype: 520 if cast: 521 series = series.astype(col_dtype) 522 else: 523 err_msg = "Data type mismatch, existing:{}, update:{}" 524 err_msg = err_msg.format(col_dtype, series.dtype) 525 raise ValueError(err_msg) 526 527 self.local.loc[series.index, column_name] = series 528 529 def __len__(self): 530 return len(self.local) 531 532 def clear_cached(self): 533 """ 534 Remove cached results from this table's computed columns. 535 536 """ 537 _TABLE_CACHE.pop(self.name, None) 538 for col in _columns_for_table(self.name).values(): 539 col.clear_cached() 540 logger.debug('cleared cached columns for table {!r}'.format(self.name)) 541 542 543class TableFuncWrapper(object): 544 """ 545 Wrap a function that provides a DataFrame. 546 547 Parameters 548 ---------- 549 name : str 550 Name for the table. 551 func : callable 552 Callable that returns a DataFrame. 553 cache : bool, optional 554 Whether to cache the results of calling the wrapped function. 555 cache_scope : {'step', 'iteration', 'forever'}, optional 556 Scope for which to cache data. Default is to cache forever 557 (or until manually cleared). 'iteration' caches data for each 558 complete iteration of the pipeline, 'step' caches data for 559 a single step of the pipeline. 560 copy_col : bool, optional 561 Whether to return copies when evaluating columns. 562 563 Attributes 564 ---------- 565 name : str 566 Table name. 567 cache : bool 568 Whether caching is enabled for this table. 569 copy_col : bool 570 Whether to return copies when evaluating columns. 571 572 """ 573 def __init__( 574 self, name, func, cache=False, cache_scope=_CS_FOREVER, 575 copy_col=True): 576 self.name = name 577 self._func = func 578 self._argspec = getargspec(func) 579 self.cache = cache 580 self.cache_scope = cache_scope 581 self.copy_col = copy_col 582 self._columns = [] 583 self._index = None 584 self._len = 0 585 586 @property 587 def columns(self): 588 """ 589 Columns in this table. (May contain only computed columns 590 if the wrapped function has not been called yet.) 591 592 """ 593 return self._columns + list_columns_for_table(self.name) 594 595 @property 596 def local_columns(self): 597 """ 598 Only the columns contained in the DataFrame returned by the 599 wrapped function. (No registered columns included.) 600 601 """ 602 if self._columns: 603 return self._columns 604 else: 605 self._call_func() 606 return self._columns 607 608 @property 609 def index(self): 610 """ 611 Index of the underlying table. Will be None if that index is 612 unknown. 613 614 """ 615 return self._index 616 617 def _call_func(self): 618 """ 619 Call the wrapped function and return the result wrapped by 620 DataFrameWrapper. 621 Also updates attributes like columns, index, and length. 622 623 """ 624 if _CACHING and self.cache and self.name in _TABLE_CACHE: 625 logger.debug('returning table {!r} from cache'.format(self.name)) 626 return _TABLE_CACHE[self.name].value 627 628 with log_start_finish( 629 'call function to get frame for table {!r}'.format( 630 self.name), 631 logger): 632 kwargs = _collect_variables(names=self._argspec.args, 633 expressions=self._argspec.defaults) 634 frame = self._func(**kwargs) 635 636 self._columns = list(frame.columns) 637 self._index = frame.index 638 self._len = len(frame) 639 640 wrapped = DataFrameWrapper(self.name, frame, copy_col=self.copy_col) 641 642 if self.cache: 643 _TABLE_CACHE[self.name] = CacheItem( 644 self.name, wrapped, self.cache_scope) 645 646 return wrapped 647 648 def __call__(self): 649 return self._call_func() 650 651 def to_frame(self, columns=None): 652 """ 653 Make a DataFrame with the given columns. 654 655 Will always return a copy of the underlying table. 656 657 Parameters 658 ---------- 659 columns : sequence, optional 660 Sequence of the column names desired in the DataFrame. 661 If None all columns are returned. 662 663 Returns 664 ------- 665 frame : pandas.DataFrame 666 667 """ 668 return self._call_func().to_frame(columns) 669 670 def get_column(self, column_name): 671 """ 672 Returns a column as a Series. 673 674 Parameters 675 ---------- 676 column_name : str 677 678 Returns 679 ------- 680 column : pandas.Series 681 682 """ 683 frame = self._call_func() 684 return DataFrameWrapper(self.name, frame, 685 copy_col=self.copy_col).get_column(column_name) 686 687 def __getitem__(self, key): 688 return self.get_column(key) 689 690 def __getattr__(self, key): 691 return self.get_column(key) 692 693 def __len__(self): 694 return self._len 695 696 def column_type(self, column_name): 697 """ 698 Report column type as one of 'local', 'series', or 'function'. 699 700 Parameters 701 ---------- 702 column_name : str 703 704 Returns 705 ------- 706 col_type : {'local', 'series', 'function'} 707 'local' means that the column is part of the registered table, 708 'series' means the column is a registered Pandas Series, 709 and 'function' means the column is a registered function providing 710 a Pandas Series. 711 712 """ 713 extra_cols = list_columns_for_table(self.name) 714 715 if column_name in extra_cols: 716 col = _COLUMNS[(self.name, column_name)] 717 718 if isinstance(col, _SeriesWrapper): 719 return 'series' 720 elif isinstance(col, _ColumnFuncWrapper): 721 return 'function' 722 723 elif column_name in self.local_columns: 724 return 'local' 725 726 raise KeyError('column {!r} not found'.format(column_name)) 727 728 def clear_cached(self): 729 """ 730 Remove this table's cached result and that of associated columns. 731 732 """ 733 _TABLE_CACHE.pop(self.name, None) 734 for col in _columns_for_table(self.name).values(): 735 col.clear_cached() 736 logger.debug( 737 'cleared cached result and cached columns for table {!r}'.format( 738 self.name)) 739 740 def func_source_data(self): 741 """ 742 Return data about the wrapped function source, including file name, 743 line number, and source code. 744 745 Returns 746 ------- 747 filename : str 748 lineno : int 749 The line number on which the function starts. 750 source : str 751 752 """ 753 return utils.func_source_data(self._func) 754 755 756class _ColumnFuncWrapper(object): 757 """ 758 Wrap a function that returns a Series. 759 760 Parameters 761 ---------- 762 table_name : str 763 Table with which the column will be associated. 764 column_name : str 765 Name for the column. 766 func : callable 767 Should return a Series that has an 768 index matching the table to which it is being added. 769 cache : bool, optional 770 Whether to cache the result of calling the wrapped function. 771 cache_scope : {'step', 'iteration', 'forever'}, optional 772 Scope for which to cache data. Default is to cache forever 773 (or until manually cleared). 'iteration' caches data for each 774 complete iteration of the pipeline, 'step' caches data for 775 a single step of the pipeline. 776 777 Attributes 778 ---------- 779 name : str 780 Column name. 781 table_name : str 782 Name of table this column is associated with. 783 cache : bool 784 Whether caching is enabled for this column. 785 786 """ 787 def __init__( 788 self, table_name, column_name, func, cache=False, 789 cache_scope=_CS_FOREVER): 790 self.table_name = table_name 791 self.name = column_name 792 self._func = func 793 self._argspec = getargspec(func) 794 self.cache = cache 795 self.cache_scope = cache_scope 796 797 def __call__(self): 798 """ 799 Evaluate the wrapped function and return the result. 800 801 """ 802 if (_CACHING and 803 self.cache and 804 (self.table_name, self.name) in _COLUMN_CACHE): 805 logger.debug( 806 'returning column {!r} for table {!r} from cache'.format( 807 self.name, self.table_name)) 808 return _COLUMN_CACHE[(self.table_name, self.name)].value 809 810 with log_start_finish( 811 ('call function to provide column {!r} for table {!r}' 812 ).format(self.name, self.table_name), logger): 813 kwargs = _collect_variables(names=self._argspec.args, 814 expressions=self._argspec.defaults) 815 col = self._func(**kwargs) 816 817 if self.cache: 818 _COLUMN_CACHE[(self.table_name, self.name)] = CacheItem( 819 (self.table_name, self.name), col, self.cache_scope) 820 821 return col 822 823 def clear_cached(self): 824 """ 825 Remove any cached result of this column. 826 827 """ 828 x = _COLUMN_CACHE.pop((self.table_name, self.name), None) 829 if x is not None: 830 logger.debug( 831 'cleared cached value for column {!r} in table {!r}'.format( 832 self.name, self.table_name)) 833 834 def func_source_data(self): 835 """ 836 Return data about the wrapped function source, including file name, 837 line number, and source code. 838 839 Returns 840 ------- 841 filename : str 842 lineno : int 843 The line number on which the function starts. 844 source : str 845 846 """ 847 return utils.func_source_data(self._func) 848 849 850class _SeriesWrapper(object): 851 """ 852 Wrap a Series for the purpose of giving it the same interface as a 853 `_ColumnFuncWrapper`. 854 855 Parameters 856 ---------- 857 table_name : str 858 Table with which the column will be associated. 859 column_name : str 860 Name for the column. 861 series : pandas.Series 862 Series with index matching the table to which it is being added. 863 864 Attributes 865 ---------- 866 name : str 867 Column name. 868 table_name : str 869 Name of table this column is associated with. 870 871 """ 872 def __init__(self, table_name, column_name, series): 873 self.table_name = table_name 874 self.name = column_name 875 self._column = series 876 877 def __call__(self): 878 return self._column 879 880 def clear_cached(self): 881 """ 882 Here for compatibility with `_ColumnFuncWrapper`. 883 884 """ 885 pass 886 887 888class _InjectableFuncWrapper(object): 889 """ 890 Wraps a function that will provide an injectable value elsewhere. 891 892 Parameters 893 ---------- 894 name : str 895 func : callable 896 cache : bool, optional 897 Whether to cache the result of calling the wrapped function. 898 cache_scope : {'step', 'iteration', 'forever'}, optional 899 Scope for which to cache data. Default is to cache forever 900 (or until manually cleared). 'iteration' caches data for each 901 complete iteration of the pipeline, 'step' caches data for 902 a single step of the pipeline. 903 904 Attributes 905 ---------- 906 name : str 907 Name of this injectable. 908 cache : bool 909 Whether caching is enabled for this injectable function. 910 911 """ 912 def __init__(self, name, func, cache=False, cache_scope=_CS_FOREVER): 913 self.name = name 914 self._func = func 915 self._argspec = getargspec(func) 916 self.cache = cache 917 self.cache_scope = cache_scope 918 919 def __call__(self): 920 if _CACHING and self.cache and self.name in _INJECTABLE_CACHE: 921 logger.debug( 922 'returning injectable {!r} from cache'.format(self.name)) 923 return _INJECTABLE_CACHE[self.name].value 924 925 with log_start_finish( 926 'call function to provide injectable {!r}'.format(self.name), 927 logger): 928 kwargs = _collect_variables(names=self._argspec.args, 929 expressions=self._argspec.defaults) 930 result = self._func(**kwargs) 931 932 if self.cache: 933 _INJECTABLE_CACHE[self.name] = CacheItem( 934 self.name, result, self.cache_scope) 935 936 return result 937 938 def clear_cached(self): 939 """ 940 Clear a cached result for this injectable. 941 942 """ 943 x = _INJECTABLE_CACHE.pop(self.name, None) 944 if x: 945 logger.debug( 946 'injectable {!r} removed from cache'.format(self.name)) 947 948 949class _StepFuncWrapper(object): 950 """ 951 Wrap a step function for argument matching. 952 953 Parameters 954 ---------- 955 step_name : str 956 func : callable 957 958 Attributes 959 ---------- 960 name : str 961 Name of step. 962 963 """ 964 def __init__(self, step_name, func): 965 self.name = step_name 966 self._func = func 967 self._argspec = getargspec(func) 968 969 def __call__(self): 970 with log_start_finish('calling step {!r}'.format(self.name), logger): 971 kwargs = _collect_variables(names=self._argspec.args, 972 expressions=self._argspec.defaults) 973 return self._func(**kwargs) 974 975 def _tables_used(self): 976 """ 977 Tables injected into the step. 978 979 Returns 980 ------- 981 tables : set of str 982 983 """ 984 args = list(self._argspec.args) 985 if self._argspec.defaults: 986 default_args = list(self._argspec.defaults) 987 else: 988 default_args = [] 989 # Combine names from argument names and argument default values. 990 names = args[:len(args) - len(default_args)] + default_args 991 tables = set() 992 for name in names: 993 parent_name = name.split('.')[0] 994 if is_table(parent_name): 995 tables.add(parent_name) 996 return tables 997 998 def func_source_data(self): 999 """ 1000 Return data about a step function's source, including file name, 1001 line number, and source code. 1002 1003 Returns 1004 ------- 1005 filename : str 1006 lineno : int 1007 The line number on which the function starts. 1008 source : str 1009 1010 """ 1011 return utils.func_source_data(self._func) 1012 1013 1014def is_table(name): 1015 """ 1016 Returns whether a given name refers to a registered table. 1017 1018 """ 1019 return name in _TABLES 1020 1021 1022def list_tables(): 1023 """ 1024 List of table names. 1025 1026 """ 1027 return list(_TABLES.keys()) 1028 1029 1030def list_columns(): 1031 """ 1032 List of (table name, registered column name) pairs. 1033 1034 """ 1035 return list(_COLUMNS.keys()) 1036 1037 1038def list_steps(): 1039 """ 1040 List of registered step names. 1041 1042 """ 1043 return list(_STEPS.keys()) 1044 1045 1046def list_injectables(): 1047 """ 1048 List of registered injectables. 1049 1050 """ 1051 return list(_INJECTABLES.keys()) 1052 1053 1054def list_broadcasts(): 1055 """ 1056 List of registered broadcasts as (cast table name, onto table name). 1057 1058 """ 1059 return list(_BROADCASTS.keys()) 1060 1061 1062def is_expression(name): 1063 """ 1064 Checks whether a given name is a simple variable name or a compound 1065 variable expression. 1066 1067 Parameters 1068 ---------- 1069 name : str 1070 1071 Returns 1072 ------- 1073 is_expr : bool 1074 1075 """ 1076 return '.' in name 1077 1078 1079def _collect_variables(names, expressions=None): 1080 """ 1081 Map labels and expressions to registered variables. 1082 1083 Handles argument matching. 1084 1085 Example: 1086 1087 _collect_variables(names=['zones', 'zone_id'], 1088 expressions=['parcels.zone_id']) 1089 1090 Would return a dict representing: 1091 1092 {'parcels': <DataFrameWrapper for zones>, 1093 'zone_id': <pandas.Series for parcels.zone_id>} 1094 1095 Parameters 1096 ---------- 1097 names : list of str 1098 List of registered variable names and/or labels. 1099 If mixing names and labels, labels must come at the end. 1100 expressions : list of str, optional 1101 List of registered variable expressions for labels defined 1102 at end of `names`. Length must match the number of labels. 1103 1104 Returns 1105 ------- 1106 variables : dict 1107 Keys match `names`. Values correspond to registered variables, 1108 which may be wrappers or evaluated functions if appropriate. 1109 1110 """ 1111 # Map registered variable labels to expressions. 1112 if not expressions: 1113 expressions = [] 1114 offset = len(names) - len(expressions) 1115 labels_map = dict(tz.concatv( 1116 zip(names[:offset], names[:offset]), 1117 zip(names[offset:], expressions))) 1118 1119 all_variables = tz.merge(_INJECTABLES, _TABLES) 1120 variables = {} 1121 for label, expression in labels_map.items(): 1122 # In the future, more registered variable expressions could be 1123 # supported. Currently supports names of registered variables 1124 # and references to table columns. 1125 if '.' in expression: 1126 # Registered variable expression refers to column. 1127 table_name, column_name = expression.split('.') 1128 table = get_table(table_name) 1129 variables[label] = table.get_column(column_name) 1130 else: 1131 thing = all_variables[expression] 1132 if isinstance(thing, (_InjectableFuncWrapper, TableFuncWrapper)): 1133 # Registered variable object is function. 1134 variables[label] = thing() 1135 else: 1136 variables[label] = thing 1137 1138 return variables 1139 1140 1141def add_table( 1142 table_name, table, cache=False, cache_scope=_CS_FOREVER, 1143 copy_col=True): 1144 """ 1145 Register a table with Orca. 1146 1147 Parameters 1148 ---------- 1149 table_name : str 1150 Should be globally unique to this table. 1151 table : pandas.DataFrame or function 1152 If a function, the function should return a DataFrame. 1153 The function's argument names and keyword argument values 1154 will be matched to registered variables when the function 1155 needs to be evaluated by Orca. 1156 cache : bool, optional 1157 Whether to cache the results of a provided callable. Does not 1158 apply if `table` is a DataFrame. 1159 cache_scope : {'step', 'iteration', 'forever'}, optional 1160 Scope for which to cache data. Default is to cache forever 1161 (or until manually cleared). 'iteration' caches data for each 1162 complete iteration of the pipeline, 'step' caches data for 1163 a single step of the pipeline. 1164 copy_col : bool, optional 1165 Whether to return copies when evaluating columns. 1166 1167 Returns 1168 ------- 1169 wrapped : `DataFrameWrapper` or `TableFuncWrapper` 1170 1171 """ 1172 if isinstance(table, Callable): 1173 table = TableFuncWrapper(table_name, table, cache=cache, 1174 cache_scope=cache_scope, copy_col=copy_col) 1175 else: 1176 table = DataFrameWrapper(table_name, table, copy_col=copy_col) 1177 1178 # clear any cached data from a previously registered table 1179 table.clear_cached() 1180 1181 logger.debug('registering table {!r}'.format(table_name)) 1182 _TABLES[table_name] = table 1183 1184 return table 1185 1186 1187def table( 1188 table_name=None, cache=False, cache_scope=_CS_FOREVER, copy_col=True): 1189 """ 1190 Decorates functions that return DataFrames. 1191 1192 Decorator version of `add_table`. Table name defaults to 1193 name of function. 1194 1195 The function's argument names and keyword argument values 1196 will be matched to registered variables when the function 1197 needs to be evaluated by Orca. 1198 The argument name "iter_var" may be used to have the current 1199 iteration variable injected. 1200 1201 """ 1202 def decorator(func): 1203 if table_name: 1204 name = table_name 1205 else: 1206 name = func.__name__ 1207 add_table( 1208 name, func, cache=cache, cache_scope=cache_scope, 1209 copy_col=copy_col) 1210 return func 1211 return decorator 1212 1213 1214def get_raw_table(table_name): 1215 """ 1216 Get a wrapped table by name and don't do anything to it. 1217 1218 Parameters 1219 ---------- 1220 table_name : str 1221 1222 Returns 1223 ------- 1224 table : DataFrameWrapper or TableFuncWrapper 1225 1226 """ 1227 if is_table(table_name): 1228 return _TABLES[table_name] 1229 else: 1230 raise KeyError('table not found: {}'.format(table_name)) 1231 1232 1233def get_table(table_name): 1234 """ 1235 Get a registered table. 1236 1237 Decorated functions will be converted to `DataFrameWrapper`. 1238 1239 Parameters 1240 ---------- 1241 table_name : str 1242 1243 Returns 1244 ------- 1245 table : `DataFrameWrapper` 1246 1247 """ 1248 table = get_raw_table(table_name) 1249 if isinstance(table, TableFuncWrapper): 1250 table = table() 1251 return table 1252 1253 1254def table_type(table_name): 1255 """ 1256 Returns the type of a registered table. 1257 1258 The type can be either "dataframe" or "function". 1259 1260 Parameters 1261 ---------- 1262 table_name : str 1263 1264 Returns 1265 ------- 1266 table_type : {'dataframe', 'function'} 1267 1268 """ 1269 table = get_raw_table(table_name) 1270 1271 if isinstance(table, DataFrameWrapper): 1272 return 'dataframe' 1273 elif isinstance(table, TableFuncWrapper): 1274 return 'function' 1275 1276 1277def add_column( 1278 table_name, column_name, column, cache=False, cache_scope=_CS_FOREVER): 1279 """ 1280 Add a new column to a table from a Series or callable. 1281 1282 Parameters 1283 ---------- 1284 table_name : str 1285 Table with which the column will be associated. 1286 column_name : str 1287 Name for the column. 1288 column : pandas.Series or callable 1289 Series should have an index matching the table to which it 1290 is being added. If a callable, the function's argument 1291 names and keyword argument values will be matched to 1292 registered variables when the function needs to be 1293 evaluated by Orca. The function should return a Series. 1294 cache : bool, optional 1295 Whether to cache the results of a provided callable. Does not 1296 apply if `column` is a Series. 1297 cache_scope : {'step', 'iteration', 'forever'}, optional 1298 Scope for which to cache data. Default is to cache forever 1299 (or until manually cleared). 'iteration' caches data for each 1300 complete iteration of the pipeline, 'step' caches data for 1301 a single step of the pipeline. 1302 1303 """ 1304 if isinstance(column, Callable): 1305 column = \ 1306 _ColumnFuncWrapper( 1307 table_name, column_name, column, 1308 cache=cache, cache_scope=cache_scope) 1309 else: 1310 column = _SeriesWrapper(table_name, column_name, column) 1311 1312 # clear any cached data from a previously registered column 1313 column.clear_cached() 1314 1315 logger.debug('registering column {!r} on table {!r}'.format( 1316 column_name, table_name)) 1317 _COLUMNS[(table_name, column_name)] = column 1318 1319 return column 1320 1321 1322def column(table_name, column_name=None, cache=False, cache_scope=_CS_FOREVER): 1323 """ 1324 Decorates functions that return a Series. 1325 1326 Decorator version of `add_column`. Series index must match 1327 the named table. Column name defaults to name of function. 1328 1329 The function's argument names and keyword argument values 1330 will be matched to registered variables when the function 1331 needs to be evaluated by Orca. 1332 The argument name "iter_var" may be used to have the current 1333 iteration variable injected. 1334 The index of the returned Series must match the named table. 1335 1336 """ 1337 def decorator(func): 1338 if column_name: 1339 name = column_name 1340 else: 1341 name = func.__name__ 1342 add_column( 1343 table_name, name, func, cache=cache, cache_scope=cache_scope) 1344 return func 1345 return decorator 1346 1347 1348def list_columns_for_table(table_name): 1349 """ 1350 Return a list of all the extra columns registered for a given table. 1351 1352 Parameters 1353 ---------- 1354 table_name : str 1355 1356 Returns 1357 ------- 1358 columns : list of str 1359 1360 """ 1361 return [cname for tname, cname in _COLUMNS.keys() if tname == table_name] 1362 1363 1364def _columns_for_table(table_name): 1365 """ 1366 Return all of the columns registered for a given table. 1367 1368 Parameters 1369 ---------- 1370 table_name : str 1371 1372 Returns 1373 ------- 1374 columns : dict of column wrappers 1375 Keys will be column names. 1376 1377 """ 1378 return {cname: col 1379 for (tname, cname), col in _COLUMNS.items() 1380 if tname == table_name} 1381 1382 1383def column_map(tables, columns): 1384 """ 1385 Take a list of tables and a list of column names and resolve which 1386 columns come from which table. 1387 1388 Parameters 1389 ---------- 1390 tables : sequence of _DataFrameWrapper or _TableFuncWrapper 1391 Could also be sequence of modified pandas.DataFrames, the important 1392 thing is that they have ``.name`` and ``.columns`` attributes. 1393 columns : sequence of str 1394 The column names of interest. 1395 1396 Returns 1397 ------- 1398 col_map : dict 1399 Maps table names to lists of column names. 1400 """ 1401 if not columns: 1402 return {t.name: None for t in tables} 1403 1404 columns = set(columns) 1405 colmap = { 1406 t.name: list(set(t.columns).intersection(columns)) for t in tables} 1407 foundcols = tz.reduce( 1408 lambda x, y: x.union(y), (set(v) for v in colmap.values())) 1409 if foundcols != columns: 1410 raise RuntimeError('Not all required columns were found. ' 1411 'Missing: {}'.format(list(columns - foundcols))) 1412 return colmap 1413 1414 1415def get_raw_column(table_name, column_name): 1416 """ 1417 Get a wrapped, registered column. 1418 1419 This function cannot return columns that are part of wrapped 1420 DataFrames, it's only for columns registered directly through Orca. 1421 1422 Parameters 1423 ---------- 1424 table_name : str 1425 column_name : str 1426 1427 Returns 1428 ------- 1429 wrapped : _SeriesWrapper or _ColumnFuncWrapper 1430 1431 """ 1432 try: 1433 return _COLUMNS[(table_name, column_name)] 1434 except KeyError: 1435 raise KeyError('column {!r} not found for table {!r}'.format( 1436 column_name, table_name)) 1437 1438 1439def _memoize_function(f, name, cache_scope=_CS_FOREVER): 1440 """ 1441 Wraps a function for memoization and ties it's cache into the 1442 Orca cacheing system. 1443 1444 Parameters 1445 ---------- 1446 f : function 1447 name : str 1448 Name of injectable. 1449 cache_scope : {'step', 'iteration', 'forever'}, optional 1450 Scope for which to cache data. Default is to cache forever 1451 (or until manually cleared). 'iteration' caches data for each 1452 complete iteration of the pipeline, 'step' caches data for 1453 a single step of the pipeline. 1454 1455 """ 1456 cache = {} 1457 1458 @wraps(f) 1459 def wrapper(*args, **kwargs): 1460 try: 1461 cache_key = ( 1462 args or None, frozenset(kwargs.items()) if kwargs else None) 1463 in_cache = cache_key in cache 1464 except TypeError: 1465 raise TypeError( 1466 'function arguments must be hashable for memoization') 1467 1468 if _CACHING and in_cache: 1469 return cache[cache_key] 1470 else: 1471 result = f(*args, **kwargs) 1472 cache[cache_key] = result 1473 return result 1474 1475 wrapper.__wrapped__ = f 1476 wrapper.cache = cache 1477 wrapper.clear_cached = lambda: cache.clear() 1478 _MEMOIZED[name] = CacheItem(name, wrapper, cache_scope) 1479 1480 return wrapper 1481 1482 1483def add_injectable( 1484 name, value, autocall=True, cache=False, cache_scope=_CS_FOREVER, 1485 memoize=False): 1486 """ 1487 Add a value that will be injected into other functions. 1488 1489 Parameters 1490 ---------- 1491 name : str 1492 value 1493 If a callable and `autocall` is True then the function's 1494 argument names and keyword argument values will be matched 1495 to registered variables when the function needs to be 1496 evaluated by Orca. The return value will 1497 be passed to any functions using this injectable. In all other 1498 cases, `value` will be passed through untouched. 1499 autocall : bool, optional 1500 Set to True to have injectable functions automatically called 1501 (with argument matching) and the result injected instead of 1502 the function itself. 1503 cache : bool, optional 1504 Whether to cache the return value of an injectable function. 1505 Only applies when `value` is a callable and `autocall` is True. 1506 cache_scope : {'step', 'iteration', 'forever'}, optional 1507 Scope for which to cache data. Default is to cache forever 1508 (or until manually cleared). 'iteration' caches data for each 1509 complete iteration of the pipeline, 'step' caches data for 1510 a single step of the pipeline. 1511 memoize : bool, optional 1512 If autocall is False it is still possible to cache function results 1513 by setting this flag to True. Cached values are stored in a dictionary 1514 keyed by argument values, so the argument values must be hashable. 1515 Memoized functions have their caches cleared according to the same 1516 rules as universal caching. 1517 1518 """ 1519 if isinstance(value, Callable): 1520 if autocall: 1521 value = _InjectableFuncWrapper( 1522 name, value, cache=cache, cache_scope=cache_scope) 1523 # clear any cached data from a previously registered value 1524 value.clear_cached() 1525 elif not autocall and memoize: 1526 value = _memoize_function(value, name, cache_scope=cache_scope) 1527 1528 logger.debug('registering injectable {!r}'.format(name)) 1529 _INJECTABLES[name] = value 1530 1531 1532def injectable( 1533 name=None, autocall=True, cache=False, cache_scope=_CS_FOREVER, 1534 memoize=False): 1535 """ 1536 Decorates functions that will be injected into other functions. 1537 1538 Decorator version of `add_injectable`. Name defaults to 1539 name of function. 1540 1541 The function's argument names and keyword argument values 1542 will be matched to registered variables when the function 1543 needs to be evaluated by Orca. 1544 The argument name "iter_var" may be used to have the current 1545 iteration variable injected. 1546 1547 """ 1548 def decorator(func): 1549 if name: 1550 n = name 1551 else: 1552 n = func.__name__ 1553 add_injectable( 1554 n, func, autocall=autocall, cache=cache, cache_scope=cache_scope, 1555 memoize=memoize) 1556 return func 1557 return decorator 1558 1559 1560def is_injectable(name): 1561 """ 1562 Checks whether a given name can be mapped to an injectable. 1563 1564 """ 1565 return name in _INJECTABLES 1566 1567 1568def get_raw_injectable(name): 1569 """ 1570 Return a raw, possibly wrapped injectable. 1571 1572 Parameters 1573 ---------- 1574 name : str 1575 1576 Returns 1577 ------- 1578 inj : _InjectableFuncWrapper or object 1579 1580 """ 1581 if is_injectable(name): 1582 return _INJECTABLES[name] 1583 else: 1584 raise KeyError('injectable not found: {!r}'.format(name)) 1585 1586 1587def injectable_type(name): 1588 """ 1589 Classify an injectable as either 'variable' or 'function'. 1590 1591 Parameters 1592 ---------- 1593 name : str 1594 1595 Returns 1596 ------- 1597 inj_type : {'variable', 'function'} 1598 If the injectable is an automatically called function or any other 1599 type of callable the type will be 'function', all other injectables 1600 will be have type 'variable'. 1601 1602 """ 1603 inj = get_raw_injectable(name) 1604 if isinstance(inj, (_InjectableFuncWrapper, Callable)): 1605 return 'function' 1606 else: 1607 return 'variable' 1608 1609 1610def get_injectable(name): 1611 """ 1612 Get an injectable by name. *Does not* evaluate wrapped functions. 1613 1614 Parameters 1615 ---------- 1616 name : str 1617 1618 Returns 1619 ------- 1620 injectable 1621 Original value or evaluated value of an _InjectableFuncWrapper. 1622 1623 """ 1624 i = get_raw_injectable(name) 1625 return i() if isinstance(i, _InjectableFuncWrapper) else i 1626 1627 1628def get_injectable_func_source_data(name): 1629 """ 1630 Return data about an injectable function's source, including file name, 1631 line number, and source code. 1632 1633 Parameters 1634 ---------- 1635 name : str 1636 1637 Returns 1638 ------- 1639 filename : str 1640 lineno : int 1641 The line number on which the function starts. 1642 source : str 1643 1644 """ 1645 if injectable_type(name) != 'function': 1646 raise ValueError('injectable {!r} is not a function'.format(name)) 1647 1648 inj = get_raw_injectable(name) 1649 1650 if isinstance(inj, _InjectableFuncWrapper): 1651 return utils.func_source_data(inj._func) 1652 elif hasattr(inj, '__wrapped__'): 1653 return utils.func_source_data(inj.__wrapped__) 1654 else: 1655 return utils.func_source_data(inj) 1656 1657 1658def add_step(step_name, func): 1659 """ 1660 Add a step function to Orca. 1661 1662 The function's argument names and keyword argument values 1663 will be matched to registered variables when the function 1664 needs to be evaluated by Orca. 1665 The argument name "iter_var" may be used to have the current 1666 iteration variable injected. 1667 1668 Parameters 1669 ---------- 1670 step_name : str 1671 func : callable 1672 1673 """ 1674 if isinstance(func, Callable): 1675 logger.debug('registering step {!r}'.format(step_name)) 1676 _STEPS[step_name] = _StepFuncWrapper(step_name, func) 1677 else: 1678 raise TypeError('func must be a callable') 1679 1680 1681def step(step_name=None): 1682 """ 1683 Decorates functions that will be called by the `run` function. 1684 1685 Decorator version of `add_step`. step name defaults to 1686 name of function. 1687 1688 The function's argument names and keyword argument values 1689 will be matched to registered variables when the function 1690 needs to be evaluated by Orca. 1691 The argument name "iter_var" may be used to have the current 1692 iteration variable injected. 1693 1694 """ 1695 def decorator(func): 1696 if step_name: 1697 name = step_name 1698 else: 1699 name = func.__name__ 1700 add_step(name, func) 1701 return func 1702 return decorator 1703 1704 1705def is_step(step_name): 1706 """ 1707 Check whether a given name refers to a registered step. 1708 1709 """ 1710 return step_name in _STEPS 1711 1712 1713def get_step(step_name): 1714 """ 1715 Get a wrapped step by name. 1716 1717 Parameters 1718 ---------- 1719 1720 """ 1721 if is_step(step_name): 1722 return _STEPS[step_name] 1723 else: 1724 raise KeyError('no step named {}'.format(step_name)) 1725 1726 1727Broadcast = namedtuple( 1728 'Broadcast', 1729 ['cast', 'onto', 'cast_on', 'onto_on', 'cast_index', 'onto_index']) 1730 1731 1732def broadcast(cast, onto, cast_on=None, onto_on=None, 1733 cast_index=False, onto_index=False): 1734 """ 1735 Register a rule for merging two tables by broadcasting one onto 1736 the other. 1737 1738 Parameters 1739 ---------- 1740 cast, onto : str 1741 Names of registered tables. 1742 cast_on, onto_on : str, optional 1743 Column names used for merge, equivalent of ``left_on``/``right_on`` 1744 parameters of pandas.merge. 1745 cast_index, onto_index : bool, optional 1746 Whether to use table indexes for merge. Equivalent of 1747 ``left_index``/``right_index`` parameters of pandas.merge. 1748 1749 """ 1750 logger.debug( 1751 'registering broadcast of table {!r} onto {!r}'.format(cast, onto)) 1752 _BROADCASTS[(cast, onto)] = \ 1753 Broadcast(cast, onto, cast_on, onto_on, cast_index, onto_index) 1754 1755 1756def _get_broadcasts(tables): 1757 """ 1758 Get the broadcasts associated with a set of tables. 1759 1760 Parameters 1761 ---------- 1762 tables : sequence of str 1763 Table names for which broadcasts have been registered. 1764 1765 Returns 1766 ------- 1767 casts : dict of `Broadcast` 1768 Keys are tuples of strings like (cast_name, onto_name). 1769 1770 """ 1771 tables = set(tables) 1772 casts = tz.keyfilter( 1773 lambda x: x[0] in tables and x[1] in tables, _BROADCASTS) 1774 if tables - set(tz.concat(casts.keys())): 1775 raise ValueError('Not enough links to merge all tables.') 1776 return casts 1777 1778 1779def is_broadcast(cast_name, onto_name): 1780 """ 1781 Checks whether a relationship exists for broadcast `cast_name` 1782 onto `onto_name`. 1783 1784 """ 1785 return (cast_name, onto_name) in _BROADCASTS 1786 1787 1788def get_broadcast(cast_name, onto_name): 1789 """ 1790 Get a single broadcast. 1791 1792 Broadcasts are stored data about how to do a Pandas join. 1793 A Broadcast object is a namedtuple with these attributes: 1794 1795 - cast: the name of the table being broadcast 1796 - onto: the name of the table onto which "cast" is broadcast 1797 - cast_on: The optional name of a column on which to join. 1798 None if the table index will be used instead. 1799 - onto_on: The optional name of a column on which to join. 1800 None if the table index will be used instead. 1801 - cast_index: True if the table index should be used for the join. 1802 - onto_index: True if the table index should be used for the join. 1803 1804 Parameters 1805 ---------- 1806 cast_name : str 1807 The name of the table being braodcast. 1808 onto_name : str 1809 The name of the table onto which `cast_name` is broadcast. 1810 1811 Returns 1812 ------- 1813 broadcast : Broadcast 1814 1815 """ 1816 if is_broadcast(cast_name, onto_name): 1817 return _BROADCASTS[(cast_name, onto_name)] 1818 else: 1819 raise KeyError( 1820 'no rule found for broadcasting {!r} onto {!r}'.format( 1821 cast_name, onto_name)) 1822 1823 1824# utilities for merge_tables 1825def _all_reachable_tables(t): 1826 """ 1827 A generator that provides all the names of tables that can be 1828 reached via merges starting at the given target table. 1829 1830 """ 1831 for k, v in t.items(): 1832 for tname in _all_reachable_tables(v): 1833 yield tname 1834 yield k 1835 1836 1837def _recursive_getitem(d, key): 1838 """ 1839 Descend into a dict of dicts to return the one that contains 1840 a given key. Every value in the dict must be another dict. 1841 1842 """ 1843 if key in d: 1844 return d 1845 else: 1846 for v in d.values(): 1847 return _recursive_getitem(v, key) 1848 else: 1849 raise KeyError('Key not found: {}'.format(key)) 1850 1851 1852def _dict_value_to_pairs(d): 1853 """ 1854 Takes the first value of a dictionary (which it self should be 1855 a dictionary) and turns it into a series of {key: value} dicts. 1856 1857 For example, _dict_value_to_pairs({'c': {'a': 1, 'b': 2}}) will yield 1858 {'a': 1} and {'b': 2}. 1859 1860 """ 1861 d = d[tz.first(d)] 1862 1863 for k, v in d.items(): 1864 yield {k: v} 1865 1866 1867def _is_leaf_node(merge_node): 1868 """ 1869 Returns True for dicts like {'a': {}}. 1870 1871 """ 1872 return len(merge_node) == 1 and not next(iter(merge_node.values())) 1873 1874 1875def _next_merge(merge_node): 1876 """ 1877 Gets a node that has only leaf nodes below it. This table and 1878 the ones below are ready to be merged to make a new leaf node. 1879 1880 """ 1881 if all(_is_leaf_node(d) for d in _dict_value_to_pairs(merge_node)): 1882 return merge_node 1883 else: 1884 for d in tz.remove(_is_leaf_node, _dict_value_to_pairs(merge_node)): 1885 return _next_merge(d) 1886 else: 1887 raise OrcaError('No node found for next merge.') 1888 1889 1890def merge_tables(target, tables, columns=None, drop_intersection=True): 1891 """ 1892 Merge a number of tables onto a target table. Tables must have 1893 registered merge rules via the `broadcast` function. 1894 1895 Parameters 1896 ---------- 1897 target : str, DataFrameWrapper, or TableFuncWrapper 1898 Name of the table (or wrapped table) onto which tables will be merged. 1899 tables : list of `DataFrameWrapper`, `TableFuncWrapper`, or str 1900 All of the tables to merge. Should include the target table. 1901 columns : list of str, optional 1902 If given, columns will be mapped to `tables` and only those columns 1903 will be requested from each table. The final merged table will have 1904 only these columns. By default all columns are used from every 1905 table. 1906 drop_intersection : bool 1907 If True, keep the left most occurence of any column name if it occurs 1908 on more than one table. This prevents getting back the same column 1909 with suffixes applied by pd.merge. If false, columns names will be 1910 suffixed with the table names - e.g. zone_id_buildings and 1911 zone_id_parcels. 1912 1913 Returns 1914 ------- 1915 merged : pandas.DataFrame 1916 1917 """ 1918 # allow target to be string or table wrapper 1919 if isinstance(target, (DataFrameWrapper, TableFuncWrapper)): 1920 target = target.name 1921 1922 # allow tables to be strings or table wrappers 1923 tables = [get_table(t) 1924 if not isinstance(t, (DataFrameWrapper, TableFuncWrapper)) else t 1925 for t in tables] 1926 1927 merges = {t.name: {} for t in tables} 1928 tables = {t.name: t for t in tables} 1929 casts = _get_broadcasts(tables.keys()) 1930 logger.debug( 1931 'attempting to merge tables {} to target table {}'.format( 1932 tables.keys(), target)) 1933 1934 # relate all the tables by registered broadcasts 1935 for table, onto in casts: 1936 merges[onto][table] = merges[table] 1937 merges = {target: merges[target]} 1938 1939 # verify that all the tables can be merged to the target 1940 all_tables = set(_all_reachable_tables(merges)) 1941 1942 if all_tables != set(tables.keys()): 1943 raise RuntimeError( 1944 ('Not all tables can be merged to target "{}". Unlinked tables: {}' 1945 ).format(target, list(set(tables.keys()) - all_tables))) 1946 1947 # add any columns necessary for indexing into other tables 1948 # during merges 1949 if columns: 1950 columns = list(columns) 1951 for c in casts.values(): 1952 if c.onto_on: 1953 columns.append(c.onto_on) 1954 if c.cast_on: 1955 columns.append(c.cast_on) 1956 1957 # get column map for which columns go with which table 1958 colmap = column_map(tables.values(), columns) 1959 1960 # get frames 1961 frames = {name: t.to_frame(columns=colmap[name]) 1962 for name, t in tables.items()} 1963 1964 past_intersections = set() 1965 1966 # perform merges until there's only one table left 1967 while merges[target]: 1968 nm = _next_merge(merges) 1969 onto = tz.first(nm) 1970 onto_table = frames[onto] 1971 1972 # loop over all the tables that can be broadcast onto 1973 # the onto_table and merge them all in. 1974 for cast in nm[onto]: 1975 cast_table = frames[cast] 1976 bc = casts[(cast, onto)] 1977 1978 with log_start_finish( 1979 'merge tables {} and {}'.format(onto, cast), logger): 1980 1981 intersection = set(onto_table.columns).\ 1982 intersection(cast_table.columns) 1983 # intersection is ok if it's the join key 1984 intersection.discard(bc.onto_on) 1985 intersection.discard(bc.cast_on) 1986 # otherwise drop so as not to create conflicts 1987 if drop_intersection: 1988 cast_table = cast_table.drop(intersection, axis=1) 1989 else: 1990 # add suffix to past intersections which wouldn't get 1991 # picked up by the merge - these we have to rename by hand 1992 renames = dict(zip( 1993 past_intersections, 1994 [c+'_'+onto for c in past_intersections] 1995 )) 1996 onto_table = onto_table.rename(columns=renames) 1997 1998 # keep track of past intersections in case there's an odd 1999 # number of intersections 2000 past_intersections = past_intersections.union(intersection) 2001 2002 onto_table = pd.merge( 2003 onto_table, cast_table, 2004 suffixes=['_'+onto, '_'+cast], 2005 left_on=bc.onto_on, right_on=bc.cast_on, 2006 left_index=bc.onto_index, right_index=bc.cast_index) 2007 2008 # replace the existing table with the merged one 2009 frames[onto] = onto_table 2010 2011 # free up space by dropping the cast table 2012 del frames[cast] 2013 2014 # mark the onto table as having no more things to broadcast 2015 # onto it. 2016 _recursive_getitem(merges, onto)[onto] = {} 2017 2018 logger.debug('finished merge') 2019 return frames[target] 2020 2021 2022def get_step_table_names(steps): 2023 """ 2024 Returns a list of table names injected into the provided steps. 2025 2026 Parameters 2027 ---------- 2028 steps: list of str 2029 Steps to gather table inputs from. 2030 2031 Returns 2032 ------- 2033 list of str 2034 2035 """ 2036 table_names = set() 2037 for s in steps: 2038 table_names |= get_step(s)._tables_used() 2039 return list(table_names) 2040 2041 2042def write_tables(fname, table_names=None, prefix=None, compress=False, local=False): 2043 """ 2044 Writes tables to a pandas.HDFStore file. 2045 2046 Parameters 2047 ---------- 2048 fname : str 2049 File name for HDFStore. Will be opened in append mode and closed 2050 at the end of this function. 2051 table_names: list of str, optional, default None 2052 List of tables to write. If None, all registered tables will 2053 be written. 2054 prefix: str 2055 If not None, used to prefix the output table names so that 2056 multiple iterations can go in the same file. 2057 compress: boolean 2058 Whether to compress output file using standard HDF5-readable 2059 zlib compression, default False. 2060 2061 """ 2062 if table_names is None: 2063 table_names = list_tables() 2064 2065 tables = (get_table(t) for t in table_names) 2066 key_template = '{}/{{}}'.format(prefix) if prefix is not None else '{}' 2067 2068 # set compression options to zlib level-1 if compress arg is True 2069 complib = compress and 'zlib' or None 2070 complevel = compress and 1 or 0 2071 2072 with pd.HDFStore(fname, mode='a', complib=complib, complevel=complevel) as store: 2073 for t in tables: 2074 # if local arg is True, store only local columns 2075 columns = None 2076 if local is True: 2077 columns = t.local_columns 2078 store[key_template.format(t.name)] = t.to_frame(columns=columns) 2079 2080 2081iter_step = namedtuple('iter_step', 'step_num,step_name') 2082 2083 2084def run(steps, iter_vars=None, data_out=None, out_interval=1, 2085 out_base_tables=None, out_run_tables=None, compress=False, 2086 out_base_local=True, out_run_local=True): 2087 """ 2088 Run steps in series, optionally repeatedly over some sequence. 2089 The current iteration variable is set as a global injectable 2090 called ``iter_var``. 2091 2092 Parameters 2093 ---------- 2094 steps : list of str 2095 List of steps to run identified by their name. 2096 iter_vars : iterable, optional 2097 The values of `iter_vars` will be made available as an injectable 2098 called ``iter_var`` when repeatedly running `steps`. 2099 data_out : str, optional 2100 An optional filename to which all tables injected into any step 2101 in `steps` will be saved every `out_interval` iterations. 2102 File will be a pandas HDF data store. 2103 out_interval : int, optional 2104 Iteration interval on which to save data to `data_out`. For example, 2105 2 will save out every 2 iterations, 5 every 5 iterations. 2106 Default is every iteration. 2107 The results of the first and last iterations are always included. 2108 The input (base) tables are also included and prefixed with `base/`, 2109 these represent the state of the system before any steps have been 2110 executed. 2111 The interval is defined relative to the first iteration. For example, 2112 a run begining in 2015 with an out_interval of 2, will write out 2113 results for 2015, 2017, etc. 2114 out_base_tables: list of str, optional, default None 2115 List of base tables to write. If not provided, tables injected 2116 into 'steps' will be written. 2117 out_run_tables: list of str, optional, default None 2118 List of run tables to write. If not provided, tables injected 2119 into 'steps' will be written. 2120 compress: boolean, optional, default False 2121 Whether to compress output file using standard HDF5 zlib compression. 2122 Compression yields much smaller files using slightly more CPU. 2123 out_base_local: boolean, optional, default True 2124 For tables in out_base_tables, whether to store only local columns (True) 2125 or both, local and computed columns (False). 2126 out_run_local: boolean, optional, default True 2127 For tables in out_run_tables, whether to store only local columns (True) 2128 or both, local and computed columns (False). 2129 """ 2130 iter_vars = iter_vars or [None] 2131 max_i = len(iter_vars) 2132 2133 # get the tables to write out 2134 if out_base_tables is None or out_run_tables is None: 2135 step_tables = get_step_table_names(steps) 2136 2137 if out_base_tables is None: 2138 out_base_tables = step_tables 2139 2140 if out_run_tables is None: 2141 out_run_tables = step_tables 2142 2143 # write out the base (inputs) 2144 if data_out: 2145 add_injectable('iter_var', iter_vars[0]) 2146 write_tables(data_out, out_base_tables, 'base', compress=compress, local=out_base_local) 2147 2148 # run the steps 2149 for i, var in enumerate(iter_vars, start=1): 2150 add_injectable('iter_var', var) 2151 2152 if var is not None: 2153 print('Running iteration {} with iteration value {!r}'.format( 2154 i, var)) 2155 logger.debug( 2156 'running iteration {} with iteration value {!r}'.format( 2157 i, var)) 2158 2159 t1 = time.time() 2160 for j, step_name in enumerate(steps): 2161 add_injectable('iter_step', iter_step(j, step_name)) 2162 print('Running step {!r}'.format(step_name)) 2163 with log_start_finish( 2164 'run step {!r}'.format(step_name), logger, 2165 logging.INFO): 2166 step = get_step(step_name) 2167 t2 = time.time() 2168 step() 2169 print("Time to execute step '{}': {:.2f} s".format( 2170 step_name, time.time() - t2)) 2171 clear_cache(scope=_CS_STEP) 2172 2173 print( 2174 ('Total time to execute iteration {} ' 2175 'with iteration value {!r}: ' 2176 '{:.2f} s').format(i, var, time.time() - t1)) 2177 2178 # write out the results for the current iteration 2179 if data_out: 2180 if (i - 1) % out_interval == 0 or i == max_i: 2181 write_tables(data_out, out_run_tables, var, compress=compress, local=out_run_local) 2182 2183 clear_cache(scope=_CS_ITER) 2184 2185 2186@contextmanager 2187def injectables(**kwargs): 2188 """ 2189 Temporarily add injectables to the pipeline environment. 2190 Takes only keyword arguments. 2191 2192 Injectables will be returned to their original state when the context 2193 manager exits. 2194 2195 """ 2196 global _INJECTABLES 2197 2198 original = _INJECTABLES.copy() 2199 _INJECTABLES.update(kwargs) 2200 yield 2201 _INJECTABLES = original 2202 2203 2204@contextmanager 2205def temporary_tables(**kwargs): 2206 """ 2207 Temporarily set DataFrames as registered tables. 2208 2209 Tables will be returned to their original state when the context 2210 manager exits. Caching is not enabled for tables registered via 2211 this function. 2212 2213 """ 2214 global _TABLES 2215 2216 original = _TABLES.copy() 2217 2218 for k, v in kwargs.items(): 2219 if not isinstance(v, pd.DataFrame): 2220 raise ValueError('tables only accepts DataFrames') 2221 add_table(k, v) 2222 2223 yield 2224 2225 _TABLES = original 2226 2227 2228def eval_variable(name, **kwargs): 2229 """ 2230 Execute a single variable function registered with Orca 2231 and return the result. Any keyword arguments are temporarily set 2232 as injectables. This gives the value as would be injected into a function. 2233 2234 Parameters 2235 ---------- 2236 name : str 2237 Name of variable to evaluate. 2238 Use variable expressions to specify columns. 2239 2240 Returns 2241 ------- 2242 object 2243 For injectables and columns this directly returns whatever 2244 object is returned by the registered function. 2245 For tables this returns a DataFrameWrapper as if the table 2246 had been injected into a function. 2247 2248 """ 2249 with injectables(**kwargs): 2250 vars = _collect_variables([name], [name]) 2251 return vars[name] 2252 2253 2254def eval_step(name, **kwargs): 2255 """ 2256 Evaluate a step as would be done within the pipeline environment 2257 and return the result. Any keyword arguments are temporarily set 2258 as injectables. 2259 2260 Parameters 2261 ---------- 2262 name : str 2263 Name of step to run. 2264 2265 Returns 2266 ------- 2267 object 2268 Anything returned by a step. (Though note that in Orca runs 2269 return values from steps are ignored.) 2270 2271 """ 2272 with injectables(**kwargs): 2273 return get_step(name)() 2274