1# Orca
2# Copyright (C) 2016 UrbanSim Inc.
3# See full license in LICENSE.
4
5from __future__ import print_function
6
7try:
8    from inspect import getfullargspec as getargspec
9except ImportError:
10    from inspect import getargspec
11import logging
12import time
13import warnings
14from collections import namedtuple
15try:
16    from collections.abc import Callable
17except ImportError:  # Python 2.7
18    from collections import Callable
19from contextlib import contextmanager
20from functools import wraps
21
22
23import pandas as pd
24import tables
25import tlz as tz
26
27from . import utils
28from .utils.logutil import log_start_finish
29
30warnings.filterwarnings('ignore', category=tables.NaturalNameWarning)
31logger = logging.getLogger(__name__)
32
33_TABLES = {}
34_COLUMNS = {}
35_STEPS = {}
36_BROADCASTS = {}
37_INJECTABLES = {}
38
39_CACHING = True
40_TABLE_CACHE = {}
41_COLUMN_CACHE = {}
42_INJECTABLE_CACHE = {}
43_MEMOIZED = {}
44
45_CS_FOREVER = 'forever'
46_CS_ITER = 'iteration'
47_CS_STEP = 'step'
48
49CacheItem = namedtuple('CacheItem', ['name', 'value', 'scope'])
50
51
52def clear_all():
53    """
54    Clear any and all stored state from Orca.
55
56    """
57    _TABLES.clear()
58    _COLUMNS.clear()
59    _STEPS.clear()
60    _BROADCASTS.clear()
61    _INJECTABLES.clear()
62    _TABLE_CACHE.clear()
63    _COLUMN_CACHE.clear()
64    _INJECTABLE_CACHE.clear()
65    for m in _MEMOIZED.values():
66        m.value.clear_cached()
67    _MEMOIZED.clear()
68    logger.debug('pipeline state cleared')
69
70
71def clear_cache(scope=None):
72    """
73    Clear all cached data.
74
75    Parameters
76    ----------
77    scope : {None, 'step', 'iteration', 'forever'}, optional
78        Clear cached values with a given scope.
79        By default all cached values are removed.
80
81    """
82    if not scope:
83        _TABLE_CACHE.clear()
84        _COLUMN_CACHE.clear()
85        _INJECTABLE_CACHE.clear()
86        for m in _MEMOIZED.values():
87            m.value.clear_cached()
88        logger.debug('pipeline cache cleared')
89    else:
90        for d in (_TABLE_CACHE, _COLUMN_CACHE, _INJECTABLE_CACHE):
91            items = tz.valfilter(lambda x: x.scope == scope, d)
92            for k in items:
93                del d[k]
94        for m in tz.filter(lambda x: x.scope == scope, _MEMOIZED.values()):
95            m.value.clear_cached()
96        logger.debug('cleared cached values with scope {!r}'.format(scope))
97
98
99def clear_injectable(injectable_name):
100    """
101    Clear the cached value of an injectable. *Added in Orca v1.6.*
102
103    Parameters
104    ----------
105    name: str
106        Name of injectable to clear.
107
108    """
109    _INJECTABLES[injectable_name].clear_cached()
110
111
112def clear_table(table_name):
113    """
114    Clear the cached copy of an entire table. *Added in Orca v1.6.*
115
116    Parameters
117    ----------
118    name: str
119        Name of table to clear.
120
121    """
122    _TABLES[table_name].clear_cached()
123
124
125def clear_column(table_name, column_name):
126    """
127    Clear the cached copy of a dynamically generated column.
128    *Added in Orca v1.6.*
129
130    Parameters
131    ----------
132    table_name: str
133        Table containing the column to clear.
134    column_name: str
135        Name of the column to clear.
136
137    """
138    _COLUMNS[(table_name, column_name)].clear_cached()
139
140
141def clear_columns(table_name, columns=None):
142    """
143    Clear all (or a specified list) of the dynamically generated columns
144    associated with a table. *Added in Orca v1.6.*
145
146    Parameters
147    ----------
148    table_name: str
149        Table name.
150    columns:  list of str, optional, default None
151        List of columns to clear. If None, all extra/computed
152        columns in the table will be cleeared.
153
154    """
155    if columns is None:
156        tab = get_table(table_name)
157        cols = tab.columns
158        local_cols = tab.local_columns
159        columns = [c for c in cols if c not in local_cols]
160        print('****************')
161        print(columns)
162
163    for col in columns:
164        clear_column(table_name, col)
165
166
167def _update_scope(wrapper, new_scope=None):
168    """
169    Update the cache scope for a wrapper (in place).
170    *Added in Orca v1.6.*
171
172    Parameters
173    ----------
174    wrapper: object
175        Should be an instance of wrapper with attributes
176        `cache`, `cache_scope` and method `clear_cached`.
177    new_scope: str, optional default None
178        The new scope value. None implies no caching.
179
180    """
181    # allowable scopes, values indicate the update granularity
182    scopes = {
183        None: 0,
184        _CS_STEP: 1,
185        _CS_ITER: 2,
186        _CS_FOREVER: 3
187    }
188    if new_scope not in scopes.keys():
189        msg = '{} is not an allowed cache scope, '.format(new_scope)
190        msg += 'allowed scopes are {}'.format(list(scopes.keys()))
191        raise ValueError(msg)
192
193    # update the cache properties
194    curr_cache = wrapper.cache
195    curr_scope = wrapper.cache_scope
196    if new_scope is None:
197        # set to defaults, i.e. no caching
198        wrapper.cache = False
199        wrapper.cache_scope = _CS_FOREVER
200    else:
201        wrapper.cache = True
202        wrapper.cache_scope = new_scope
203
204    # clear out any existing caches if the provided scope is
205    # more granular than the existing
206    old_granularity = scopes[curr_scope]
207    new_granularity = scopes[new_scope]
208    if new_granularity < old_granularity:
209        wrapper.clear_cached()
210
211
212def update_injectable_scope(name, new_scope=None):
213    """
214    Update the cache scope for a wrapped injectable function.
215    Clears the cache if the new scope is more granular
216    than the existing. *Added in Orca v1.6.*
217
218    Parameters
219    ----------
220    name: str
221        Name of the injectable to update.
222    new_scope: str, optional default None
223        Valid values: None, 'forever', 'iteration', 'step'
224        None implies no caching.
225
226    """
227    _update_scope(
228        get_raw_injectable(name), new_scope)
229
230
231def update_column_scope(table_name, column_name, new_scope=None):
232    """
233    Update the cache scope for a wrapped column function. Clears
234    the cache if the new scope is more granular than the existing.
235    *Added in Orca v1.6.*
236
237    Parameters
238    ----------
239    table_name: str
240        Name of the table.
241    column_name: str
242        Name of the column to update.
243    new_scope: str, optional default None
244        Valid values: None, 'forever', 'iteration', 'step'
245        None implies no caching.
246
247    """
248    _update_scope(
249        get_raw_column(table_name, column_name), new_scope)
250
251
252def update_table_scope(name, new_scope=None):
253    """
254    Update the cache scope for a wrapped table function. Clears
255    the cache if the new scope is more granular than the existing.
256    *Added in Orca v1.6.*
257
258    Parameters
259    ----------
260    name: str
261        Name of the table to update.
262    new_scope: str, optional default None
263        Valid values: None, 'forever', 'iteration', 'step'
264        None implies no caching.
265
266    """
267    _update_scope(
268        get_raw_table(name), new_scope)
269
270
271def enable_cache():
272    """
273    Allow caching of registered variables that explicitly have
274    caching enabled.
275
276    """
277    global _CACHING
278    _CACHING = True
279
280
281def disable_cache():
282    """
283    Turn off caching across Orca, even for registered variables
284    that have caching enabled.
285
286    """
287    global _CACHING
288    _CACHING = False
289
290
291def cache_on():
292    """
293    Whether caching is currently enabled or disabled.
294
295    Returns
296    -------
297    on : bool
298        True if caching is enabled.
299
300    """
301    return _CACHING
302
303
304@contextmanager
305def cache_disabled():
306    turn_back_on = True if cache_on() else False
307    disable_cache()
308
309    yield
310
311    if turn_back_on:
312        enable_cache()
313
314
315# for errors that occur during Orca runs
316class OrcaError(Exception):
317    pass
318
319
320class DataFrameWrapper(object):
321    """
322    Wraps a DataFrame so it can provide certain columns and handle
323    computed columns.
324
325    Parameters
326    ----------
327    name : str
328        Name for the table.
329    frame : pandas.DataFrame
330    copy_col : bool, optional
331        Whether to return copies when evaluating columns.
332
333    Attributes
334    ----------
335    name : str
336        Table name.
337    copy_col : bool
338        Whether to return copies when evaluating columns.
339    local : pandas.DataFrame
340        The wrapped DataFrame.
341
342    """
343    def __init__(self, name, frame, copy_col=True):
344        self.name = name
345        self.local = frame
346        self.copy_col = copy_col
347
348    @property
349    def columns(self):
350        """
351        Columns in this table.
352
353        """
354        return self.local_columns + list_columns_for_table(self.name)
355
356    @property
357    def local_columns(self):
358        """
359        Columns that are part of the wrapped DataFrame.
360
361        """
362        return list(self.local.columns)
363
364    @property
365    def index(self):
366        """
367        Table index.
368
369        """
370        return self.local.index
371
372    def to_frame(self, columns=None):
373        """
374        Make a DataFrame with the given columns.
375
376        Will always return a copy of the underlying table.
377
378        Parameters
379        ----------
380        columns : sequence or string, optional
381            Sequence of the column names desired in the DataFrame. A string
382            can also be passed if only one column is desired.
383            If None all columns are returned, including registered columns.
384
385        Returns
386        -------
387        frame : pandas.DataFrame
388
389        """
390        extra_cols = _columns_for_table(self.name)
391
392        if columns is not None:
393            columns = [columns] if isinstance(columns, str) else columns
394            columns = set(columns)
395            set_extra_cols = set(extra_cols)
396            local_cols = set(self.local.columns) & columns - set_extra_cols
397            df = self.local[list(local_cols)].copy()
398            extra_cols = {k: extra_cols[k] for k in (columns & set_extra_cols)}
399        else:
400            df = self.local.copy()
401
402        with log_start_finish(
403                'computing {!r} columns for table {!r}'.format(
404                    len(extra_cols), self.name),
405                logger):
406            for name, col in extra_cols.items():
407                with log_start_finish(
408                        'computing column {!r} for table {!r}'.format(
409                            name, self.name),
410                        logger):
411                    df[name] = col()
412
413        return df
414
415    def update_col(self, column_name, series):
416        """
417        Add or replace a column in the underlying DataFrame.
418
419        Parameters
420        ----------
421        column_name : str
422            Column to add or replace.
423        series : pandas.Series or sequence
424            Column data.
425
426        """
427        logger.debug('updating column {!r} in table {!r}'.format(
428            column_name, self.name))
429        self.local[column_name] = series
430
431    def __setitem__(self, key, value):
432        return self.update_col(key, value)
433
434    def get_column(self, column_name):
435        """
436        Returns a column as a Series.
437
438        Parameters
439        ----------
440        column_name : str
441
442        Returns
443        -------
444        column : pandas.Series
445
446        """
447        with log_start_finish(
448                'getting single column {!r} from table {!r}'.format(
449                    column_name, self.name),
450                logger):
451            extra_cols = _columns_for_table(self.name)
452            if column_name in extra_cols:
453                with log_start_finish(
454                        'computing column {!r} for table {!r}'.format(
455                            column_name, self.name),
456                        logger):
457                    column = extra_cols[column_name]()
458            else:
459                column = self.local[column_name]
460            if self.copy_col:
461                return column.copy()
462            else:
463                return column
464
465    def __getitem__(self, key):
466        return self.get_column(key)
467
468    def __getattr__(self, key):
469        return self.get_column(key)
470
471    def column_type(self, column_name):
472        """
473        Report column type as one of 'local', 'series', or 'function'.
474
475        Parameters
476        ----------
477        column_name : str
478
479        Returns
480        -------
481        col_type : {'local', 'series', 'function'}
482            'local' means that the column is part of the registered table,
483            'series' means the column is a registered Pandas Series,
484            and 'function' means the column is a registered function providing
485            a Pandas Series.
486
487        """
488        extra_cols = list_columns_for_table(self.name)
489
490        if column_name in extra_cols:
491            col = _COLUMNS[(self.name, column_name)]
492
493            if isinstance(col, _SeriesWrapper):
494                return 'series'
495            elif isinstance(col, _ColumnFuncWrapper):
496                return 'function'
497
498        elif column_name in self.local_columns:
499            return 'local'
500
501        raise KeyError('column {!r} not found'.format(column_name))
502
503    def update_col_from_series(self, column_name, series, cast=False):
504        """
505        Update existing values in a column from another series.
506        Index values must match in both column and series. Optionally
507        casts data type to match the existing column.
508
509        Parameters
510        ---------------
511        column_name : str
512        series : panas.Series
513        cast: bool, optional, default False
514        """
515        logger.debug('updating column {!r} in table {!r}'.format(
516            column_name, self.name))
517
518        col_dtype = self.local[column_name].dtype
519        if series.dtype != col_dtype:
520            if cast:
521                series = series.astype(col_dtype)
522            else:
523                err_msg = "Data type mismatch, existing:{}, update:{}"
524                err_msg = err_msg.format(col_dtype, series.dtype)
525                raise ValueError(err_msg)
526
527        self.local.loc[series.index, column_name] = series
528
529    def __len__(self):
530        return len(self.local)
531
532    def clear_cached(self):
533        """
534        Remove cached results from this table's computed columns.
535
536        """
537        _TABLE_CACHE.pop(self.name, None)
538        for col in _columns_for_table(self.name).values():
539            col.clear_cached()
540        logger.debug('cleared cached columns for table {!r}'.format(self.name))
541
542
543class TableFuncWrapper(object):
544    """
545    Wrap a function that provides a DataFrame.
546
547    Parameters
548    ----------
549    name : str
550        Name for the table.
551    func : callable
552        Callable that returns a DataFrame.
553    cache : bool, optional
554        Whether to cache the results of calling the wrapped function.
555    cache_scope : {'step', 'iteration', 'forever'}, optional
556        Scope for which to cache data. Default is to cache forever
557        (or until manually cleared). 'iteration' caches data for each
558        complete iteration of the pipeline, 'step' caches data for
559        a single step of the pipeline.
560    copy_col : bool, optional
561        Whether to return copies when evaluating columns.
562
563    Attributes
564    ----------
565    name : str
566        Table name.
567    cache : bool
568        Whether caching is enabled for this table.
569    copy_col : bool
570        Whether to return copies when evaluating columns.
571
572    """
573    def __init__(
574            self, name, func, cache=False, cache_scope=_CS_FOREVER,
575            copy_col=True):
576        self.name = name
577        self._func = func
578        self._argspec = getargspec(func)
579        self.cache = cache
580        self.cache_scope = cache_scope
581        self.copy_col = copy_col
582        self._columns = []
583        self._index = None
584        self._len = 0
585
586    @property
587    def columns(self):
588        """
589        Columns in this table. (May contain only computed columns
590        if the wrapped function has not been called yet.)
591
592        """
593        return self._columns + list_columns_for_table(self.name)
594
595    @property
596    def local_columns(self):
597        """
598        Only the columns contained in the DataFrame returned by the
599        wrapped function. (No registered columns included.)
600
601        """
602        if self._columns:
603            return self._columns
604        else:
605            self._call_func()
606            return self._columns
607
608    @property
609    def index(self):
610        """
611        Index of the underlying table. Will be None if that index is
612        unknown.
613
614        """
615        return self._index
616
617    def _call_func(self):
618        """
619        Call the wrapped function and return the result wrapped by
620        DataFrameWrapper.
621        Also updates attributes like columns, index, and length.
622
623        """
624        if _CACHING and self.cache and self.name in _TABLE_CACHE:
625            logger.debug('returning table {!r} from cache'.format(self.name))
626            return _TABLE_CACHE[self.name].value
627
628        with log_start_finish(
629                'call function to get frame for table {!r}'.format(
630                    self.name),
631                logger):
632            kwargs = _collect_variables(names=self._argspec.args,
633                                        expressions=self._argspec.defaults)
634            frame = self._func(**kwargs)
635
636        self._columns = list(frame.columns)
637        self._index = frame.index
638        self._len = len(frame)
639
640        wrapped = DataFrameWrapper(self.name, frame, copy_col=self.copy_col)
641
642        if self.cache:
643            _TABLE_CACHE[self.name] = CacheItem(
644                self.name, wrapped, self.cache_scope)
645
646        return wrapped
647
648    def __call__(self):
649        return self._call_func()
650
651    def to_frame(self, columns=None):
652        """
653        Make a DataFrame with the given columns.
654
655        Will always return a copy of the underlying table.
656
657        Parameters
658        ----------
659        columns : sequence, optional
660            Sequence of the column names desired in the DataFrame.
661            If None all columns are returned.
662
663        Returns
664        -------
665        frame : pandas.DataFrame
666
667        """
668        return self._call_func().to_frame(columns)
669
670    def get_column(self, column_name):
671        """
672        Returns a column as a Series.
673
674        Parameters
675        ----------
676        column_name : str
677
678        Returns
679        -------
680        column : pandas.Series
681
682        """
683        frame = self._call_func()
684        return DataFrameWrapper(self.name, frame,
685                                copy_col=self.copy_col).get_column(column_name)
686
687    def __getitem__(self, key):
688        return self.get_column(key)
689
690    def __getattr__(self, key):
691        return self.get_column(key)
692
693    def __len__(self):
694        return self._len
695
696    def column_type(self, column_name):
697        """
698        Report column type as one of 'local', 'series', or 'function'.
699
700        Parameters
701        ----------
702        column_name : str
703
704        Returns
705        -------
706        col_type : {'local', 'series', 'function'}
707            'local' means that the column is part of the registered table,
708            'series' means the column is a registered Pandas Series,
709            and 'function' means the column is a registered function providing
710            a Pandas Series.
711
712        """
713        extra_cols = list_columns_for_table(self.name)
714
715        if column_name in extra_cols:
716            col = _COLUMNS[(self.name, column_name)]
717
718            if isinstance(col, _SeriesWrapper):
719                return 'series'
720            elif isinstance(col, _ColumnFuncWrapper):
721                return 'function'
722
723        elif column_name in self.local_columns:
724            return 'local'
725
726        raise KeyError('column {!r} not found'.format(column_name))
727
728    def clear_cached(self):
729        """
730        Remove this table's cached result and that of associated columns.
731
732        """
733        _TABLE_CACHE.pop(self.name, None)
734        for col in _columns_for_table(self.name).values():
735            col.clear_cached()
736        logger.debug(
737            'cleared cached result and cached columns for table {!r}'.format(
738                self.name))
739
740    def func_source_data(self):
741        """
742        Return data about the wrapped function source, including file name,
743        line number, and source code.
744
745        Returns
746        -------
747        filename : str
748        lineno : int
749            The line number on which the function starts.
750        source : str
751
752        """
753        return utils.func_source_data(self._func)
754
755
756class _ColumnFuncWrapper(object):
757    """
758    Wrap a function that returns a Series.
759
760    Parameters
761    ----------
762    table_name : str
763        Table with which the column will be associated.
764    column_name : str
765        Name for the column.
766    func : callable
767        Should return a Series that has an
768        index matching the table to which it is being added.
769    cache : bool, optional
770        Whether to cache the result of calling the wrapped function.
771    cache_scope : {'step', 'iteration', 'forever'}, optional
772        Scope for which to cache data. Default is to cache forever
773        (or until manually cleared). 'iteration' caches data for each
774        complete iteration of the pipeline, 'step' caches data for
775        a single step of the pipeline.
776
777    Attributes
778    ----------
779    name : str
780        Column name.
781    table_name : str
782        Name of table this column is associated with.
783    cache : bool
784        Whether caching is enabled for this column.
785
786    """
787    def __init__(
788            self, table_name, column_name, func, cache=False,
789            cache_scope=_CS_FOREVER):
790        self.table_name = table_name
791        self.name = column_name
792        self._func = func
793        self._argspec = getargspec(func)
794        self.cache = cache
795        self.cache_scope = cache_scope
796
797    def __call__(self):
798        """
799        Evaluate the wrapped function and return the result.
800
801        """
802        if (_CACHING and
803                self.cache and
804                (self.table_name, self.name) in _COLUMN_CACHE):
805            logger.debug(
806                'returning column {!r} for table {!r} from cache'.format(
807                    self.name, self.table_name))
808            return _COLUMN_CACHE[(self.table_name, self.name)].value
809
810        with log_start_finish(
811                ('call function to provide column {!r} for table {!r}'
812                 ).format(self.name, self.table_name), logger):
813            kwargs = _collect_variables(names=self._argspec.args,
814                                        expressions=self._argspec.defaults)
815            col = self._func(**kwargs)
816
817        if self.cache:
818            _COLUMN_CACHE[(self.table_name, self.name)] = CacheItem(
819                (self.table_name, self.name), col, self.cache_scope)
820
821        return col
822
823    def clear_cached(self):
824        """
825        Remove any cached result of this column.
826
827        """
828        x = _COLUMN_CACHE.pop((self.table_name, self.name), None)
829        if x is not None:
830            logger.debug(
831                'cleared cached value for column {!r} in table {!r}'.format(
832                    self.name, self.table_name))
833
834    def func_source_data(self):
835        """
836        Return data about the wrapped function source, including file name,
837        line number, and source code.
838
839        Returns
840        -------
841        filename : str
842        lineno : int
843            The line number on which the function starts.
844        source : str
845
846        """
847        return utils.func_source_data(self._func)
848
849
850class _SeriesWrapper(object):
851    """
852    Wrap a Series for the purpose of giving it the same interface as a
853    `_ColumnFuncWrapper`.
854
855    Parameters
856    ----------
857    table_name : str
858        Table with which the column will be associated.
859    column_name : str
860        Name for the column.
861    series : pandas.Series
862        Series with index matching the table to which it is being added.
863
864    Attributes
865    ----------
866    name : str
867        Column name.
868    table_name : str
869        Name of table this column is associated with.
870
871    """
872    def __init__(self, table_name, column_name, series):
873        self.table_name = table_name
874        self.name = column_name
875        self._column = series
876
877    def __call__(self):
878        return self._column
879
880    def clear_cached(self):
881        """
882        Here for compatibility with `_ColumnFuncWrapper`.
883
884        """
885        pass
886
887
888class _InjectableFuncWrapper(object):
889    """
890    Wraps a function that will provide an injectable value elsewhere.
891
892    Parameters
893    ----------
894    name : str
895    func : callable
896    cache : bool, optional
897        Whether to cache the result of calling the wrapped function.
898    cache_scope : {'step', 'iteration', 'forever'}, optional
899        Scope for which to cache data. Default is to cache forever
900        (or until manually cleared). 'iteration' caches data for each
901        complete iteration of the pipeline, 'step' caches data for
902        a single step of the pipeline.
903
904    Attributes
905    ----------
906    name : str
907        Name of this injectable.
908    cache : bool
909        Whether caching is enabled for this injectable function.
910
911    """
912    def __init__(self, name, func, cache=False, cache_scope=_CS_FOREVER):
913        self.name = name
914        self._func = func
915        self._argspec = getargspec(func)
916        self.cache = cache
917        self.cache_scope = cache_scope
918
919    def __call__(self):
920        if _CACHING and self.cache and self.name in _INJECTABLE_CACHE:
921            logger.debug(
922                'returning injectable {!r} from cache'.format(self.name))
923            return _INJECTABLE_CACHE[self.name].value
924
925        with log_start_finish(
926                'call function to provide injectable {!r}'.format(self.name),
927                logger):
928            kwargs = _collect_variables(names=self._argspec.args,
929                                        expressions=self._argspec.defaults)
930            result = self._func(**kwargs)
931
932        if self.cache:
933            _INJECTABLE_CACHE[self.name] = CacheItem(
934                self.name, result, self.cache_scope)
935
936        return result
937
938    def clear_cached(self):
939        """
940        Clear a cached result for this injectable.
941
942        """
943        x = _INJECTABLE_CACHE.pop(self.name, None)
944        if x:
945            logger.debug(
946                'injectable {!r} removed from cache'.format(self.name))
947
948
949class _StepFuncWrapper(object):
950    """
951    Wrap a step function for argument matching.
952
953    Parameters
954    ----------
955    step_name : str
956    func : callable
957
958    Attributes
959    ----------
960    name : str
961        Name of step.
962
963    """
964    def __init__(self, step_name, func):
965        self.name = step_name
966        self._func = func
967        self._argspec = getargspec(func)
968
969    def __call__(self):
970        with log_start_finish('calling step {!r}'.format(self.name), logger):
971            kwargs = _collect_variables(names=self._argspec.args,
972                                        expressions=self._argspec.defaults)
973            return self._func(**kwargs)
974
975    def _tables_used(self):
976        """
977        Tables injected into the step.
978
979        Returns
980        -------
981        tables : set of str
982
983        """
984        args = list(self._argspec.args)
985        if self._argspec.defaults:
986            default_args = list(self._argspec.defaults)
987        else:
988            default_args = []
989        # Combine names from argument names and argument default values.
990        names = args[:len(args) - len(default_args)] + default_args
991        tables = set()
992        for name in names:
993            parent_name = name.split('.')[0]
994            if is_table(parent_name):
995                tables.add(parent_name)
996        return tables
997
998    def func_source_data(self):
999        """
1000        Return data about a step function's source, including file name,
1001        line number, and source code.
1002
1003        Returns
1004        -------
1005        filename : str
1006        lineno : int
1007            The line number on which the function starts.
1008        source : str
1009
1010        """
1011        return utils.func_source_data(self._func)
1012
1013
1014def is_table(name):
1015    """
1016    Returns whether a given name refers to a registered table.
1017
1018    """
1019    return name in _TABLES
1020
1021
1022def list_tables():
1023    """
1024    List of table names.
1025
1026    """
1027    return list(_TABLES.keys())
1028
1029
1030def list_columns():
1031    """
1032    List of (table name, registered column name) pairs.
1033
1034    """
1035    return list(_COLUMNS.keys())
1036
1037
1038def list_steps():
1039    """
1040    List of registered step names.
1041
1042    """
1043    return list(_STEPS.keys())
1044
1045
1046def list_injectables():
1047    """
1048    List of registered injectables.
1049
1050    """
1051    return list(_INJECTABLES.keys())
1052
1053
1054def list_broadcasts():
1055    """
1056    List of registered broadcasts as (cast table name, onto table name).
1057
1058    """
1059    return list(_BROADCASTS.keys())
1060
1061
1062def is_expression(name):
1063    """
1064    Checks whether a given name is a simple variable name or a compound
1065    variable expression.
1066
1067    Parameters
1068    ----------
1069    name : str
1070
1071    Returns
1072    -------
1073    is_expr : bool
1074
1075    """
1076    return '.' in name
1077
1078
1079def _collect_variables(names, expressions=None):
1080    """
1081    Map labels and expressions to registered variables.
1082
1083    Handles argument matching.
1084
1085    Example:
1086
1087        _collect_variables(names=['zones', 'zone_id'],
1088                           expressions=['parcels.zone_id'])
1089
1090    Would return a dict representing:
1091
1092        {'parcels': <DataFrameWrapper for zones>,
1093         'zone_id': <pandas.Series for parcels.zone_id>}
1094
1095    Parameters
1096    ----------
1097    names : list of str
1098        List of registered variable names and/or labels.
1099        If mixing names and labels, labels must come at the end.
1100    expressions : list of str, optional
1101        List of registered variable expressions for labels defined
1102        at end of `names`. Length must match the number of labels.
1103
1104    Returns
1105    -------
1106    variables : dict
1107        Keys match `names`. Values correspond to registered variables,
1108        which may be wrappers or evaluated functions if appropriate.
1109
1110    """
1111    # Map registered variable labels to expressions.
1112    if not expressions:
1113        expressions = []
1114    offset = len(names) - len(expressions)
1115    labels_map = dict(tz.concatv(
1116        zip(names[:offset], names[:offset]),
1117        zip(names[offset:], expressions)))
1118
1119    all_variables = tz.merge(_INJECTABLES, _TABLES)
1120    variables = {}
1121    for label, expression in labels_map.items():
1122        # In the future, more registered variable expressions could be
1123        # supported. Currently supports names of registered variables
1124        # and references to table columns.
1125        if '.' in expression:
1126            # Registered variable expression refers to column.
1127            table_name, column_name = expression.split('.')
1128            table = get_table(table_name)
1129            variables[label] = table.get_column(column_name)
1130        else:
1131            thing = all_variables[expression]
1132            if isinstance(thing, (_InjectableFuncWrapper, TableFuncWrapper)):
1133                # Registered variable object is function.
1134                variables[label] = thing()
1135            else:
1136                variables[label] = thing
1137
1138    return variables
1139
1140
1141def add_table(
1142        table_name, table, cache=False, cache_scope=_CS_FOREVER,
1143        copy_col=True):
1144    """
1145    Register a table with Orca.
1146
1147    Parameters
1148    ----------
1149    table_name : str
1150        Should be globally unique to this table.
1151    table : pandas.DataFrame or function
1152        If a function, the function should return a DataFrame.
1153        The function's argument names and keyword argument values
1154        will be matched to registered variables when the function
1155        needs to be evaluated by Orca.
1156    cache : bool, optional
1157        Whether to cache the results of a provided callable. Does not
1158        apply if `table` is a DataFrame.
1159    cache_scope : {'step', 'iteration', 'forever'}, optional
1160        Scope for which to cache data. Default is to cache forever
1161        (or until manually cleared). 'iteration' caches data for each
1162        complete iteration of the pipeline, 'step' caches data for
1163        a single step of the pipeline.
1164    copy_col : bool, optional
1165        Whether to return copies when evaluating columns.
1166
1167    Returns
1168    -------
1169    wrapped : `DataFrameWrapper` or `TableFuncWrapper`
1170
1171    """
1172    if isinstance(table, Callable):
1173        table = TableFuncWrapper(table_name, table, cache=cache,
1174                                 cache_scope=cache_scope, copy_col=copy_col)
1175    else:
1176        table = DataFrameWrapper(table_name, table, copy_col=copy_col)
1177
1178    # clear any cached data from a previously registered table
1179    table.clear_cached()
1180
1181    logger.debug('registering table {!r}'.format(table_name))
1182    _TABLES[table_name] = table
1183
1184    return table
1185
1186
1187def table(
1188        table_name=None, cache=False, cache_scope=_CS_FOREVER, copy_col=True):
1189    """
1190    Decorates functions that return DataFrames.
1191
1192    Decorator version of `add_table`. Table name defaults to
1193    name of function.
1194
1195    The function's argument names and keyword argument values
1196    will be matched to registered variables when the function
1197    needs to be evaluated by Orca.
1198    The argument name "iter_var" may be used to have the current
1199    iteration variable injected.
1200
1201    """
1202    def decorator(func):
1203        if table_name:
1204            name = table_name
1205        else:
1206            name = func.__name__
1207        add_table(
1208            name, func, cache=cache, cache_scope=cache_scope,
1209            copy_col=copy_col)
1210        return func
1211    return decorator
1212
1213
1214def get_raw_table(table_name):
1215    """
1216    Get a wrapped table by name and don't do anything to it.
1217
1218    Parameters
1219    ----------
1220    table_name : str
1221
1222    Returns
1223    -------
1224    table : DataFrameWrapper or TableFuncWrapper
1225
1226    """
1227    if is_table(table_name):
1228        return _TABLES[table_name]
1229    else:
1230        raise KeyError('table not found: {}'.format(table_name))
1231
1232
1233def get_table(table_name):
1234    """
1235    Get a registered table.
1236
1237    Decorated functions will be converted to `DataFrameWrapper`.
1238
1239    Parameters
1240    ----------
1241    table_name : str
1242
1243    Returns
1244    -------
1245    table : `DataFrameWrapper`
1246
1247    """
1248    table = get_raw_table(table_name)
1249    if isinstance(table, TableFuncWrapper):
1250        table = table()
1251    return table
1252
1253
1254def table_type(table_name):
1255    """
1256    Returns the type of a registered table.
1257
1258    The type can be either "dataframe" or "function".
1259
1260    Parameters
1261    ----------
1262    table_name : str
1263
1264    Returns
1265    -------
1266    table_type : {'dataframe', 'function'}
1267
1268    """
1269    table = get_raw_table(table_name)
1270
1271    if isinstance(table, DataFrameWrapper):
1272        return 'dataframe'
1273    elif isinstance(table, TableFuncWrapper):
1274        return 'function'
1275
1276
1277def add_column(
1278        table_name, column_name, column, cache=False, cache_scope=_CS_FOREVER):
1279    """
1280    Add a new column to a table from a Series or callable.
1281
1282    Parameters
1283    ----------
1284    table_name : str
1285        Table with which the column will be associated.
1286    column_name : str
1287        Name for the column.
1288    column : pandas.Series or callable
1289        Series should have an index matching the table to which it
1290        is being added. If a callable, the function's argument
1291        names and keyword argument values will be matched to
1292        registered variables when the function needs to be
1293        evaluated by Orca. The function should return a Series.
1294    cache : bool, optional
1295        Whether to cache the results of a provided callable. Does not
1296        apply if `column` is a Series.
1297    cache_scope : {'step', 'iteration', 'forever'}, optional
1298        Scope for which to cache data. Default is to cache forever
1299        (or until manually cleared). 'iteration' caches data for each
1300        complete iteration of the pipeline, 'step' caches data for
1301        a single step of the pipeline.
1302
1303    """
1304    if isinstance(column, Callable):
1305        column = \
1306            _ColumnFuncWrapper(
1307                table_name, column_name, column,
1308                cache=cache, cache_scope=cache_scope)
1309    else:
1310        column = _SeriesWrapper(table_name, column_name, column)
1311
1312    # clear any cached data from a previously registered column
1313    column.clear_cached()
1314
1315    logger.debug('registering column {!r} on table {!r}'.format(
1316        column_name, table_name))
1317    _COLUMNS[(table_name, column_name)] = column
1318
1319    return column
1320
1321
1322def column(table_name, column_name=None, cache=False, cache_scope=_CS_FOREVER):
1323    """
1324    Decorates functions that return a Series.
1325
1326    Decorator version of `add_column`. Series index must match
1327    the named table. Column name defaults to name of function.
1328
1329    The function's argument names and keyword argument values
1330    will be matched to registered variables when the function
1331    needs to be evaluated by Orca.
1332    The argument name "iter_var" may be used to have the current
1333    iteration variable injected.
1334    The index of the returned Series must match the named table.
1335
1336    """
1337    def decorator(func):
1338        if column_name:
1339            name = column_name
1340        else:
1341            name = func.__name__
1342        add_column(
1343            table_name, name, func, cache=cache, cache_scope=cache_scope)
1344        return func
1345    return decorator
1346
1347
1348def list_columns_for_table(table_name):
1349    """
1350    Return a list of all the extra columns registered for a given table.
1351
1352    Parameters
1353    ----------
1354    table_name : str
1355
1356    Returns
1357    -------
1358    columns : list of str
1359
1360    """
1361    return [cname for tname, cname in _COLUMNS.keys() if tname == table_name]
1362
1363
1364def _columns_for_table(table_name):
1365    """
1366    Return all of the columns registered for a given table.
1367
1368    Parameters
1369    ----------
1370    table_name : str
1371
1372    Returns
1373    -------
1374    columns : dict of column wrappers
1375        Keys will be column names.
1376
1377    """
1378    return {cname: col
1379            for (tname, cname), col in _COLUMNS.items()
1380            if tname == table_name}
1381
1382
1383def column_map(tables, columns):
1384    """
1385    Take a list of tables and a list of column names and resolve which
1386    columns come from which table.
1387
1388    Parameters
1389    ----------
1390    tables : sequence of _DataFrameWrapper or _TableFuncWrapper
1391        Could also be sequence of modified pandas.DataFrames, the important
1392        thing is that they have ``.name`` and ``.columns`` attributes.
1393    columns : sequence of str
1394        The column names of interest.
1395
1396    Returns
1397    -------
1398    col_map : dict
1399        Maps table names to lists of column names.
1400    """
1401    if not columns:
1402        return {t.name: None for t in tables}
1403
1404    columns = set(columns)
1405    colmap = {
1406        t.name: list(set(t.columns).intersection(columns)) for t in tables}
1407    foundcols = tz.reduce(
1408        lambda x, y: x.union(y), (set(v) for v in colmap.values()))
1409    if foundcols != columns:
1410        raise RuntimeError('Not all required columns were found. '
1411                           'Missing: {}'.format(list(columns - foundcols)))
1412    return colmap
1413
1414
1415def get_raw_column(table_name, column_name):
1416    """
1417    Get a wrapped, registered column.
1418
1419    This function cannot return columns that are part of wrapped
1420    DataFrames, it's only for columns registered directly through Orca.
1421
1422    Parameters
1423    ----------
1424    table_name : str
1425    column_name : str
1426
1427    Returns
1428    -------
1429    wrapped : _SeriesWrapper or _ColumnFuncWrapper
1430
1431    """
1432    try:
1433        return _COLUMNS[(table_name, column_name)]
1434    except KeyError:
1435        raise KeyError('column {!r} not found for table {!r}'.format(
1436            column_name, table_name))
1437
1438
1439def _memoize_function(f, name, cache_scope=_CS_FOREVER):
1440    """
1441    Wraps a function for memoization and ties it's cache into the
1442    Orca cacheing system.
1443
1444    Parameters
1445    ----------
1446    f : function
1447    name : str
1448        Name of injectable.
1449    cache_scope : {'step', 'iteration', 'forever'}, optional
1450        Scope for which to cache data. Default is to cache forever
1451        (or until manually cleared). 'iteration' caches data for each
1452        complete iteration of the pipeline, 'step' caches data for
1453        a single step of the pipeline.
1454
1455    """
1456    cache = {}
1457
1458    @wraps(f)
1459    def wrapper(*args, **kwargs):
1460        try:
1461            cache_key = (
1462                args or None, frozenset(kwargs.items()) if kwargs else None)
1463            in_cache = cache_key in cache
1464        except TypeError:
1465            raise TypeError(
1466                'function arguments must be hashable for memoization')
1467
1468        if _CACHING and in_cache:
1469            return cache[cache_key]
1470        else:
1471            result = f(*args, **kwargs)
1472            cache[cache_key] = result
1473            return result
1474
1475    wrapper.__wrapped__ = f
1476    wrapper.cache = cache
1477    wrapper.clear_cached = lambda: cache.clear()
1478    _MEMOIZED[name] = CacheItem(name, wrapper, cache_scope)
1479
1480    return wrapper
1481
1482
1483def add_injectable(
1484        name, value, autocall=True, cache=False, cache_scope=_CS_FOREVER,
1485        memoize=False):
1486    """
1487    Add a value that will be injected into other functions.
1488
1489    Parameters
1490    ----------
1491    name : str
1492    value
1493        If a callable and `autocall` is True then the function's
1494        argument names and keyword argument values will be matched
1495        to registered variables when the function needs to be
1496        evaluated by Orca. The return value will
1497        be passed to any functions using this injectable. In all other
1498        cases, `value` will be passed through untouched.
1499    autocall : bool, optional
1500        Set to True to have injectable functions automatically called
1501        (with argument matching) and the result injected instead of
1502        the function itself.
1503    cache : bool, optional
1504        Whether to cache the return value of an injectable function.
1505        Only applies when `value` is a callable and `autocall` is True.
1506    cache_scope : {'step', 'iteration', 'forever'}, optional
1507        Scope for which to cache data. Default is to cache forever
1508        (or until manually cleared). 'iteration' caches data for each
1509        complete iteration of the pipeline, 'step' caches data for
1510        a single step of the pipeline.
1511    memoize : bool, optional
1512        If autocall is False it is still possible to cache function results
1513        by setting this flag to True. Cached values are stored in a dictionary
1514        keyed by argument values, so the argument values must be hashable.
1515        Memoized functions have their caches cleared according to the same
1516        rules as universal caching.
1517
1518    """
1519    if isinstance(value, Callable):
1520        if autocall:
1521            value = _InjectableFuncWrapper(
1522                name, value, cache=cache, cache_scope=cache_scope)
1523            # clear any cached data from a previously registered value
1524            value.clear_cached()
1525        elif not autocall and memoize:
1526            value = _memoize_function(value, name, cache_scope=cache_scope)
1527
1528    logger.debug('registering injectable {!r}'.format(name))
1529    _INJECTABLES[name] = value
1530
1531
1532def injectable(
1533        name=None, autocall=True, cache=False, cache_scope=_CS_FOREVER,
1534        memoize=False):
1535    """
1536    Decorates functions that will be injected into other functions.
1537
1538    Decorator version of `add_injectable`. Name defaults to
1539    name of function.
1540
1541    The function's argument names and keyword argument values
1542    will be matched to registered variables when the function
1543    needs to be evaluated by Orca.
1544    The argument name "iter_var" may be used to have the current
1545    iteration variable injected.
1546
1547    """
1548    def decorator(func):
1549        if name:
1550            n = name
1551        else:
1552            n = func.__name__
1553        add_injectable(
1554            n, func, autocall=autocall, cache=cache, cache_scope=cache_scope,
1555            memoize=memoize)
1556        return func
1557    return decorator
1558
1559
1560def is_injectable(name):
1561    """
1562    Checks whether a given name can be mapped to an injectable.
1563
1564    """
1565    return name in _INJECTABLES
1566
1567
1568def get_raw_injectable(name):
1569    """
1570    Return a raw, possibly wrapped injectable.
1571
1572    Parameters
1573    ----------
1574    name : str
1575
1576    Returns
1577    -------
1578    inj : _InjectableFuncWrapper or object
1579
1580    """
1581    if is_injectable(name):
1582        return _INJECTABLES[name]
1583    else:
1584        raise KeyError('injectable not found: {!r}'.format(name))
1585
1586
1587def injectable_type(name):
1588    """
1589    Classify an injectable as either 'variable' or 'function'.
1590
1591    Parameters
1592    ----------
1593    name : str
1594
1595    Returns
1596    -------
1597    inj_type : {'variable', 'function'}
1598        If the injectable is an automatically called function or any other
1599        type of callable the type will be 'function', all other injectables
1600        will be have type 'variable'.
1601
1602    """
1603    inj = get_raw_injectable(name)
1604    if isinstance(inj, (_InjectableFuncWrapper, Callable)):
1605        return 'function'
1606    else:
1607        return 'variable'
1608
1609
1610def get_injectable(name):
1611    """
1612    Get an injectable by name. *Does not* evaluate wrapped functions.
1613
1614    Parameters
1615    ----------
1616    name : str
1617
1618    Returns
1619    -------
1620    injectable
1621        Original value or evaluated value of an _InjectableFuncWrapper.
1622
1623    """
1624    i = get_raw_injectable(name)
1625    return i() if isinstance(i, _InjectableFuncWrapper) else i
1626
1627
1628def get_injectable_func_source_data(name):
1629    """
1630    Return data about an injectable function's source, including file name,
1631    line number, and source code.
1632
1633    Parameters
1634    ----------
1635    name : str
1636
1637    Returns
1638    -------
1639    filename : str
1640    lineno : int
1641        The line number on which the function starts.
1642    source : str
1643
1644    """
1645    if injectable_type(name) != 'function':
1646        raise ValueError('injectable {!r} is not a function'.format(name))
1647
1648    inj = get_raw_injectable(name)
1649
1650    if isinstance(inj, _InjectableFuncWrapper):
1651        return utils.func_source_data(inj._func)
1652    elif hasattr(inj, '__wrapped__'):
1653        return utils.func_source_data(inj.__wrapped__)
1654    else:
1655        return utils.func_source_data(inj)
1656
1657
1658def add_step(step_name, func):
1659    """
1660    Add a step function to Orca.
1661
1662    The function's argument names and keyword argument values
1663    will be matched to registered variables when the function
1664    needs to be evaluated by Orca.
1665    The argument name "iter_var" may be used to have the current
1666    iteration variable injected.
1667
1668    Parameters
1669    ----------
1670    step_name : str
1671    func : callable
1672
1673    """
1674    if isinstance(func, Callable):
1675        logger.debug('registering step {!r}'.format(step_name))
1676        _STEPS[step_name] = _StepFuncWrapper(step_name, func)
1677    else:
1678        raise TypeError('func must be a callable')
1679
1680
1681def step(step_name=None):
1682    """
1683    Decorates functions that will be called by the `run` function.
1684
1685    Decorator version of `add_step`. step name defaults to
1686    name of function.
1687
1688    The function's argument names and keyword argument values
1689    will be matched to registered variables when the function
1690    needs to be evaluated by Orca.
1691    The argument name "iter_var" may be used to have the current
1692    iteration variable injected.
1693
1694    """
1695    def decorator(func):
1696        if step_name:
1697            name = step_name
1698        else:
1699            name = func.__name__
1700        add_step(name, func)
1701        return func
1702    return decorator
1703
1704
1705def is_step(step_name):
1706    """
1707    Check whether a given name refers to a registered step.
1708
1709    """
1710    return step_name in _STEPS
1711
1712
1713def get_step(step_name):
1714    """
1715    Get a wrapped step by name.
1716
1717    Parameters
1718    ----------
1719
1720    """
1721    if is_step(step_name):
1722        return _STEPS[step_name]
1723    else:
1724        raise KeyError('no step named {}'.format(step_name))
1725
1726
1727Broadcast = namedtuple(
1728    'Broadcast',
1729    ['cast', 'onto', 'cast_on', 'onto_on', 'cast_index', 'onto_index'])
1730
1731
1732def broadcast(cast, onto, cast_on=None, onto_on=None,
1733              cast_index=False, onto_index=False):
1734    """
1735    Register a rule for merging two tables by broadcasting one onto
1736    the other.
1737
1738    Parameters
1739    ----------
1740    cast, onto : str
1741        Names of registered tables.
1742    cast_on, onto_on : str, optional
1743        Column names used for merge, equivalent of ``left_on``/``right_on``
1744        parameters of pandas.merge.
1745    cast_index, onto_index : bool, optional
1746        Whether to use table indexes for merge. Equivalent of
1747        ``left_index``/``right_index`` parameters of pandas.merge.
1748
1749    """
1750    logger.debug(
1751        'registering broadcast of table {!r} onto {!r}'.format(cast, onto))
1752    _BROADCASTS[(cast, onto)] = \
1753        Broadcast(cast, onto, cast_on, onto_on, cast_index, onto_index)
1754
1755
1756def _get_broadcasts(tables):
1757    """
1758    Get the broadcasts associated with a set of tables.
1759
1760    Parameters
1761    ----------
1762    tables : sequence of str
1763        Table names for which broadcasts have been registered.
1764
1765    Returns
1766    -------
1767    casts : dict of `Broadcast`
1768        Keys are tuples of strings like (cast_name, onto_name).
1769
1770    """
1771    tables = set(tables)
1772    casts = tz.keyfilter(
1773        lambda x: x[0] in tables and x[1] in tables, _BROADCASTS)
1774    if tables - set(tz.concat(casts.keys())):
1775        raise ValueError('Not enough links to merge all tables.')
1776    return casts
1777
1778
1779def is_broadcast(cast_name, onto_name):
1780    """
1781    Checks whether a relationship exists for broadcast `cast_name`
1782    onto `onto_name`.
1783
1784    """
1785    return (cast_name, onto_name) in _BROADCASTS
1786
1787
1788def get_broadcast(cast_name, onto_name):
1789    """
1790    Get a single broadcast.
1791
1792    Broadcasts are stored data about how to do a Pandas join.
1793    A Broadcast object is a namedtuple with these attributes:
1794
1795        - cast: the name of the table being broadcast
1796        - onto: the name of the table onto which "cast" is broadcast
1797        - cast_on: The optional name of a column on which to join.
1798          None if the table index will be used instead.
1799        - onto_on: The optional name of a column on which to join.
1800          None if the table index will be used instead.
1801        - cast_index: True if the table index should be used for the join.
1802        - onto_index: True if the table index should be used for the join.
1803
1804    Parameters
1805    ----------
1806    cast_name : str
1807        The name of the table being braodcast.
1808    onto_name : str
1809        The name of the table onto which `cast_name` is broadcast.
1810
1811    Returns
1812    -------
1813    broadcast : Broadcast
1814
1815    """
1816    if is_broadcast(cast_name, onto_name):
1817        return _BROADCASTS[(cast_name, onto_name)]
1818    else:
1819        raise KeyError(
1820            'no rule found for broadcasting {!r} onto {!r}'.format(
1821                cast_name, onto_name))
1822
1823
1824# utilities for merge_tables
1825def _all_reachable_tables(t):
1826    """
1827    A generator that provides all the names of tables that can be
1828    reached via merges starting at the given target table.
1829
1830    """
1831    for k, v in t.items():
1832        for tname in _all_reachable_tables(v):
1833            yield tname
1834        yield k
1835
1836
1837def _recursive_getitem(d, key):
1838    """
1839    Descend into a dict of dicts to return the one that contains
1840    a given key. Every value in the dict must be another dict.
1841
1842    """
1843    if key in d:
1844        return d
1845    else:
1846        for v in d.values():
1847            return _recursive_getitem(v, key)
1848        else:
1849            raise KeyError('Key not found: {}'.format(key))
1850
1851
1852def _dict_value_to_pairs(d):
1853    """
1854    Takes the first value of a dictionary (which it self should be
1855    a dictionary) and turns it into a series of {key: value} dicts.
1856
1857    For example, _dict_value_to_pairs({'c': {'a': 1, 'b': 2}}) will yield
1858    {'a': 1} and {'b': 2}.
1859
1860    """
1861    d = d[tz.first(d)]
1862
1863    for k, v in d.items():
1864        yield {k: v}
1865
1866
1867def _is_leaf_node(merge_node):
1868    """
1869    Returns True for dicts like {'a': {}}.
1870
1871    """
1872    return len(merge_node) == 1 and not next(iter(merge_node.values()))
1873
1874
1875def _next_merge(merge_node):
1876    """
1877    Gets a node that has only leaf nodes below it. This table and
1878    the ones below are ready to be merged to make a new leaf node.
1879
1880    """
1881    if all(_is_leaf_node(d) for d in _dict_value_to_pairs(merge_node)):
1882        return merge_node
1883    else:
1884        for d in tz.remove(_is_leaf_node, _dict_value_to_pairs(merge_node)):
1885            return _next_merge(d)
1886        else:
1887            raise OrcaError('No node found for next merge.')
1888
1889
1890def merge_tables(target, tables, columns=None, drop_intersection=True):
1891    """
1892    Merge a number of tables onto a target table. Tables must have
1893    registered merge rules via the `broadcast` function.
1894
1895    Parameters
1896    ----------
1897    target : str, DataFrameWrapper, or TableFuncWrapper
1898        Name of the table (or wrapped table) onto which tables will be merged.
1899    tables : list of `DataFrameWrapper`, `TableFuncWrapper`, or str
1900        All of the tables to merge. Should include the target table.
1901    columns : list of str, optional
1902        If given, columns will be mapped to `tables` and only those columns
1903        will be requested from each table. The final merged table will have
1904        only these columns. By default all columns are used from every
1905        table.
1906    drop_intersection : bool
1907        If True, keep the left most occurence of any column name if it occurs
1908        on more than one table.  This prevents getting back the same column
1909        with suffixes applied by pd.merge.  If false, columns names will be
1910        suffixed with the table names - e.g. zone_id_buildings and
1911        zone_id_parcels.
1912
1913    Returns
1914    -------
1915    merged : pandas.DataFrame
1916
1917    """
1918    # allow target to be string or table wrapper
1919    if isinstance(target, (DataFrameWrapper, TableFuncWrapper)):
1920        target = target.name
1921
1922    # allow tables to be strings or table wrappers
1923    tables = [get_table(t)
1924              if not isinstance(t, (DataFrameWrapper, TableFuncWrapper)) else t
1925              for t in tables]
1926
1927    merges = {t.name: {} for t in tables}
1928    tables = {t.name: t for t in tables}
1929    casts = _get_broadcasts(tables.keys())
1930    logger.debug(
1931        'attempting to merge tables {} to target table {}'.format(
1932            tables.keys(), target))
1933
1934    # relate all the tables by registered broadcasts
1935    for table, onto in casts:
1936        merges[onto][table] = merges[table]
1937    merges = {target: merges[target]}
1938
1939    # verify that all the tables can be merged to the target
1940    all_tables = set(_all_reachable_tables(merges))
1941
1942    if all_tables != set(tables.keys()):
1943        raise RuntimeError(
1944            ('Not all tables can be merged to target "{}". Unlinked tables: {}'
1945             ).format(target, list(set(tables.keys()) - all_tables)))
1946
1947    # add any columns necessary for indexing into other tables
1948    # during merges
1949    if columns:
1950        columns = list(columns)
1951        for c in casts.values():
1952            if c.onto_on:
1953                columns.append(c.onto_on)
1954            if c.cast_on:
1955                columns.append(c.cast_on)
1956
1957    # get column map for which columns go with which table
1958    colmap = column_map(tables.values(), columns)
1959
1960    # get frames
1961    frames = {name: t.to_frame(columns=colmap[name])
1962              for name, t in tables.items()}
1963
1964    past_intersections = set()
1965
1966    # perform merges until there's only one table left
1967    while merges[target]:
1968        nm = _next_merge(merges)
1969        onto = tz.first(nm)
1970        onto_table = frames[onto]
1971
1972        # loop over all the tables that can be broadcast onto
1973        # the onto_table and merge them all in.
1974        for cast in nm[onto]:
1975            cast_table = frames[cast]
1976            bc = casts[(cast, onto)]
1977
1978            with log_start_finish(
1979                    'merge tables {} and {}'.format(onto, cast), logger):
1980
1981                intersection = set(onto_table.columns).\
1982                    intersection(cast_table.columns)
1983                # intersection is ok if it's the join key
1984                intersection.discard(bc.onto_on)
1985                intersection.discard(bc.cast_on)
1986                # otherwise drop so as not to create conflicts
1987                if drop_intersection:
1988                    cast_table = cast_table.drop(intersection, axis=1)
1989                else:
1990                    # add suffix to past intersections which wouldn't get
1991                    # picked up by the merge - these we have to rename by hand
1992                    renames = dict(zip(
1993                        past_intersections,
1994                        [c+'_'+onto for c in past_intersections]
1995                    ))
1996                    onto_table = onto_table.rename(columns=renames)
1997
1998                # keep track of past intersections in case there's an odd
1999                # number of intersections
2000                past_intersections = past_intersections.union(intersection)
2001
2002                onto_table = pd.merge(
2003                    onto_table, cast_table,
2004                    suffixes=['_'+onto, '_'+cast],
2005                    left_on=bc.onto_on, right_on=bc.cast_on,
2006                    left_index=bc.onto_index, right_index=bc.cast_index)
2007
2008        # replace the existing table with the merged one
2009        frames[onto] = onto_table
2010
2011        # free up space by dropping the cast table
2012        del frames[cast]
2013
2014        # mark the onto table as having no more things to broadcast
2015        # onto it.
2016        _recursive_getitem(merges, onto)[onto] = {}
2017
2018    logger.debug('finished merge')
2019    return frames[target]
2020
2021
2022def get_step_table_names(steps):
2023    """
2024    Returns a list of table names injected into the provided steps.
2025
2026    Parameters
2027    ----------
2028    steps: list of str
2029        Steps to gather table inputs from.
2030
2031    Returns
2032    -------
2033    list of str
2034
2035    """
2036    table_names = set()
2037    for s in steps:
2038        table_names |= get_step(s)._tables_used()
2039    return list(table_names)
2040
2041
2042def write_tables(fname, table_names=None, prefix=None, compress=False, local=False):
2043    """
2044    Writes tables to a pandas.HDFStore file.
2045
2046    Parameters
2047    ----------
2048    fname : str
2049        File name for HDFStore. Will be opened in append mode and closed
2050        at the end of this function.
2051    table_names: list of str, optional, default None
2052        List of tables to write. If None, all registered tables will
2053        be written.
2054    prefix: str
2055        If not None, used to prefix the output table names so that
2056        multiple iterations can go in the same file.
2057    compress: boolean
2058        Whether to compress output file using standard HDF5-readable
2059        zlib compression, default False.
2060
2061    """
2062    if table_names is None:
2063        table_names = list_tables()
2064
2065    tables = (get_table(t) for t in table_names)
2066    key_template = '{}/{{}}'.format(prefix) if prefix is not None else '{}'
2067
2068    # set compression options to zlib level-1 if compress arg is True
2069    complib = compress and 'zlib' or None
2070    complevel = compress and 1 or 0
2071
2072    with pd.HDFStore(fname, mode='a', complib=complib, complevel=complevel) as store:
2073        for t in tables:
2074            # if local arg is True, store only local columns
2075            columns = None
2076            if local is True:
2077                columns = t.local_columns
2078            store[key_template.format(t.name)] = t.to_frame(columns=columns)
2079
2080
2081iter_step = namedtuple('iter_step', 'step_num,step_name')
2082
2083
2084def run(steps, iter_vars=None, data_out=None, out_interval=1,
2085        out_base_tables=None, out_run_tables=None, compress=False,
2086        out_base_local=True, out_run_local=True):
2087    """
2088    Run steps in series, optionally repeatedly over some sequence.
2089    The current iteration variable is set as a global injectable
2090    called ``iter_var``.
2091
2092    Parameters
2093    ----------
2094    steps : list of str
2095        List of steps to run identified by their name.
2096    iter_vars : iterable, optional
2097        The values of `iter_vars` will be made available as an injectable
2098        called ``iter_var`` when repeatedly running `steps`.
2099    data_out : str, optional
2100        An optional filename to which all tables injected into any step
2101        in `steps` will be saved every `out_interval` iterations.
2102        File will be a pandas HDF data store.
2103    out_interval : int, optional
2104        Iteration interval on which to save data to `data_out`. For example,
2105        2 will save out every 2 iterations, 5 every 5 iterations.
2106        Default is every iteration.
2107        The results of the first and last iterations are always included.
2108        The input (base) tables are also included and prefixed with `base/`,
2109        these represent the state of the system before any steps have been
2110        executed.
2111        The interval is defined relative to the first iteration. For example,
2112        a run begining in 2015 with an out_interval of 2, will write out
2113        results for 2015, 2017, etc.
2114    out_base_tables: list of str, optional, default None
2115        List of base tables to write. If not provided, tables injected
2116        into 'steps' will be written.
2117    out_run_tables: list of str, optional, default None
2118        List of run tables to write. If not provided, tables injected
2119        into 'steps' will be written.
2120    compress: boolean, optional, default False
2121        Whether to compress output file using standard HDF5 zlib compression.
2122        Compression yields much smaller files using slightly more CPU.
2123    out_base_local: boolean, optional, default True
2124        For tables in out_base_tables, whether to store only local columns (True)
2125        or both, local and computed columns (False).
2126    out_run_local: boolean, optional, default True
2127        For tables in out_run_tables, whether to store only local columns (True)
2128        or both, local and computed columns (False).
2129    """
2130    iter_vars = iter_vars or [None]
2131    max_i = len(iter_vars)
2132
2133    # get the tables to write out
2134    if out_base_tables is None or out_run_tables is None:
2135        step_tables = get_step_table_names(steps)
2136
2137        if out_base_tables is None:
2138            out_base_tables = step_tables
2139
2140        if out_run_tables is None:
2141            out_run_tables = step_tables
2142
2143    # write out the base (inputs)
2144    if data_out:
2145        add_injectable('iter_var', iter_vars[0])
2146        write_tables(data_out, out_base_tables, 'base', compress=compress, local=out_base_local)
2147
2148    # run the steps
2149    for i, var in enumerate(iter_vars, start=1):
2150        add_injectable('iter_var', var)
2151
2152        if var is not None:
2153            print('Running iteration {} with iteration value {!r}'.format(
2154                i, var))
2155            logger.debug(
2156                'running iteration {} with iteration value {!r}'.format(
2157                    i, var))
2158
2159        t1 = time.time()
2160        for j, step_name in enumerate(steps):
2161            add_injectable('iter_step', iter_step(j, step_name))
2162            print('Running step {!r}'.format(step_name))
2163            with log_start_finish(
2164                    'run step {!r}'.format(step_name), logger,
2165                    logging.INFO):
2166                step = get_step(step_name)
2167                t2 = time.time()
2168                step()
2169                print("Time to execute step '{}': {:.2f} s".format(
2170                      step_name, time.time() - t2))
2171            clear_cache(scope=_CS_STEP)
2172
2173        print(
2174            ('Total time to execute iteration {} '
2175             'with iteration value {!r}: '
2176             '{:.2f} s').format(i, var, time.time() - t1))
2177
2178        # write out the results for the current iteration
2179        if data_out:
2180            if (i - 1) % out_interval == 0 or i == max_i:
2181                write_tables(data_out, out_run_tables, var, compress=compress, local=out_run_local)
2182
2183        clear_cache(scope=_CS_ITER)
2184
2185
2186@contextmanager
2187def injectables(**kwargs):
2188    """
2189    Temporarily add injectables to the pipeline environment.
2190    Takes only keyword arguments.
2191
2192    Injectables will be returned to their original state when the context
2193    manager exits.
2194
2195    """
2196    global _INJECTABLES
2197
2198    original = _INJECTABLES.copy()
2199    _INJECTABLES.update(kwargs)
2200    yield
2201    _INJECTABLES = original
2202
2203
2204@contextmanager
2205def temporary_tables(**kwargs):
2206    """
2207    Temporarily set DataFrames as registered tables.
2208
2209    Tables will be returned to their original state when the context
2210    manager exits. Caching is not enabled for tables registered via
2211    this function.
2212
2213    """
2214    global _TABLES
2215
2216    original = _TABLES.copy()
2217
2218    for k, v in kwargs.items():
2219        if not isinstance(v, pd.DataFrame):
2220            raise ValueError('tables only accepts DataFrames')
2221        add_table(k, v)
2222
2223    yield
2224
2225    _TABLES = original
2226
2227
2228def eval_variable(name, **kwargs):
2229    """
2230    Execute a single variable function registered with Orca
2231    and return the result. Any keyword arguments are temporarily set
2232    as injectables. This gives the value as would be injected into a function.
2233
2234    Parameters
2235    ----------
2236    name : str
2237        Name of variable to evaluate.
2238        Use variable expressions to specify columns.
2239
2240    Returns
2241    -------
2242    object
2243        For injectables and columns this directly returns whatever
2244        object is returned by the registered function.
2245        For tables this returns a DataFrameWrapper as if the table
2246        had been injected into a function.
2247
2248    """
2249    with injectables(**kwargs):
2250        vars = _collect_variables([name], [name])
2251        return vars[name]
2252
2253
2254def eval_step(name, **kwargs):
2255    """
2256    Evaluate a step as would be done within the pipeline environment
2257    and return the result. Any keyword arguments are temporarily set
2258    as injectables.
2259
2260    Parameters
2261    ----------
2262    name : str
2263        Name of step to run.
2264
2265    Returns
2266    -------
2267    object
2268        Anything returned by a step. (Though note that in Orca runs
2269        return values from steps are ignored.)
2270
2271    """
2272    with injectables(**kwargs):
2273        return get_step(name)()
2274