1# -*- coding: utf-8 -*-
2from __future__ import absolute_import
3from .load_csv import load_csv
4from .temptable import (
5    load_data,
6    new_table_name,
7    savepoint,
8)
9from .squint.query import DEFAULT_CONNECTION
10
11
12def _load_temp_sqlite_table(columns, records):
13    global DEFAULT_CONNECTION
14    cursor = DEFAULT_CONNECTION.cursor()
15    with savepoint(cursor):
16        table = new_table_name(cursor)
17        load_data(cursor, table, columns, records)
18    return DEFAULT_CONNECTION, table
19
20
21########################################################################
22# From sources/base.py
23########################################################################
24from .._compatibility.builtins import *
25from .._compatibility.collections.abc import Sequence
26from .._compatibility import decimal
27from .._compatibility import functools
28from .._utils import nonstringiter
29from .api07_comp import CompareDict
30from .api07_comp import CompareSet
31
32
33class BaseSource(object):
34    """Common base class for all data sources.  Custom sources can be
35    created by subclassing BaseSource and implementing
36    :meth:`__init__()`, :meth:`__repr__()`, :meth:`columns()` and
37    :meth:`__iter__()`.
38
39    All data sources implement a common set of methods.
40    """
41    def __new__(cls, *args, **kwds):
42        if cls is BaseSource:
43            msg = ('Cannot instantiate BaseSource directly.  Use a '
44                   'data source of the appropriate type or make a '
45                   'subclass.')
46            raise NotImplementedError(msg)
47        return super(BaseSource, cls).__new__(cls)
48
49    def __init__(self):
50        """Initialize self."""
51        return NotImplemented
52
53    def __repr__(self):
54        """Returns string representation describing the data source."""
55        return NotImplemented
56
57    def columns(self):
58        """Returns list of column names."""
59        return NotImplemented
60
61    def __iter__(self):
62        """Returns iterable of dictionary rows (like
63        :class:`csv.DictReader`)."""
64        return NotImplemented
65
66    def filter_rows(self, **kwds):
67        """Returns iterable of dictionary rows (like
68        :class:`csv.DictReader`) filtered by keywords.  E.g., where
69        column1=value1, column2=value2, etc. (unoptimized, uses
70        :meth:`__iter__`).
71        """
72        if kwds:
73            normalize = lambda v: (v,) if isinstance(v, str) else v
74            kwds = dict((k, normalize(v)) for k, v in kwds.items())
75            matches_kwds = lambda row: all(row[k] in v for k, v in kwds.items())
76            return filter(matches_kwds, self.__iter__())
77        return self.__iter__()
78
79    def distinct(self, columns, **kwds_filter):
80        """Returns :class:`CompareSet` of distinct values or distinct
81        tuples of values if given multiple *columns* (unoptimized, uses
82        :meth:`__iter__`).
83        """
84        if not nonstringiter(columns):
85            columns = (columns,)
86        self._assert_columns_exist(columns)
87        iterable = self.filter_rows(**kwds_filter)  # Filtered rows only.
88        iterable = (tuple(row[c] for c in columns) for row in iterable)
89        return CompareSet(iterable)
90
91    def sum(self, column, keys=None, **kwds_filter):
92        """Returns :class:`CompareDict` containing sums of *column*
93        values grouped by *keys*.
94        """
95        mapper = lambda x: decimal.Decimal(x) if x else decimal.Decimal(0)
96        reducer = lambda x, y: x + y
97        return self.mapreduce(mapper, reducer, column, keys, **kwds_filter)
98
99    def count(self, column, keys=None, **kwds_filter):
100        """Returns :class:`CompareDict` containing count of non-empty
101        *column* values grouped by *keys*.
102        """
103        mapper = lambda value: 1 if value else 0  # 1 for truthy, 0 for falsy
104        reducer = lambda x, y: x + y
105        return self.mapreduce(mapper, reducer, column, keys, **kwds_filter)
106
107    def mapreduce(self, mapper, reducer, columns, keys=None, **kwds_filter):
108        """Apply a *mapper* to specified *columns* (which are grouped by
109        *keys* and filtered by keywords) then apply a *reducer* of two
110        arguments cumulatively to the mapped values, from left to right,
111        so as to reduce the values to a single result (per group of
112        *keys*).  If *keys* is omitted, a single result is returned,
113        otherwise returns a :class:`CompareDict` object.
114
115        *mapper* (function or other callable):
116            Should accept a column value and return a computed result.
117            Mapper always receives a single argument---if *columns* is a
118            sequence, *mapper* will receive a tuple of values from the
119            specified columns.
120        *reducer* (function or other callable):
121            Should accept two arguments (values produced by *mapper*)
122            and apply them, from left to right, to return a single
123            result.
124        *columns* (string or sequence):
125            Name of column or columns whose values are passed to
126            *mapper*.
127        *keys* (None, string, or sequence):
128            Name of key or keys used to group column values.
129        *kwds_filter*:
130            Keywords used to filter rows.
131        """
132        if isinstance(columns, str):
133            get_value = lambda row: row[columns]
134        elif isinstance(columns, Sequence):
135            get_value = lambda row: tuple(row[column] for column in columns)
136        else:
137            raise TypeError('colums must be str or sequence')
138
139        filtered_rows = self.filter_rows(**kwds_filter)
140
141        if not keys:
142            filtered_values = (get_value(row) for row in filtered_rows)
143            mapped_values = (mapper(value) for value in filtered_values)
144            return functools.reduce(reducer, mapped_values)  # <- EXIT!
145
146        if not nonstringiter(keys):
147            keys = (keys,)
148        self._assert_columns_exist(keys)
149
150        result = {}                            # Do not remove this
151        for row in filtered_rows:              # accumulator and loop
152            y = get_value(row)                 # without a good reason!
153            y = mapper(y)                      # While a more functional
154            key = tuple(row[k] for k in keys)  # style (using sorted,
155            if key in result:                  # groupby, and reduce)
156                x = result[key]                # is nicer to read, this
157                result[key] = reducer(x, y)    # base class should
158            else:                              # prioritize memory
159                result[key] = y                # efficiency over other
160        return CompareDict(result, keys)       # considerations.
161
162    def _assert_columns_exist(self, columns):
163        """Asserts that given columns are present in data source,
164        raises LookupError if columns are missing.
165        """
166        if not nonstringiter(columns):
167            columns = (columns,)
168        self_cols = self.columns()
169        is_missing = lambda col: col not in self_cols
170        missing = [c for c in columns if is_missing(c)]
171        if missing:
172            missing = ', '.join(repr(x) for x in missing)
173            msg = '{0} not in {1}'.format(missing, self.__repr__())
174            raise LookupError(msg)
175
176
177########################################################################
178# For Testing
179########################################################################
180class MinimalSource(BaseSource):
181    """Minimal data source implementation for testing."""
182    def __init__(self, data, fieldnames=None):
183        if not fieldnames:
184            data_iter = iter(data)
185            fieldnames = next(data_iter)  # <- First row.
186            data = list(data_iter)        # <- Remaining rows.
187        self._data = data
188        self._fieldnames = fieldnames
189
190    def __repr__(self):
191        return self.__class__.__name__ + '(<data>, <fieldnames>)'
192
193    def columns(self):
194        return self._fieldnames
195
196    def __iter__(self):
197        for row in self._data:
198            yield dict(zip(self._fieldnames, row))
199
200
201########################################################################
202# From sources/adapter.py
203########################################################################
204from .._compatibility.builtins import *
205from .._compatibility.collections.abc import Sequence
206from .._utils import nonstringiter
207from .api07_comp import CompareDict
208from .api07_comp import CompareSet
209
210
211class _FilterValueError(ValueError):
212    """Used by AdapterSource.  This error is raised when attempting to
213    unwrap a filter that specifies an inappropriate (non-missing) value
214    for a missing column."""
215    pass
216
217
218class AdapterSource(BaseSource):
219    """A wrapper class that adapts a data *source* to an *interface* of
220    column names. The *interface* should be a sequence of 2-tuples where
221    the first item is the existing column name and the second item is
222    the desired column name. If column order is not important, the
223    *interface* can, alternatively, be a dictionary.
224
225    For example, a CSV file that contains the columns 'AAA', 'BBB',
226    and 'DDD' can be adapted to behave as if it has the columns
227    'AAA', 'BBB', 'CCC' and 'DDD' with the following::
228
229        source = CsvSource('mydata.csv')
230        interface = [
231            ('AAA', 'AAA'),
232            ('BBB', 'BBB'),
233            (None,  'CCC'),
234            ('DDD', 'DDD'),
235        ]
236        subject = AdapterSource(source, interface)
237
238    An :class:`AdapterSource` can be thought of as a virtual source that
239    renames, reorders, adds, or removes columns of the original
240    *source*.
241
242    To add a column that does not exist in original, use None in place
243    of a column name (see column 'CCC', above). Columns mapped to None
244    will contain *missing* values (defaults to empty string).  To remove
245    a column, simply omit it from the interface.
246
247    The original source can be accessed via the :attr:`__wrapped__`
248    property.
249    """
250    def __init__(self, source, interface, missing=''):
251        if not isinstance(interface, Sequence):
252            if isinstance(interface, dict):
253                interface = interface.items()
254            interface = sorted(interface)
255
256        source_columns = source.columns()
257        interface_cols = [x[0] for x in interface]
258        for c in interface_cols:
259            if c != None and c not in source_columns:
260                raise KeyError(c)
261
262        self._interface = list(interface)
263        self._missing = missing
264        self.__wrapped__ = source
265
266    def __repr__(self):
267        self_class = self.__class__.__name__
268        wrapped_repr = repr(self.__wrapped__)
269        interface = self._interface
270        missing = self._missing
271        if missing != '':
272            missing = ', missing=' + repr(missing)
273        return '{0}({1}, {2}{3})'.format(self_class, wrapped_repr, interface, missing)
274
275    def columns(self):
276        return [new for (old, new) in self._interface if new != None]
277
278    def __iter__(self):
279        interface = self._interface
280        missing = self._missing
281        for row in self.__wrapped__.__iter__():
282            yield dict((new, row.get(old, missing)) for old, new in interface)
283
284    def filter_rows(self, **kwds):
285        try:
286            unwrap_kwds = self._unwrap_filter(kwds)
287        except _FilterValueError:
288            return  # <- EXIT! Raises StopIteration to signify empty generator.
289
290        interface = self._interface
291        missing = self._missing
292        for row in self.__wrapped__.filter_rows(**unwrap_kwds):
293            yield dict((new, row.get(old, missing)) for old, new in interface)
294
295    def distinct(self, columns, **kwds_filter):
296        unwrap_src = self.__wrapped__  # Unwrap data source.
297        unwrap_cols = self._unwrap_columns(columns)
298        try:
299            unwrap_flt = self._unwrap_filter(kwds_filter)
300        except _FilterValueError:
301            return CompareSet([])  # <- EXIT!
302
303        if not unwrap_cols:
304            iterable = iter(unwrap_src)
305            try:
306                next(iterable)  # Check for any data at all.
307                length = 1 if isinstance(columns, str) else len(columns)
308                result = [tuple([self._missing]) * length]  # Make 1 row of *missing* vals.
309            except StopIteration:
310                result = []  # If no data, result is empty.
311            return CompareSet(result)  # <- EXIT!
312
313        results = unwrap_src.distinct(unwrap_cols, **unwrap_flt)
314        rewrap_cols = self._rewrap_columns(unwrap_cols)
315        return self._rebuild_compareset(results, rewrap_cols, columns)
316
317    def sum(self, column, keys=None, **kwds_filter):
318        return self._aggregate('sum', column, keys, **kwds_filter)
319
320    def count(self, column, keys=None, **kwds_filter):
321        return self._aggregate('count', column, keys, **kwds_filter)
322
323    def _aggregate(self, method, column, keys=None, **kwds_filter):
324        """Call aggregation method ('sum' or 'count'), return result."""
325        unwrap_src = self.__wrapped__
326        unwrap_col = self._unwrap_columns(column)
327        unwrap_keys = self._unwrap_columns(keys)
328        try:
329            unwrap_flt = self._unwrap_filter(kwds_filter)
330        except _FilterValueError:
331            if keys:
332                result = CompareDict({}, keys)
333            else:
334                result = 0
335            return result  # <- EXIT!
336
337        # If all *columns* are missing, build result of missing values.
338        if not unwrap_col:
339            distinct = self.distinct(keys, **kwds_filter)
340            result = ((key, 0) for key in distinct)
341            return CompareDict(result, keys)  # <- EXIT!
342
343        # Get method ('sum' or 'count') and perform aggregation.
344        aggregate = getattr(unwrap_src, method)
345        result = aggregate(unwrap_col, unwrap_keys, **unwrap_flt)
346
347        rewrap_col = self._rewrap_columns(unwrap_col)
348        rewrap_keys = self._rewrap_columns(unwrap_keys)
349        return self._rebuild_comparedict(result, rewrap_col, column,
350                                         rewrap_keys, keys, missing_col=0)
351
352    def mapreduce(self, mapper, reducer, columns, keys=None, **kwds_filter):
353        unwrap_src = self.__wrapped__
354        unwrap_cols = self._unwrap_columns(columns)
355        unwrap_keys = self._unwrap_columns(keys)
356        try:
357            unwrap_flt = self._unwrap_filter(kwds_filter)
358        except _FilterValueError:
359            if keys:
360                result = CompareDict({}, keys)
361            else:
362                result = self._missing
363            return result  # <- EXIT!
364
365        # If all *columns* are missing, build result of missing values.
366        if not unwrap_cols:
367            distinct = self.distinct(keys, **kwds_filter)
368            if isinstance(columns, str):
369                val = self._missing
370            else:
371                val = (self._missing,) * len(columns)
372            result = ((key, val) for key in distinct)
373            return CompareDict(result, keys)  # <- EXIT!
374
375        result = unwrap_src.mapreduce(mapper, reducer,
376                                      unwrap_cols, unwrap_keys, **unwrap_flt)
377
378        rewrap_cols = self._rewrap_columns(unwrap_cols)
379        rewrap_keys = self._rewrap_columns(unwrap_keys)
380        return self._rebuild_comparedict(result, rewrap_cols, columns,
381                                           rewrap_keys, keys,
382                                           missing_col=self._missing)
383
384    def _unwrap_columns(self, columns, interface_dict=None):
385        """Unwrap adapter *columns* to reveal hidden adaptee columns."""
386        if not columns:
387            return None  # <- EXIT!
388
389        if not interface_dict:
390            interface_dict = dict((new, old) for old, new in self._interface)
391
392        if isinstance(columns, str):
393            return interface_dict[columns]  # <- EXIT!
394
395        unwrapped = (interface_dict[k] for k in columns)
396        return tuple(x for x in unwrapped if x != None)
397
398    def _unwrap_filter(self, filter_dict, interface_dict=None):
399        """Unwrap adapter *filter_dict* to reveal hidden adaptee column
400        names.  An unwrapped filter cannot be created if the filter
401        specifies that a missing column equals a non-missing value--if
402        this condition occurs, a _FilterValueError is raised.
403        """
404        if not interface_dict:
405            interface_dict = dict((new, old) for old, new in self._interface)
406
407        translated = {}
408        for k, v in filter_dict.items():
409            tran_k = interface_dict[k]
410            if tran_k != None:
411                translated[tran_k] = v
412            else:
413                if v != self._missing:
414                    raise _FilterValueError('Missing column can only be '
415                                            'filtered to missing value.')
416        return translated
417
418    def _rewrap_columns(self, unwrapped_columns, rev_dict=None):
419        """Take unwrapped adaptee column names and wrap them in adapter
420        column names (specified by _interface).
421        """
422        if not unwrapped_columns:
423            return None  # <- EXIT!
424
425        if rev_dict:
426            interface_dict = dict((old, new) for new, old in rev_dict.items())
427        else:
428            interface_dict = dict(self._interface)
429
430        if isinstance(unwrapped_columns, str):
431            return interface_dict[unwrapped_columns]
432        return tuple(interface_dict[k] for k in unwrapped_columns)
433
434    def _rebuild_compareset(self, result, rewrapped_columns, columns):
435        """Take CompareSet from unwrapped source and rebuild it to match
436        the CompareSet that would be expected from the wrapped source.
437        """
438        normalize = lambda x: x if (isinstance(x, str) or not x) else tuple(x)
439        rewrapped_columns = normalize(rewrapped_columns)
440        columns = normalize(columns)
441
442        if rewrapped_columns == columns:
443            return result  # <- EXIT!
444
445        missing = self._missing
446        def rebuild(x):
447            lookup_dict = dict(zip(rewrapped_columns, x))
448            return tuple(lookup_dict.get(c, missing) for c in columns)
449        return CompareSet(rebuild(x) for x in result)
450
451    def _rebuild_comparedict(self,
452                             result,
453                             rewrapped_columns,
454                             columns,
455                             rewrapped_keys,
456                             keys,
457                             missing_col):
458        """Take CompareDict from unwrapped source and rebuild it to
459        match the CompareDict that would be expected from the wrapped
460        source.
461        """
462        normalize = lambda x: x if (isinstance(x, str) or not x) else tuple(x)
463        rewrapped_columns = normalize(rewrapped_columns)
464        rewrapped_keys = normalize(rewrapped_keys)
465        columns = normalize(columns)
466        keys = normalize(keys)
467
468        if rewrapped_keys == keys and rewrapped_columns == columns:
469            if isinstance(result, CompareDict):
470                key_names = (keys,) if isinstance(keys, str) else keys
471                result.key_names = key_names
472            return result  # <- EXIT!
473
474        try:
475            item_gen = iter(result.items())
476        except AttributeError:
477            item_gen = [(self._missing, result)]
478
479        if rewrapped_keys != keys:
480            def rebuild_keys(k, missing):
481                if isinstance(keys, str):
482                    return k
483                key_dict = dict(zip(rewrapped_keys, k))
484                return tuple(key_dict.get(c, missing) for c in keys)
485            missing_key = self._missing
486            item_gen = ((rebuild_keys(k, missing_key), v) for k, v in item_gen)
487
488        if rewrapped_columns != columns:
489            def rebuild_values(v, missing):
490                if isinstance(columns, str):
491                    return v
492                if not nonstringiter(v):
493                    v = (v,)
494                value_dict = dict(zip(rewrapped_columns, v))
495                return tuple(value_dict.get(v, missing) for v in columns)
496            item_gen = ((k, rebuild_values(v, missing_col)) for k, v in item_gen)
497
498        return CompareDict(item_gen, key_names=keys)
499
500
501########################################################################
502# From sources/multi.py
503########################################################################
504from .._compatibility.builtins import *
505from .._compatibility.collections import defaultdict
506from .._compatibility import itertools
507from .._compatibility import functools
508from .api07_comp import CompareDict
509from .api07_comp import CompareSet
510
511
512class MultiSource(BaseSource):
513    """
514    MultiSource(*sources, missing='')
515
516    A wrapper class that allows multiple data sources to be treated
517    as a single, composite data source::
518
519        subject = datatest.MultiSource(
520            datatest.CsvSource('file1.csv'),
521            datatest.CsvSource('file2.csv'),
522            datatest.CsvSource('file3.csv')
523        )
524
525    The original sources are stored in the :attr:`__wrapped__`
526    attribute.
527    """
528    def __init__(self, *sources, **kwd):
529        """
530        __init__(self, *sources, missing='')
531
532        Initialize self.
533        """
534        if not sources:
535            raise TypeError('expected 1 or more sources, got 0')
536
537        missing = kwd.pop('missing', '')  # Accept as keyword-only argument.
538
539        if kwd:                     # Enforce keyword-only argument
540            key, _ = kwd.popitem()  # behavior that works in Python 2.x.
541            msg = "__init__() got an unexpected keyword argument " + repr(key)
542            raise TypeError(msg)
543
544        if not all(isinstance(s, BaseSource) for s in sources):
545            raise TypeError('sources must be derived from BaseSource')
546
547        all_columns = []
548        for s in sources:
549            for c in s.columns():
550                if c not in all_columns:
551                    all_columns.append(c)
552
553        normalized_sources = []
554        for s in sources:
555            if set(s.columns()) < set(all_columns):
556                columns = s.columns()
557                make_old = lambda x: x if x in columns else None
558                interface = [(make_old(x), x) for x in all_columns]
559                s = AdapterSource(s, interface, missing)
560            normalized_sources.append(s)
561
562        self._columns = all_columns
563        self._sources = normalized_sources
564        self.__wrapped__ = sources  # <- Original sources.
565
566    def __repr__(self):
567        """Return a string representation of the data source."""
568        cls_name = self.__class__.__name__
569        src_names = [repr(src) for src in self.__wrapped__]  # Get reprs.
570        src_names = ['    ' + src for src in src_names]      # Prefix with 4 spaces.
571        src_names = ',\n'.join(src_names)                    # Join w/ comma & new-line.
572        return '{0}(\n{1}\n)'.format(cls_name, src_names)
573
574    def columns(self):
575        """Return list of column names."""
576        return self._columns
577
578    def __iter__(self):
579        """Return iterable of dictionary rows (like csv.DictReader)."""
580        for source in self._sources:
581            for row in source.__iter__():
582                yield row
583
584    def filter_rows(self, **kwds):
585        for source in self._sources:
586            for row in source.filter_rows(**kwds):
587                yield row
588
589    def distinct(self, columns, **kwds_filter):
590        """Return iterable of tuples containing distinct *column*
591        values.
592        """
593        fn = lambda source: source.distinct(columns, **kwds_filter)
594        results = (fn(source) for source in self._sources)
595        results = itertools.chain(*results)
596        return CompareSet(results)
597
598    def sum(self, column, keys=None, **kwds_filter):
599        """Return sum of values in *column* grouped by *keys*."""
600        return self._aggregate('sum', column, keys, **kwds_filter)
601
602    def count(self, column, keys=None, **kwds_filter):
603        return self._aggregate('count', column, keys, **kwds_filter)
604
605    def _aggregate(self, method, column, keys=None, **kwds_filter):
606        """Call aggregation method ('sum' or 'count'), return result."""
607        fn = lambda src: getattr(src, method)(column, keys, **kwds_filter)
608        results = (fn(source) for source in self._sources)  # Perform aggregation.
609
610        if not keys:
611            return sum(results)  # <- EXIT!
612
613        total = defaultdict(lambda: 0)
614        for result in results:
615            for key, val in result.items():
616                total[key] += val
617        return CompareDict(total, keys)
618
619    def mapreduce(self, mapper, reducer, columns, keys=None, **kwds_filter):
620        fn = lambda source: source.mapreduce(mapper, reducer, columns, keys, **kwds_filter)
621        results = (fn(source) for source in self._sources)
622
623        if not keys:
624            return functools.reduce(reducer, results)  # <- EXIT!
625
626        final_result = {}
627        results = (result.items() for result in results)
628        for key, y in itertools.chain(*results):
629            if key in final_result:
630                x = final_result[key]
631                final_result[key] = reducer(x, y)
632            else:
633                final_result[key] = y
634        return CompareDict(final_result, keys)
635
636
637########################################################################
638# From sources/sqlite.py
639########################################################################
640import sqlite3
641from .._compatibility.builtins import *
642from .._compatibility import decimal
643from .._utils import nonstringiter
644from .api07_comp import CompareDict
645from .api07_comp import CompareSet
646
647
648sqlite3.register_adapter(decimal.Decimal, float)
649
650class SqliteBase(BaseSource):
651    """Base class four SqliteSource and CsvSource (not intended to be
652    instantiated directly).
653    """
654    def __new__(cls, *args, **kwds):
655        if cls is SqliteBase:
656            msg = 'cannot instantiate SqliteBase directly - make a subclass'
657            raise NotImplementedError(msg)
658        return super(SqliteBase, cls).__new__(cls)
659
660    def __init__(self, connection, table):
661        """Initialize self."""
662        self._connection = connection
663        self._table = table
664
665    def __repr__(self):
666        """Return a string representation of the data source."""
667        cls_name = self.__class__.__name__
668        conn_name = str(self._connection)
669        tbl_name = self._table
670        return '{0}({1}, table={2!r})'.format(cls_name, conn_name, tbl_name)
671
672    def columns(self):
673        """Return list of column names."""
674        cursor = self._connection.cursor()
675        cursor.execute('PRAGMA table_info(' + self._table + ')')
676        return [x[1] for x in cursor.fetchall()]
677
678    def __iter__(self):
679        """Return iterable of dictionary rows (like csv.DictReader)."""
680        cursor = self._connection.cursor()
681        cursor.execute('SELECT * FROM ' + self._table)
682
683        column_names = self.columns()
684        dict_row = lambda x: dict(zip(column_names, x))
685        return (dict_row(row) for row in cursor.fetchall())
686
687    def filter_rows(self, **kwds):
688        if kwds:
689            cursor = self._connection.cursor()
690            cursor = self._execute_query('*', **kwds)  # <- applies filter
691            column_names = self.columns()
692            dict_row = lambda row: dict(zip(column_names, row))
693            return (dict_row(row) for row in cursor)
694        return self.__iter__()
695
696    def distinct(self, columns, **kwds_filter):
697        """Return iterable of tuples containing distinct *columns*
698        values.
699        """
700        if not nonstringiter(columns):
701            columns = (columns,)
702        self._assert_columns_exist(columns)
703        select_clause = [self._normalize_column(x) for x in columns]
704        select_clause = ', '.join(select_clause)
705        select_clause = 'DISTINCT ' + select_clause
706
707        cursor = self._execute_query(select_clause, **kwds_filter)
708        return CompareSet(cursor)
709
710    def sum(self, column, keys=None, **kwds_filter):
711        """Returns :class:`CompareDict` containing sums of *column*
712        values grouped by *keys*.
713        """
714        self._assert_columns_exist(column)
715        column = self._normalize_column(column)
716        sql_functions = 'SUM({0})'.format(column)
717        return self._sql_aggregate(sql_functions, keys, **kwds_filter)
718
719    def count(self, column, keys=None, **kwds_filter):
720        """Returns :class:`CompareDict` containing count of non-empty
721        *column* values grouped by *keys*.
722        """
723        self._assert_columns_exist(column)
724        sql_function = "SUM(CASE COALESCE({0}, '') WHEN '' THEN 0 ELSE 1 END)"
725        sql_function = sql_function.format(self._normalize_column(column))
726        return self._sql_aggregate(sql_function, keys, **kwds_filter)
727
728    def _sql_aggregate(self, sql_function, keys=None, **kwds_filter):
729        """Aggregates values using SQL function select--e.g.,
730        'COUNT(*)', 'SUM(col1)', etc.
731        """
732        # TODO: _sql_aggregate has grown messy after a handful of
733        # iterations look to refactor it in the future to improve
734        # maintainability.
735        if not nonstringiter(sql_function):
736            sql_function = (sql_function,)
737
738        if keys == None:
739            sql_function = ', '.join(sql_function)
740            cursor = self._execute_query(sql_function, **kwds_filter)
741            result = cursor.fetchone()
742            if len(result) == 1:
743                return result[0]
744            return result  # <- EXIT!
745
746        if not nonstringiter(keys):
747            keys = (keys,)
748        group_clause = [self._normalize_column(x) for x in keys]
749        group_clause = ', '.join(group_clause)
750
751        select_clause = '{0}, {1}'.format(group_clause, ', '.join(sql_function))
752        trailing_clause = 'GROUP BY ' + group_clause
753
754        cursor = self._execute_query(select_clause, trailing_clause, **kwds_filter)
755        pos = len(sql_function)
756        iterable = ((row[:-pos], getvals(row)) for row in cursor)
757        if pos > 1:
758            # Gets values by slicing (i.e., row[-pos:]).
759            iterable = ((row[:-pos], row[-pos:]) for row in cursor)
760        else:
761            # Gets value by index (i.e., row[-pos]).
762            iterable = ((row[:-pos], row[-pos]) for row in cursor)
763        return CompareDict(iterable, keys)
764
765    def mapreduce(self, mapper, reducer, columns, keys=None, **kwds_filter):
766        obj = super(SqliteBase, self)  # 2.x compatible calling convention.
767        return obj.mapreduce(mapper, reducer, columns, keys, **kwds_filter)
768        # SqliteBase doesn't implement its own mapreduce() optimization.
769        # A generalized, SQL optimization could do little more than the
770        # already-optmized filter_rows() method.  Since the super-class'
771        # mapreduce() already uses filter_rows() internally, a separate
772        # optimization is unnecessary.
773
774    def _execute_query(self, select_clause, trailing_clause=None, **kwds_filter):
775        """Execute query and return cursor object."""
776        try:
777            stmnt, params = self._build_query(self._table, select_clause, **kwds_filter)
778            if trailing_clause:
779                stmnt += '\n' + trailing_clause
780            cursor = self._connection.cursor()
781            #print(stmnt, params)
782            cursor.execute(stmnt, params)
783        except Exception as e:
784            exc_cls = e.__class__
785            msg = '%s\n  query: %s\n  params: %r' % (e, stmnt, params)
786            raise exc_cls(msg)
787        return cursor
788
789    @classmethod
790    def _build_query(cls, table, select_clause, **kwds_filter):
791        """Return 'SELECT' query."""
792        query = 'SELECT ' + select_clause + ' FROM ' + table
793        where_clause, params = cls._build_where_clause(**kwds_filter)
794        if where_clause:
795            query = query + ' WHERE ' + where_clause
796        return query, params
797
798    @staticmethod
799    def _build_where_clause(**kwds_filter):
800        """Return 'WHERE' clause that implements *kwds_filter*
801        constraints.
802        """
803        clause = []
804        params = []
805        items = kwds_filter.items()
806        items = sorted(items, key=lambda x: x[0])  # Ordered by key.
807        for key, val in items:
808            if nonstringiter(val):
809                clause.append(key + ' IN (%s)' % (', '.join('?' * len(val))))
810                for x in val:
811                    params.append(x)
812            else:
813                clause.append(key + '=?')
814                params.append(val)
815
816        clause = ' AND '.join(clause) if clause else ''
817        return clause, params
818
819    def create_index(self, *columns):
820        """Create an index for specified columns---can speed up testing
821        in some cases.
822
823        See :meth:`SqliteSource.create_index` for more details.
824        """
825        self._assert_columns_exist(columns)
826
827        # Build index name.
828        whitelist = lambda col: ''.join(x for x in col if x.isalnum())
829        idx_name = '_'.join(whitelist(col) for col in columns)
830        idx_name = 'idx_{0}_{1}'.format(self._table, idx_name)
831
832        # Build column names.
833        col_names = [self._normalize_column(x) for x in columns]
834        col_names = ', '.join(col_names)
835
836        # Prepare statement.
837        statement = 'CREATE INDEX IF NOT EXISTS {0} ON {1} ({2})'
838        statement = statement.format(idx_name, self._table, col_names)
839
840        # Create index.
841        cursor = self._connection.cursor()
842        cursor.execute(statement)
843
844    @staticmethod
845    def _normalize_column(column):
846        """Normalize value for use as SQLite column name."""
847        if not isinstance(column, str):
848            msg = "expected column of type 'str', got {0!r} instead"
849            raise TypeError(msg.format(column.__class__.__name__))
850        column = column.strip()
851        column = column.replace('"', '""')  # Escape quotes.
852        if column == '':
853            column = '_empty_'
854        return '"' + column + '"'
855
856
857class SqliteSource(SqliteBase):
858    """Loads *table* data from given SQLite *connection*:
859    ::
860
861        conn = sqlite3.connect('mydatabase.sqlite3')
862        subject = datatest.SqliteSource(conn, 'mytable')
863    """
864    @classmethod
865    def from_records(cls, data, columns=None):
866        """Alternate constructor to load an existing collection of
867        records into a tempoarary SQLite database.  Loads *data* (an
868        iterable of lists, tuples, or dicts) into a temporary table
869        using the named *columns*::
870
871            records = [
872                ('a', 'x'),
873                ('b', 'y'),
874                ('c', 'z'),
875                ...
876            ]
877            subject = datatest.SqliteSource.from_records(records, ['col1', 'col2'])
878
879        The *columns* argument can be omitted if *data* is a collection
880        of dictionary or namedtuple records::
881
882            dict_rows = [
883                {'col1': 'a', 'col2': 'x'},
884                {'col1': 'b', 'col2': 'y'},
885                {'col1': 'c', 'col2': 'z'},
886                ...
887            ]
888            subject = datatest.SqliteSource.from_records(dict_rows)
889        """
890        connection, table = _load_temp_sqlite_table(columns, data)
891        return cls(connection, table)
892
893    def create_index(self, *columns):
894        """Create an index for specified columns---can speed up testing
895        in some cases.
896
897        Indexes should be added one-by-one to tune a test suite's
898        over-all performance.  Creating several indexes before testing
899        even begins could lead to worse performance so use them with
900        discretion.
901
902        An example:  If you're using "town" to group aggregation tests
903        (like ``self.assertSubjectSum('population', ['town'])``), then
904        you might be able to improve performance by adding an index for
905        the "town" column::
906
907            subject.create_index('town')
908
909        Using two or more columns creates a multi-column index::
910
911            subject.create_index('town', 'zipcode')
912
913        Calling the function multiple times will create multiple
914        indexes::
915
916            subject.create_index('town')
917            subject.create_index('zipcode')
918        """
919        # Calling super() with older convention to support Python 2.7 & 2.6.
920        super(SqliteSource, self).create_index(*columns)
921
922
923########################################################################
924# From sources/csv.py
925########################################################################
926import inspect
927import os
928import sys
929import warnings
930from .._compatibility.builtins import *
931
932
933class CsvSource(SqliteBase):
934    """Loads CSV data from *file* (path or file-like object):
935    ::
936
937        subject = datatest.CsvSource('mydata.csv')
938    """
939    def __init__(self, file, encoding=None, in_memory=False, **fmtparams):
940        """Initialize self."""
941        # The arg *in_memory* is now unused but should be kept in signature
942        # so that old code doesn't error-out.
943
944        global DEFAULT_CONNECTION
945
946        self._file_repr = repr(file)
947
948        # If *file* is relative path, uses directory of calling file as base.
949        if isinstance(file, str) and not os.path.isabs(file):
950            calling_frame = sys._getframe(1)
951            calling_file = inspect.getfile(calling_frame)
952            base_path = os.path.dirname(calling_file)
953            file = os.path.join(base_path, file)
954            file = os.path.normpath(file)
955
956        # Create temporary SQLite table object.
957        connection = DEFAULT_CONNECTION
958        cursor = connection.cursor()
959        with savepoint(cursor):
960            table = new_table_name(cursor)
961            load_csv(cursor, table, file, encoding=encoding, **fmtparams)
962
963        # Calling super() with older convention to support Python 2.7 & 2.6.
964        super(CsvSource, self).__init__(connection, table)
965
966    def __repr__(self):
967        """Return a string representation of the data source."""
968        cls_name = self.__class__.__name__
969        src_file = self._file_repr
970        return '{0}({1})'.format(cls_name, src_file)
971
972
973########################################################################
974# From sources/excel.py
975########################################################################
976
977class ExcelSource(SqliteBase):
978    """Loads first worksheet from XLSX or XLS file *path*::
979
980        subject = datatest.ExcelSource('mydata.xlsx')
981
982    Specific worksheets can be accessed by name::
983
984        subject = datatest.ExcelSource('mydata.xlsx', 'Sheet 2')
985
986    .. note::
987        This data source is optional---it requires the third-party
988        library `xlrd <https://pypi.org/project/xlrd/>`_.
989    """
990    def __init__(self, path, worksheet=None, in_memory=False):
991        """Initialize self."""
992        try:
993            import xlrd
994        except ImportError:
995            raise ImportError(
996                "No module named 'xlrd'\n"
997                "\n"
998                "This is an optional data source that requires the "
999                "third-party library 'xlrd'."
1000            )
1001
1002        self._file_repr = repr(path)
1003
1004        # Open Excel file and get worksheet.
1005        book = xlrd.open_workbook(path, on_demand=True)
1006        if worksheet:
1007            sheet = book.sheet_by_name(worksheet)
1008        else:
1009            sheet = book.sheet_by_index(0)
1010
1011        # Build SQLite table from records, release resources.
1012        iterrows = (sheet.row(i) for i in range(sheet.nrows))
1013        iterrows = ([x.value for x in row] for row in iterrows)
1014        columns = next(iterrows)  # <- Get header row.
1015        connection, table = _load_temp_sqlite_table(columns, iterrows)
1016        book.release_resources()
1017
1018        # Calling super() with older convention to support Python 2.7 & 2.6.
1019        super(ExcelSource, self).__init__(connection, table)
1020
1021
1022########################################################################
1023# From sources/pandas.py
1024########################################################################
1025import re
1026
1027
1028def _version_info(module):
1029    """Helper function returns a tuple containing the version number
1030    components for a given module.
1031    """
1032    try:
1033        version = module.__version__
1034    except AttributeError:
1035        version = str(module)
1036
1037    def cast_as_int(value):
1038        try:
1039            return int(value)
1040        except ValueError:
1041            return value
1042
1043    return tuple(cast_as_int(x) for x in re.split('[.+]', version))
1044
1045
1046class PandasSource(BaseSource):
1047    """Loads pandas DataFrame as a data source:
1048
1049    .. code-block:: python
1050
1051        subject = datatest.PandasSource(df)
1052
1053    .. note::
1054        This data source is optional---it requires the third-party
1055        library `pandas <https://pypi.org/project/pandas/>`_.
1056    """
1057    def __init__(self, df):
1058        """Initialize self."""
1059        self._df = df
1060        self._default_index = (df.index.names == [None])
1061        self._pandas = __import__('pandas')
1062
1063    def __repr__(self):
1064        """Return a string representation of the data source."""
1065        cls_name = self.__class__.__name__
1066        hex_id = hex(id(self._df))
1067        return "{0}(<pandas.DataFrame object at {1}>)".format(cls_name, hex_id)
1068
1069    def __iter__(self):
1070        """Return iterable of dictionary rows (like csv.DictReader)."""
1071        columns = self.columns()
1072        if self._default_index:
1073            for row in self._df.itertuples(index=False):
1074                yield dict(zip(columns, row))
1075        else:
1076            mktup = lambda x: x if isinstance(x, tuple) else tuple([x])
1077            flatten = lambda x: mktup(x[0]) + mktup(x[1:])
1078            for row in self._df.itertuples(index=True):
1079                yield dict(zip(columns, flatten(row)))
1080
1081    def columns(self):
1082        """Return list of column names."""
1083        if self._default_index:
1084            return list(self._df.columns)
1085        return list(self._df.index.names) + list(self._df.columns)
1086
1087    def count(self, column, keys=None, **kwds_filter):
1088        """Returns CompareDict containing count of non-empty *column*
1089        values grouped by *keys*.
1090        """
1091        isnull = self._pandas.isnull
1092        mapper = lambda value: 1 if (value and not isnull(value)) else 0
1093        reducer = lambda x, y: x + y
1094        return self.mapreduce(mapper, reducer, column, keys, **kwds_filter)
1095