1# -*- coding: utf-8 -*- 2from __future__ import absolute_import 3from .load_csv import load_csv 4from .temptable import ( 5 load_data, 6 new_table_name, 7 savepoint, 8) 9from .squint.query import DEFAULT_CONNECTION 10 11 12def _load_temp_sqlite_table(columns, records): 13 global DEFAULT_CONNECTION 14 cursor = DEFAULT_CONNECTION.cursor() 15 with savepoint(cursor): 16 table = new_table_name(cursor) 17 load_data(cursor, table, columns, records) 18 return DEFAULT_CONNECTION, table 19 20 21######################################################################## 22# From sources/base.py 23######################################################################## 24from .._compatibility.builtins import * 25from .._compatibility.collections.abc import Sequence 26from .._compatibility import decimal 27from .._compatibility import functools 28from .._utils import nonstringiter 29from .api07_comp import CompareDict 30from .api07_comp import CompareSet 31 32 33class BaseSource(object): 34 """Common base class for all data sources. Custom sources can be 35 created by subclassing BaseSource and implementing 36 :meth:`__init__()`, :meth:`__repr__()`, :meth:`columns()` and 37 :meth:`__iter__()`. 38 39 All data sources implement a common set of methods. 40 """ 41 def __new__(cls, *args, **kwds): 42 if cls is BaseSource: 43 msg = ('Cannot instantiate BaseSource directly. Use a ' 44 'data source of the appropriate type or make a ' 45 'subclass.') 46 raise NotImplementedError(msg) 47 return super(BaseSource, cls).__new__(cls) 48 49 def __init__(self): 50 """Initialize self.""" 51 return NotImplemented 52 53 def __repr__(self): 54 """Returns string representation describing the data source.""" 55 return NotImplemented 56 57 def columns(self): 58 """Returns list of column names.""" 59 return NotImplemented 60 61 def __iter__(self): 62 """Returns iterable of dictionary rows (like 63 :class:`csv.DictReader`).""" 64 return NotImplemented 65 66 def filter_rows(self, **kwds): 67 """Returns iterable of dictionary rows (like 68 :class:`csv.DictReader`) filtered by keywords. E.g., where 69 column1=value1, column2=value2, etc. (unoptimized, uses 70 :meth:`__iter__`). 71 """ 72 if kwds: 73 normalize = lambda v: (v,) if isinstance(v, str) else v 74 kwds = dict((k, normalize(v)) for k, v in kwds.items()) 75 matches_kwds = lambda row: all(row[k] in v for k, v in kwds.items()) 76 return filter(matches_kwds, self.__iter__()) 77 return self.__iter__() 78 79 def distinct(self, columns, **kwds_filter): 80 """Returns :class:`CompareSet` of distinct values or distinct 81 tuples of values if given multiple *columns* (unoptimized, uses 82 :meth:`__iter__`). 83 """ 84 if not nonstringiter(columns): 85 columns = (columns,) 86 self._assert_columns_exist(columns) 87 iterable = self.filter_rows(**kwds_filter) # Filtered rows only. 88 iterable = (tuple(row[c] for c in columns) for row in iterable) 89 return CompareSet(iterable) 90 91 def sum(self, column, keys=None, **kwds_filter): 92 """Returns :class:`CompareDict` containing sums of *column* 93 values grouped by *keys*. 94 """ 95 mapper = lambda x: decimal.Decimal(x) if x else decimal.Decimal(0) 96 reducer = lambda x, y: x + y 97 return self.mapreduce(mapper, reducer, column, keys, **kwds_filter) 98 99 def count(self, column, keys=None, **kwds_filter): 100 """Returns :class:`CompareDict` containing count of non-empty 101 *column* values grouped by *keys*. 102 """ 103 mapper = lambda value: 1 if value else 0 # 1 for truthy, 0 for falsy 104 reducer = lambda x, y: x + y 105 return self.mapreduce(mapper, reducer, column, keys, **kwds_filter) 106 107 def mapreduce(self, mapper, reducer, columns, keys=None, **kwds_filter): 108 """Apply a *mapper* to specified *columns* (which are grouped by 109 *keys* and filtered by keywords) then apply a *reducer* of two 110 arguments cumulatively to the mapped values, from left to right, 111 so as to reduce the values to a single result (per group of 112 *keys*). If *keys* is omitted, a single result is returned, 113 otherwise returns a :class:`CompareDict` object. 114 115 *mapper* (function or other callable): 116 Should accept a column value and return a computed result. 117 Mapper always receives a single argument---if *columns* is a 118 sequence, *mapper* will receive a tuple of values from the 119 specified columns. 120 *reducer* (function or other callable): 121 Should accept two arguments (values produced by *mapper*) 122 and apply them, from left to right, to return a single 123 result. 124 *columns* (string or sequence): 125 Name of column or columns whose values are passed to 126 *mapper*. 127 *keys* (None, string, or sequence): 128 Name of key or keys used to group column values. 129 *kwds_filter*: 130 Keywords used to filter rows. 131 """ 132 if isinstance(columns, str): 133 get_value = lambda row: row[columns] 134 elif isinstance(columns, Sequence): 135 get_value = lambda row: tuple(row[column] for column in columns) 136 else: 137 raise TypeError('colums must be str or sequence') 138 139 filtered_rows = self.filter_rows(**kwds_filter) 140 141 if not keys: 142 filtered_values = (get_value(row) for row in filtered_rows) 143 mapped_values = (mapper(value) for value in filtered_values) 144 return functools.reduce(reducer, mapped_values) # <- EXIT! 145 146 if not nonstringiter(keys): 147 keys = (keys,) 148 self._assert_columns_exist(keys) 149 150 result = {} # Do not remove this 151 for row in filtered_rows: # accumulator and loop 152 y = get_value(row) # without a good reason! 153 y = mapper(y) # While a more functional 154 key = tuple(row[k] for k in keys) # style (using sorted, 155 if key in result: # groupby, and reduce) 156 x = result[key] # is nicer to read, this 157 result[key] = reducer(x, y) # base class should 158 else: # prioritize memory 159 result[key] = y # efficiency over other 160 return CompareDict(result, keys) # considerations. 161 162 def _assert_columns_exist(self, columns): 163 """Asserts that given columns are present in data source, 164 raises LookupError if columns are missing. 165 """ 166 if not nonstringiter(columns): 167 columns = (columns,) 168 self_cols = self.columns() 169 is_missing = lambda col: col not in self_cols 170 missing = [c for c in columns if is_missing(c)] 171 if missing: 172 missing = ', '.join(repr(x) for x in missing) 173 msg = '{0} not in {1}'.format(missing, self.__repr__()) 174 raise LookupError(msg) 175 176 177######################################################################## 178# For Testing 179######################################################################## 180class MinimalSource(BaseSource): 181 """Minimal data source implementation for testing.""" 182 def __init__(self, data, fieldnames=None): 183 if not fieldnames: 184 data_iter = iter(data) 185 fieldnames = next(data_iter) # <- First row. 186 data = list(data_iter) # <- Remaining rows. 187 self._data = data 188 self._fieldnames = fieldnames 189 190 def __repr__(self): 191 return self.__class__.__name__ + '(<data>, <fieldnames>)' 192 193 def columns(self): 194 return self._fieldnames 195 196 def __iter__(self): 197 for row in self._data: 198 yield dict(zip(self._fieldnames, row)) 199 200 201######################################################################## 202# From sources/adapter.py 203######################################################################## 204from .._compatibility.builtins import * 205from .._compatibility.collections.abc import Sequence 206from .._utils import nonstringiter 207from .api07_comp import CompareDict 208from .api07_comp import CompareSet 209 210 211class _FilterValueError(ValueError): 212 """Used by AdapterSource. This error is raised when attempting to 213 unwrap a filter that specifies an inappropriate (non-missing) value 214 for a missing column.""" 215 pass 216 217 218class AdapterSource(BaseSource): 219 """A wrapper class that adapts a data *source* to an *interface* of 220 column names. The *interface* should be a sequence of 2-tuples where 221 the first item is the existing column name and the second item is 222 the desired column name. If column order is not important, the 223 *interface* can, alternatively, be a dictionary. 224 225 For example, a CSV file that contains the columns 'AAA', 'BBB', 226 and 'DDD' can be adapted to behave as if it has the columns 227 'AAA', 'BBB', 'CCC' and 'DDD' with the following:: 228 229 source = CsvSource('mydata.csv') 230 interface = [ 231 ('AAA', 'AAA'), 232 ('BBB', 'BBB'), 233 (None, 'CCC'), 234 ('DDD', 'DDD'), 235 ] 236 subject = AdapterSource(source, interface) 237 238 An :class:`AdapterSource` can be thought of as a virtual source that 239 renames, reorders, adds, or removes columns of the original 240 *source*. 241 242 To add a column that does not exist in original, use None in place 243 of a column name (see column 'CCC', above). Columns mapped to None 244 will contain *missing* values (defaults to empty string). To remove 245 a column, simply omit it from the interface. 246 247 The original source can be accessed via the :attr:`__wrapped__` 248 property. 249 """ 250 def __init__(self, source, interface, missing=''): 251 if not isinstance(interface, Sequence): 252 if isinstance(interface, dict): 253 interface = interface.items() 254 interface = sorted(interface) 255 256 source_columns = source.columns() 257 interface_cols = [x[0] for x in interface] 258 for c in interface_cols: 259 if c != None and c not in source_columns: 260 raise KeyError(c) 261 262 self._interface = list(interface) 263 self._missing = missing 264 self.__wrapped__ = source 265 266 def __repr__(self): 267 self_class = self.__class__.__name__ 268 wrapped_repr = repr(self.__wrapped__) 269 interface = self._interface 270 missing = self._missing 271 if missing != '': 272 missing = ', missing=' + repr(missing) 273 return '{0}({1}, {2}{3})'.format(self_class, wrapped_repr, interface, missing) 274 275 def columns(self): 276 return [new for (old, new) in self._interface if new != None] 277 278 def __iter__(self): 279 interface = self._interface 280 missing = self._missing 281 for row in self.__wrapped__.__iter__(): 282 yield dict((new, row.get(old, missing)) for old, new in interface) 283 284 def filter_rows(self, **kwds): 285 try: 286 unwrap_kwds = self._unwrap_filter(kwds) 287 except _FilterValueError: 288 return # <- EXIT! Raises StopIteration to signify empty generator. 289 290 interface = self._interface 291 missing = self._missing 292 for row in self.__wrapped__.filter_rows(**unwrap_kwds): 293 yield dict((new, row.get(old, missing)) for old, new in interface) 294 295 def distinct(self, columns, **kwds_filter): 296 unwrap_src = self.__wrapped__ # Unwrap data source. 297 unwrap_cols = self._unwrap_columns(columns) 298 try: 299 unwrap_flt = self._unwrap_filter(kwds_filter) 300 except _FilterValueError: 301 return CompareSet([]) # <- EXIT! 302 303 if not unwrap_cols: 304 iterable = iter(unwrap_src) 305 try: 306 next(iterable) # Check for any data at all. 307 length = 1 if isinstance(columns, str) else len(columns) 308 result = [tuple([self._missing]) * length] # Make 1 row of *missing* vals. 309 except StopIteration: 310 result = [] # If no data, result is empty. 311 return CompareSet(result) # <- EXIT! 312 313 results = unwrap_src.distinct(unwrap_cols, **unwrap_flt) 314 rewrap_cols = self._rewrap_columns(unwrap_cols) 315 return self._rebuild_compareset(results, rewrap_cols, columns) 316 317 def sum(self, column, keys=None, **kwds_filter): 318 return self._aggregate('sum', column, keys, **kwds_filter) 319 320 def count(self, column, keys=None, **kwds_filter): 321 return self._aggregate('count', column, keys, **kwds_filter) 322 323 def _aggregate(self, method, column, keys=None, **kwds_filter): 324 """Call aggregation method ('sum' or 'count'), return result.""" 325 unwrap_src = self.__wrapped__ 326 unwrap_col = self._unwrap_columns(column) 327 unwrap_keys = self._unwrap_columns(keys) 328 try: 329 unwrap_flt = self._unwrap_filter(kwds_filter) 330 except _FilterValueError: 331 if keys: 332 result = CompareDict({}, keys) 333 else: 334 result = 0 335 return result # <- EXIT! 336 337 # If all *columns* are missing, build result of missing values. 338 if not unwrap_col: 339 distinct = self.distinct(keys, **kwds_filter) 340 result = ((key, 0) for key in distinct) 341 return CompareDict(result, keys) # <- EXIT! 342 343 # Get method ('sum' or 'count') and perform aggregation. 344 aggregate = getattr(unwrap_src, method) 345 result = aggregate(unwrap_col, unwrap_keys, **unwrap_flt) 346 347 rewrap_col = self._rewrap_columns(unwrap_col) 348 rewrap_keys = self._rewrap_columns(unwrap_keys) 349 return self._rebuild_comparedict(result, rewrap_col, column, 350 rewrap_keys, keys, missing_col=0) 351 352 def mapreduce(self, mapper, reducer, columns, keys=None, **kwds_filter): 353 unwrap_src = self.__wrapped__ 354 unwrap_cols = self._unwrap_columns(columns) 355 unwrap_keys = self._unwrap_columns(keys) 356 try: 357 unwrap_flt = self._unwrap_filter(kwds_filter) 358 except _FilterValueError: 359 if keys: 360 result = CompareDict({}, keys) 361 else: 362 result = self._missing 363 return result # <- EXIT! 364 365 # If all *columns* are missing, build result of missing values. 366 if not unwrap_cols: 367 distinct = self.distinct(keys, **kwds_filter) 368 if isinstance(columns, str): 369 val = self._missing 370 else: 371 val = (self._missing,) * len(columns) 372 result = ((key, val) for key in distinct) 373 return CompareDict(result, keys) # <- EXIT! 374 375 result = unwrap_src.mapreduce(mapper, reducer, 376 unwrap_cols, unwrap_keys, **unwrap_flt) 377 378 rewrap_cols = self._rewrap_columns(unwrap_cols) 379 rewrap_keys = self._rewrap_columns(unwrap_keys) 380 return self._rebuild_comparedict(result, rewrap_cols, columns, 381 rewrap_keys, keys, 382 missing_col=self._missing) 383 384 def _unwrap_columns(self, columns, interface_dict=None): 385 """Unwrap adapter *columns* to reveal hidden adaptee columns.""" 386 if not columns: 387 return None # <- EXIT! 388 389 if not interface_dict: 390 interface_dict = dict((new, old) for old, new in self._interface) 391 392 if isinstance(columns, str): 393 return interface_dict[columns] # <- EXIT! 394 395 unwrapped = (interface_dict[k] for k in columns) 396 return tuple(x for x in unwrapped if x != None) 397 398 def _unwrap_filter(self, filter_dict, interface_dict=None): 399 """Unwrap adapter *filter_dict* to reveal hidden adaptee column 400 names. An unwrapped filter cannot be created if the filter 401 specifies that a missing column equals a non-missing value--if 402 this condition occurs, a _FilterValueError is raised. 403 """ 404 if not interface_dict: 405 interface_dict = dict((new, old) for old, new in self._interface) 406 407 translated = {} 408 for k, v in filter_dict.items(): 409 tran_k = interface_dict[k] 410 if tran_k != None: 411 translated[tran_k] = v 412 else: 413 if v != self._missing: 414 raise _FilterValueError('Missing column can only be ' 415 'filtered to missing value.') 416 return translated 417 418 def _rewrap_columns(self, unwrapped_columns, rev_dict=None): 419 """Take unwrapped adaptee column names and wrap them in adapter 420 column names (specified by _interface). 421 """ 422 if not unwrapped_columns: 423 return None # <- EXIT! 424 425 if rev_dict: 426 interface_dict = dict((old, new) for new, old in rev_dict.items()) 427 else: 428 interface_dict = dict(self._interface) 429 430 if isinstance(unwrapped_columns, str): 431 return interface_dict[unwrapped_columns] 432 return tuple(interface_dict[k] for k in unwrapped_columns) 433 434 def _rebuild_compareset(self, result, rewrapped_columns, columns): 435 """Take CompareSet from unwrapped source and rebuild it to match 436 the CompareSet that would be expected from the wrapped source. 437 """ 438 normalize = lambda x: x if (isinstance(x, str) or not x) else tuple(x) 439 rewrapped_columns = normalize(rewrapped_columns) 440 columns = normalize(columns) 441 442 if rewrapped_columns == columns: 443 return result # <- EXIT! 444 445 missing = self._missing 446 def rebuild(x): 447 lookup_dict = dict(zip(rewrapped_columns, x)) 448 return tuple(lookup_dict.get(c, missing) for c in columns) 449 return CompareSet(rebuild(x) for x in result) 450 451 def _rebuild_comparedict(self, 452 result, 453 rewrapped_columns, 454 columns, 455 rewrapped_keys, 456 keys, 457 missing_col): 458 """Take CompareDict from unwrapped source and rebuild it to 459 match the CompareDict that would be expected from the wrapped 460 source. 461 """ 462 normalize = lambda x: x if (isinstance(x, str) or not x) else tuple(x) 463 rewrapped_columns = normalize(rewrapped_columns) 464 rewrapped_keys = normalize(rewrapped_keys) 465 columns = normalize(columns) 466 keys = normalize(keys) 467 468 if rewrapped_keys == keys and rewrapped_columns == columns: 469 if isinstance(result, CompareDict): 470 key_names = (keys,) if isinstance(keys, str) else keys 471 result.key_names = key_names 472 return result # <- EXIT! 473 474 try: 475 item_gen = iter(result.items()) 476 except AttributeError: 477 item_gen = [(self._missing, result)] 478 479 if rewrapped_keys != keys: 480 def rebuild_keys(k, missing): 481 if isinstance(keys, str): 482 return k 483 key_dict = dict(zip(rewrapped_keys, k)) 484 return tuple(key_dict.get(c, missing) for c in keys) 485 missing_key = self._missing 486 item_gen = ((rebuild_keys(k, missing_key), v) for k, v in item_gen) 487 488 if rewrapped_columns != columns: 489 def rebuild_values(v, missing): 490 if isinstance(columns, str): 491 return v 492 if not nonstringiter(v): 493 v = (v,) 494 value_dict = dict(zip(rewrapped_columns, v)) 495 return tuple(value_dict.get(v, missing) for v in columns) 496 item_gen = ((k, rebuild_values(v, missing_col)) for k, v in item_gen) 497 498 return CompareDict(item_gen, key_names=keys) 499 500 501######################################################################## 502# From sources/multi.py 503######################################################################## 504from .._compatibility.builtins import * 505from .._compatibility.collections import defaultdict 506from .._compatibility import itertools 507from .._compatibility import functools 508from .api07_comp import CompareDict 509from .api07_comp import CompareSet 510 511 512class MultiSource(BaseSource): 513 """ 514 MultiSource(*sources, missing='') 515 516 A wrapper class that allows multiple data sources to be treated 517 as a single, composite data source:: 518 519 subject = datatest.MultiSource( 520 datatest.CsvSource('file1.csv'), 521 datatest.CsvSource('file2.csv'), 522 datatest.CsvSource('file3.csv') 523 ) 524 525 The original sources are stored in the :attr:`__wrapped__` 526 attribute. 527 """ 528 def __init__(self, *sources, **kwd): 529 """ 530 __init__(self, *sources, missing='') 531 532 Initialize self. 533 """ 534 if not sources: 535 raise TypeError('expected 1 or more sources, got 0') 536 537 missing = kwd.pop('missing', '') # Accept as keyword-only argument. 538 539 if kwd: # Enforce keyword-only argument 540 key, _ = kwd.popitem() # behavior that works in Python 2.x. 541 msg = "__init__() got an unexpected keyword argument " + repr(key) 542 raise TypeError(msg) 543 544 if not all(isinstance(s, BaseSource) for s in sources): 545 raise TypeError('sources must be derived from BaseSource') 546 547 all_columns = [] 548 for s in sources: 549 for c in s.columns(): 550 if c not in all_columns: 551 all_columns.append(c) 552 553 normalized_sources = [] 554 for s in sources: 555 if set(s.columns()) < set(all_columns): 556 columns = s.columns() 557 make_old = lambda x: x if x in columns else None 558 interface = [(make_old(x), x) for x in all_columns] 559 s = AdapterSource(s, interface, missing) 560 normalized_sources.append(s) 561 562 self._columns = all_columns 563 self._sources = normalized_sources 564 self.__wrapped__ = sources # <- Original sources. 565 566 def __repr__(self): 567 """Return a string representation of the data source.""" 568 cls_name = self.__class__.__name__ 569 src_names = [repr(src) for src in self.__wrapped__] # Get reprs. 570 src_names = [' ' + src for src in src_names] # Prefix with 4 spaces. 571 src_names = ',\n'.join(src_names) # Join w/ comma & new-line. 572 return '{0}(\n{1}\n)'.format(cls_name, src_names) 573 574 def columns(self): 575 """Return list of column names.""" 576 return self._columns 577 578 def __iter__(self): 579 """Return iterable of dictionary rows (like csv.DictReader).""" 580 for source in self._sources: 581 for row in source.__iter__(): 582 yield row 583 584 def filter_rows(self, **kwds): 585 for source in self._sources: 586 for row in source.filter_rows(**kwds): 587 yield row 588 589 def distinct(self, columns, **kwds_filter): 590 """Return iterable of tuples containing distinct *column* 591 values. 592 """ 593 fn = lambda source: source.distinct(columns, **kwds_filter) 594 results = (fn(source) for source in self._sources) 595 results = itertools.chain(*results) 596 return CompareSet(results) 597 598 def sum(self, column, keys=None, **kwds_filter): 599 """Return sum of values in *column* grouped by *keys*.""" 600 return self._aggregate('sum', column, keys, **kwds_filter) 601 602 def count(self, column, keys=None, **kwds_filter): 603 return self._aggregate('count', column, keys, **kwds_filter) 604 605 def _aggregate(self, method, column, keys=None, **kwds_filter): 606 """Call aggregation method ('sum' or 'count'), return result.""" 607 fn = lambda src: getattr(src, method)(column, keys, **kwds_filter) 608 results = (fn(source) for source in self._sources) # Perform aggregation. 609 610 if not keys: 611 return sum(results) # <- EXIT! 612 613 total = defaultdict(lambda: 0) 614 for result in results: 615 for key, val in result.items(): 616 total[key] += val 617 return CompareDict(total, keys) 618 619 def mapreduce(self, mapper, reducer, columns, keys=None, **kwds_filter): 620 fn = lambda source: source.mapreduce(mapper, reducer, columns, keys, **kwds_filter) 621 results = (fn(source) for source in self._sources) 622 623 if not keys: 624 return functools.reduce(reducer, results) # <- EXIT! 625 626 final_result = {} 627 results = (result.items() for result in results) 628 for key, y in itertools.chain(*results): 629 if key in final_result: 630 x = final_result[key] 631 final_result[key] = reducer(x, y) 632 else: 633 final_result[key] = y 634 return CompareDict(final_result, keys) 635 636 637######################################################################## 638# From sources/sqlite.py 639######################################################################## 640import sqlite3 641from .._compatibility.builtins import * 642from .._compatibility import decimal 643from .._utils import nonstringiter 644from .api07_comp import CompareDict 645from .api07_comp import CompareSet 646 647 648sqlite3.register_adapter(decimal.Decimal, float) 649 650class SqliteBase(BaseSource): 651 """Base class four SqliteSource and CsvSource (not intended to be 652 instantiated directly). 653 """ 654 def __new__(cls, *args, **kwds): 655 if cls is SqliteBase: 656 msg = 'cannot instantiate SqliteBase directly - make a subclass' 657 raise NotImplementedError(msg) 658 return super(SqliteBase, cls).__new__(cls) 659 660 def __init__(self, connection, table): 661 """Initialize self.""" 662 self._connection = connection 663 self._table = table 664 665 def __repr__(self): 666 """Return a string representation of the data source.""" 667 cls_name = self.__class__.__name__ 668 conn_name = str(self._connection) 669 tbl_name = self._table 670 return '{0}({1}, table={2!r})'.format(cls_name, conn_name, tbl_name) 671 672 def columns(self): 673 """Return list of column names.""" 674 cursor = self._connection.cursor() 675 cursor.execute('PRAGMA table_info(' + self._table + ')') 676 return [x[1] for x in cursor.fetchall()] 677 678 def __iter__(self): 679 """Return iterable of dictionary rows (like csv.DictReader).""" 680 cursor = self._connection.cursor() 681 cursor.execute('SELECT * FROM ' + self._table) 682 683 column_names = self.columns() 684 dict_row = lambda x: dict(zip(column_names, x)) 685 return (dict_row(row) for row in cursor.fetchall()) 686 687 def filter_rows(self, **kwds): 688 if kwds: 689 cursor = self._connection.cursor() 690 cursor = self._execute_query('*', **kwds) # <- applies filter 691 column_names = self.columns() 692 dict_row = lambda row: dict(zip(column_names, row)) 693 return (dict_row(row) for row in cursor) 694 return self.__iter__() 695 696 def distinct(self, columns, **kwds_filter): 697 """Return iterable of tuples containing distinct *columns* 698 values. 699 """ 700 if not nonstringiter(columns): 701 columns = (columns,) 702 self._assert_columns_exist(columns) 703 select_clause = [self._normalize_column(x) for x in columns] 704 select_clause = ', '.join(select_clause) 705 select_clause = 'DISTINCT ' + select_clause 706 707 cursor = self._execute_query(select_clause, **kwds_filter) 708 return CompareSet(cursor) 709 710 def sum(self, column, keys=None, **kwds_filter): 711 """Returns :class:`CompareDict` containing sums of *column* 712 values grouped by *keys*. 713 """ 714 self._assert_columns_exist(column) 715 column = self._normalize_column(column) 716 sql_functions = 'SUM({0})'.format(column) 717 return self._sql_aggregate(sql_functions, keys, **kwds_filter) 718 719 def count(self, column, keys=None, **kwds_filter): 720 """Returns :class:`CompareDict` containing count of non-empty 721 *column* values grouped by *keys*. 722 """ 723 self._assert_columns_exist(column) 724 sql_function = "SUM(CASE COALESCE({0}, '') WHEN '' THEN 0 ELSE 1 END)" 725 sql_function = sql_function.format(self._normalize_column(column)) 726 return self._sql_aggregate(sql_function, keys, **kwds_filter) 727 728 def _sql_aggregate(self, sql_function, keys=None, **kwds_filter): 729 """Aggregates values using SQL function select--e.g., 730 'COUNT(*)', 'SUM(col1)', etc. 731 """ 732 # TODO: _sql_aggregate has grown messy after a handful of 733 # iterations look to refactor it in the future to improve 734 # maintainability. 735 if not nonstringiter(sql_function): 736 sql_function = (sql_function,) 737 738 if keys == None: 739 sql_function = ', '.join(sql_function) 740 cursor = self._execute_query(sql_function, **kwds_filter) 741 result = cursor.fetchone() 742 if len(result) == 1: 743 return result[0] 744 return result # <- EXIT! 745 746 if not nonstringiter(keys): 747 keys = (keys,) 748 group_clause = [self._normalize_column(x) for x in keys] 749 group_clause = ', '.join(group_clause) 750 751 select_clause = '{0}, {1}'.format(group_clause, ', '.join(sql_function)) 752 trailing_clause = 'GROUP BY ' + group_clause 753 754 cursor = self._execute_query(select_clause, trailing_clause, **kwds_filter) 755 pos = len(sql_function) 756 iterable = ((row[:-pos], getvals(row)) for row in cursor) 757 if pos > 1: 758 # Gets values by slicing (i.e., row[-pos:]). 759 iterable = ((row[:-pos], row[-pos:]) for row in cursor) 760 else: 761 # Gets value by index (i.e., row[-pos]). 762 iterable = ((row[:-pos], row[-pos]) for row in cursor) 763 return CompareDict(iterable, keys) 764 765 def mapreduce(self, mapper, reducer, columns, keys=None, **kwds_filter): 766 obj = super(SqliteBase, self) # 2.x compatible calling convention. 767 return obj.mapreduce(mapper, reducer, columns, keys, **kwds_filter) 768 # SqliteBase doesn't implement its own mapreduce() optimization. 769 # A generalized, SQL optimization could do little more than the 770 # already-optmized filter_rows() method. Since the super-class' 771 # mapreduce() already uses filter_rows() internally, a separate 772 # optimization is unnecessary. 773 774 def _execute_query(self, select_clause, trailing_clause=None, **kwds_filter): 775 """Execute query and return cursor object.""" 776 try: 777 stmnt, params = self._build_query(self._table, select_clause, **kwds_filter) 778 if trailing_clause: 779 stmnt += '\n' + trailing_clause 780 cursor = self._connection.cursor() 781 #print(stmnt, params) 782 cursor.execute(stmnt, params) 783 except Exception as e: 784 exc_cls = e.__class__ 785 msg = '%s\n query: %s\n params: %r' % (e, stmnt, params) 786 raise exc_cls(msg) 787 return cursor 788 789 @classmethod 790 def _build_query(cls, table, select_clause, **kwds_filter): 791 """Return 'SELECT' query.""" 792 query = 'SELECT ' + select_clause + ' FROM ' + table 793 where_clause, params = cls._build_where_clause(**kwds_filter) 794 if where_clause: 795 query = query + ' WHERE ' + where_clause 796 return query, params 797 798 @staticmethod 799 def _build_where_clause(**kwds_filter): 800 """Return 'WHERE' clause that implements *kwds_filter* 801 constraints. 802 """ 803 clause = [] 804 params = [] 805 items = kwds_filter.items() 806 items = sorted(items, key=lambda x: x[0]) # Ordered by key. 807 for key, val in items: 808 if nonstringiter(val): 809 clause.append(key + ' IN (%s)' % (', '.join('?' * len(val)))) 810 for x in val: 811 params.append(x) 812 else: 813 clause.append(key + '=?') 814 params.append(val) 815 816 clause = ' AND '.join(clause) if clause else '' 817 return clause, params 818 819 def create_index(self, *columns): 820 """Create an index for specified columns---can speed up testing 821 in some cases. 822 823 See :meth:`SqliteSource.create_index` for more details. 824 """ 825 self._assert_columns_exist(columns) 826 827 # Build index name. 828 whitelist = lambda col: ''.join(x for x in col if x.isalnum()) 829 idx_name = '_'.join(whitelist(col) for col in columns) 830 idx_name = 'idx_{0}_{1}'.format(self._table, idx_name) 831 832 # Build column names. 833 col_names = [self._normalize_column(x) for x in columns] 834 col_names = ', '.join(col_names) 835 836 # Prepare statement. 837 statement = 'CREATE INDEX IF NOT EXISTS {0} ON {1} ({2})' 838 statement = statement.format(idx_name, self._table, col_names) 839 840 # Create index. 841 cursor = self._connection.cursor() 842 cursor.execute(statement) 843 844 @staticmethod 845 def _normalize_column(column): 846 """Normalize value for use as SQLite column name.""" 847 if not isinstance(column, str): 848 msg = "expected column of type 'str', got {0!r} instead" 849 raise TypeError(msg.format(column.__class__.__name__)) 850 column = column.strip() 851 column = column.replace('"', '""') # Escape quotes. 852 if column == '': 853 column = '_empty_' 854 return '"' + column + '"' 855 856 857class SqliteSource(SqliteBase): 858 """Loads *table* data from given SQLite *connection*: 859 :: 860 861 conn = sqlite3.connect('mydatabase.sqlite3') 862 subject = datatest.SqliteSource(conn, 'mytable') 863 """ 864 @classmethod 865 def from_records(cls, data, columns=None): 866 """Alternate constructor to load an existing collection of 867 records into a tempoarary SQLite database. Loads *data* (an 868 iterable of lists, tuples, or dicts) into a temporary table 869 using the named *columns*:: 870 871 records = [ 872 ('a', 'x'), 873 ('b', 'y'), 874 ('c', 'z'), 875 ... 876 ] 877 subject = datatest.SqliteSource.from_records(records, ['col1', 'col2']) 878 879 The *columns* argument can be omitted if *data* is a collection 880 of dictionary or namedtuple records:: 881 882 dict_rows = [ 883 {'col1': 'a', 'col2': 'x'}, 884 {'col1': 'b', 'col2': 'y'}, 885 {'col1': 'c', 'col2': 'z'}, 886 ... 887 ] 888 subject = datatest.SqliteSource.from_records(dict_rows) 889 """ 890 connection, table = _load_temp_sqlite_table(columns, data) 891 return cls(connection, table) 892 893 def create_index(self, *columns): 894 """Create an index for specified columns---can speed up testing 895 in some cases. 896 897 Indexes should be added one-by-one to tune a test suite's 898 over-all performance. Creating several indexes before testing 899 even begins could lead to worse performance so use them with 900 discretion. 901 902 An example: If you're using "town" to group aggregation tests 903 (like ``self.assertSubjectSum('population', ['town'])``), then 904 you might be able to improve performance by adding an index for 905 the "town" column:: 906 907 subject.create_index('town') 908 909 Using two or more columns creates a multi-column index:: 910 911 subject.create_index('town', 'zipcode') 912 913 Calling the function multiple times will create multiple 914 indexes:: 915 916 subject.create_index('town') 917 subject.create_index('zipcode') 918 """ 919 # Calling super() with older convention to support Python 2.7 & 2.6. 920 super(SqliteSource, self).create_index(*columns) 921 922 923######################################################################## 924# From sources/csv.py 925######################################################################## 926import inspect 927import os 928import sys 929import warnings 930from .._compatibility.builtins import * 931 932 933class CsvSource(SqliteBase): 934 """Loads CSV data from *file* (path or file-like object): 935 :: 936 937 subject = datatest.CsvSource('mydata.csv') 938 """ 939 def __init__(self, file, encoding=None, in_memory=False, **fmtparams): 940 """Initialize self.""" 941 # The arg *in_memory* is now unused but should be kept in signature 942 # so that old code doesn't error-out. 943 944 global DEFAULT_CONNECTION 945 946 self._file_repr = repr(file) 947 948 # If *file* is relative path, uses directory of calling file as base. 949 if isinstance(file, str) and not os.path.isabs(file): 950 calling_frame = sys._getframe(1) 951 calling_file = inspect.getfile(calling_frame) 952 base_path = os.path.dirname(calling_file) 953 file = os.path.join(base_path, file) 954 file = os.path.normpath(file) 955 956 # Create temporary SQLite table object. 957 connection = DEFAULT_CONNECTION 958 cursor = connection.cursor() 959 with savepoint(cursor): 960 table = new_table_name(cursor) 961 load_csv(cursor, table, file, encoding=encoding, **fmtparams) 962 963 # Calling super() with older convention to support Python 2.7 & 2.6. 964 super(CsvSource, self).__init__(connection, table) 965 966 def __repr__(self): 967 """Return a string representation of the data source.""" 968 cls_name = self.__class__.__name__ 969 src_file = self._file_repr 970 return '{0}({1})'.format(cls_name, src_file) 971 972 973######################################################################## 974# From sources/excel.py 975######################################################################## 976 977class ExcelSource(SqliteBase): 978 """Loads first worksheet from XLSX or XLS file *path*:: 979 980 subject = datatest.ExcelSource('mydata.xlsx') 981 982 Specific worksheets can be accessed by name:: 983 984 subject = datatest.ExcelSource('mydata.xlsx', 'Sheet 2') 985 986 .. note:: 987 This data source is optional---it requires the third-party 988 library `xlrd <https://pypi.org/project/xlrd/>`_. 989 """ 990 def __init__(self, path, worksheet=None, in_memory=False): 991 """Initialize self.""" 992 try: 993 import xlrd 994 except ImportError: 995 raise ImportError( 996 "No module named 'xlrd'\n" 997 "\n" 998 "This is an optional data source that requires the " 999 "third-party library 'xlrd'." 1000 ) 1001 1002 self._file_repr = repr(path) 1003 1004 # Open Excel file and get worksheet. 1005 book = xlrd.open_workbook(path, on_demand=True) 1006 if worksheet: 1007 sheet = book.sheet_by_name(worksheet) 1008 else: 1009 sheet = book.sheet_by_index(0) 1010 1011 # Build SQLite table from records, release resources. 1012 iterrows = (sheet.row(i) for i in range(sheet.nrows)) 1013 iterrows = ([x.value for x in row] for row in iterrows) 1014 columns = next(iterrows) # <- Get header row. 1015 connection, table = _load_temp_sqlite_table(columns, iterrows) 1016 book.release_resources() 1017 1018 # Calling super() with older convention to support Python 2.7 & 2.6. 1019 super(ExcelSource, self).__init__(connection, table) 1020 1021 1022######################################################################## 1023# From sources/pandas.py 1024######################################################################## 1025import re 1026 1027 1028def _version_info(module): 1029 """Helper function returns a tuple containing the version number 1030 components for a given module. 1031 """ 1032 try: 1033 version = module.__version__ 1034 except AttributeError: 1035 version = str(module) 1036 1037 def cast_as_int(value): 1038 try: 1039 return int(value) 1040 except ValueError: 1041 return value 1042 1043 return tuple(cast_as_int(x) for x in re.split('[.+]', version)) 1044 1045 1046class PandasSource(BaseSource): 1047 """Loads pandas DataFrame as a data source: 1048 1049 .. code-block:: python 1050 1051 subject = datatest.PandasSource(df) 1052 1053 .. note:: 1054 This data source is optional---it requires the third-party 1055 library `pandas <https://pypi.org/project/pandas/>`_. 1056 """ 1057 def __init__(self, df): 1058 """Initialize self.""" 1059 self._df = df 1060 self._default_index = (df.index.names == [None]) 1061 self._pandas = __import__('pandas') 1062 1063 def __repr__(self): 1064 """Return a string representation of the data source.""" 1065 cls_name = self.__class__.__name__ 1066 hex_id = hex(id(self._df)) 1067 return "{0}(<pandas.DataFrame object at {1}>)".format(cls_name, hex_id) 1068 1069 def __iter__(self): 1070 """Return iterable of dictionary rows (like csv.DictReader).""" 1071 columns = self.columns() 1072 if self._default_index: 1073 for row in self._df.itertuples(index=False): 1074 yield dict(zip(columns, row)) 1075 else: 1076 mktup = lambda x: x if isinstance(x, tuple) else tuple([x]) 1077 flatten = lambda x: mktup(x[0]) + mktup(x[1:]) 1078 for row in self._df.itertuples(index=True): 1079 yield dict(zip(columns, flatten(row))) 1080 1081 def columns(self): 1082 """Return list of column names.""" 1083 if self._default_index: 1084 return list(self._df.columns) 1085 return list(self._df.index.names) + list(self._df.columns) 1086 1087 def count(self, column, keys=None, **kwds_filter): 1088 """Returns CompareDict containing count of non-empty *column* 1089 values grouped by *keys*. 1090 """ 1091 isnull = self._pandas.isnull 1092 mapper = lambda value: 1 if (value and not isnull(value)) else 0 1093 reducer = lambda x, y: x + y 1094 return self.mapreduce(mapper, reducer, column, keys, **kwds_filter) 1095