1# -*- coding: utf-8 -*- 2# This file is part of beets. 3# Copyright 2016, Adrian Sampson. 4# 5# Permission is hereby granted, free of charge, to any person obtaining 6# a copy of this software and associated documentation files (the 7# "Software"), to deal in the Software without restriction, including 8# without limitation the rights to use, copy, modify, merge, publish, 9# distribute, sublicense, and/or sell copies of the Software, and to 10# permit persons to whom the Software is furnished to do so, subject to 11# the following conditions: 12# 13# The above copyright notice and this permission notice shall be 14# included in all copies or substantial portions of the Software. 15 16"""The Query type hierarchy for DBCore. 17""" 18from __future__ import division, absolute_import, print_function 19 20import re 21from operator import mul 22from beets import util 23from datetime import datetime, timedelta 24import unicodedata 25from functools import reduce 26import six 27 28if not six.PY2: 29 buffer = memoryview # sqlite won't accept memoryview in python 2 30 31 32class ParsingError(ValueError): 33 """Abstract class for any unparseable user-requested album/query 34 specification. 35 """ 36 37 38class InvalidQueryError(ParsingError): 39 """Represent any kind of invalid query. 40 41 The query should be a unicode string or a list, which will be space-joined. 42 """ 43 44 def __init__(self, query, explanation): 45 if isinstance(query, list): 46 query = " ".join(query) 47 message = u"'{0}': {1}".format(query, explanation) 48 super(InvalidQueryError, self).__init__(message) 49 50 51class InvalidQueryArgumentValueError(ParsingError): 52 """Represent a query argument that could not be converted as expected. 53 54 It exists to be caught in upper stack levels so a meaningful (i.e. with the 55 query) InvalidQueryError can be raised. 56 """ 57 58 def __init__(self, what, expected, detail=None): 59 message = u"'{0}' is not {1}".format(what, expected) 60 if detail: 61 message = u"{0}: {1}".format(message, detail) 62 super(InvalidQueryArgumentValueError, self).__init__(message) 63 64 65class Query(object): 66 """An abstract class representing a query into the item database. 67 """ 68 69 def clause(self): 70 """Generate an SQLite expression implementing the query. 71 72 Return (clause, subvals) where clause is a valid sqlite 73 WHERE clause implementing the query and subvals is a list of 74 items to be substituted for ?s in the clause. 75 """ 76 return None, () 77 78 def match(self, item): 79 """Check whether this query matches a given Item. Can be used to 80 perform queries on arbitrary sets of Items. 81 """ 82 raise NotImplementedError 83 84 def __repr__(self): 85 return "{0.__class__.__name__}()".format(self) 86 87 def __eq__(self, other): 88 return type(self) == type(other) 89 90 def __hash__(self): 91 return 0 92 93 94class FieldQuery(Query): 95 """An abstract query that searches in a specific field for a 96 pattern. Subclasses must provide a `value_match` class method, which 97 determines whether a certain pattern string matches a certain value 98 string. Subclasses may also provide `col_clause` to implement the 99 same matching functionality in SQLite. 100 """ 101 102 def __init__(self, field, pattern, fast=True): 103 self.field = field 104 self.pattern = pattern 105 self.fast = fast 106 107 def col_clause(self): 108 return None, () 109 110 def clause(self): 111 if self.fast: 112 return self.col_clause() 113 else: 114 # Matching a flexattr. This is a slow query. 115 return None, () 116 117 @classmethod 118 def value_match(cls, pattern, value): 119 """Determine whether the value matches the pattern. Both 120 arguments are strings. 121 """ 122 raise NotImplementedError() 123 124 def match(self, item): 125 return self.value_match(self.pattern, item.get(self.field)) 126 127 def __repr__(self): 128 return ("{0.__class__.__name__}({0.field!r}, {0.pattern!r}, " 129 "{0.fast})".format(self)) 130 131 def __eq__(self, other): 132 return super(FieldQuery, self).__eq__(other) and \ 133 self.field == other.field and self.pattern == other.pattern 134 135 def __hash__(self): 136 return hash((self.field, hash(self.pattern))) 137 138 139class MatchQuery(FieldQuery): 140 """A query that looks for exact matches in an item field.""" 141 142 def col_clause(self): 143 return self.field + " = ?", [self.pattern] 144 145 @classmethod 146 def value_match(cls, pattern, value): 147 return pattern == value 148 149 150class NoneQuery(FieldQuery): 151 """A query that checks whether a field is null.""" 152 153 def __init__(self, field, fast=True): 154 super(NoneQuery, self).__init__(field, None, fast) 155 156 def col_clause(self): 157 return self.field + " IS NULL", () 158 159 @classmethod 160 def match(cls, item): 161 try: 162 return item[cls.field] is None 163 except KeyError: 164 return True 165 166 def __repr__(self): 167 return "{0.__class__.__name__}({0.field!r}, {0.fast})".format(self) 168 169 170class StringFieldQuery(FieldQuery): 171 """A FieldQuery that converts values to strings before matching 172 them. 173 """ 174 175 @classmethod 176 def value_match(cls, pattern, value): 177 """Determine whether the value matches the pattern. The value 178 may have any type. 179 """ 180 return cls.string_match(pattern, util.as_string(value)) 181 182 @classmethod 183 def string_match(cls, pattern, value): 184 """Determine whether the value matches the pattern. Both 185 arguments are strings. Subclasses implement this method. 186 """ 187 raise NotImplementedError() 188 189 190class SubstringQuery(StringFieldQuery): 191 """A query that matches a substring in a specific item field.""" 192 193 def col_clause(self): 194 pattern = (self.pattern 195 .replace('\\', '\\\\') 196 .replace('%', '\\%') 197 .replace('_', '\\_')) 198 search = '%' + pattern + '%' 199 clause = self.field + " like ? escape '\\'" 200 subvals = [search] 201 return clause, subvals 202 203 @classmethod 204 def string_match(cls, pattern, value): 205 return pattern.lower() in value.lower() 206 207 208class RegexpQuery(StringFieldQuery): 209 """A query that matches a regular expression in a specific item 210 field. 211 212 Raises InvalidQueryError when the pattern is not a valid regular 213 expression. 214 """ 215 216 def __init__(self, field, pattern, fast=True): 217 super(RegexpQuery, self).__init__(field, pattern, fast) 218 pattern = self._normalize(pattern) 219 try: 220 self.pattern = re.compile(self.pattern) 221 except re.error as exc: 222 # Invalid regular expression. 223 raise InvalidQueryArgumentValueError(pattern, 224 u"a regular expression", 225 format(exc)) 226 227 @staticmethod 228 def _normalize(s): 229 """Normalize a Unicode string's representation (used on both 230 patterns and matched values). 231 """ 232 return unicodedata.normalize('NFC', s) 233 234 @classmethod 235 def string_match(cls, pattern, value): 236 return pattern.search(cls._normalize(value)) is not None 237 238 239class BooleanQuery(MatchQuery): 240 """Matches a boolean field. Pattern should either be a boolean or a 241 string reflecting a boolean. 242 """ 243 244 def __init__(self, field, pattern, fast=True): 245 super(BooleanQuery, self).__init__(field, pattern, fast) 246 if isinstance(pattern, six.string_types): 247 self.pattern = util.str2bool(pattern) 248 self.pattern = int(self.pattern) 249 250 251class BytesQuery(MatchQuery): 252 """Match a raw bytes field (i.e., a path). This is a necessary hack 253 to work around the `sqlite3` module's desire to treat `bytes` and 254 `unicode` equivalently in Python 2. Always use this query instead of 255 `MatchQuery` when matching on BLOB values. 256 """ 257 258 def __init__(self, field, pattern): 259 super(BytesQuery, self).__init__(field, pattern) 260 261 # Use a buffer/memoryview representation of the pattern for SQLite 262 # matching. This instructs SQLite to treat the blob as binary 263 # rather than encoded Unicode. 264 if isinstance(self.pattern, (six.text_type, bytes)): 265 if isinstance(self.pattern, six.text_type): 266 self.pattern = self.pattern.encode('utf-8') 267 self.buf_pattern = buffer(self.pattern) 268 elif isinstance(self.pattern, buffer): 269 self.buf_pattern = self.pattern 270 self.pattern = bytes(self.pattern) 271 272 def col_clause(self): 273 return self.field + " = ?", [self.buf_pattern] 274 275 276class NumericQuery(FieldQuery): 277 """Matches numeric fields. A syntax using Ruby-style range ellipses 278 (``..``) lets users specify one- or two-sided ranges. For example, 279 ``year:2001..`` finds music released since the turn of the century. 280 281 Raises InvalidQueryError when the pattern does not represent an int or 282 a float. 283 """ 284 285 def _convert(self, s): 286 """Convert a string to a numeric type (float or int). 287 288 Return None if `s` is empty. 289 Raise an InvalidQueryError if the string cannot be converted. 290 """ 291 # This is really just a bit of fun premature optimization. 292 if not s: 293 return None 294 try: 295 return int(s) 296 except ValueError: 297 try: 298 return float(s) 299 except ValueError: 300 raise InvalidQueryArgumentValueError(s, u"an int or a float") 301 302 def __init__(self, field, pattern, fast=True): 303 super(NumericQuery, self).__init__(field, pattern, fast) 304 305 parts = pattern.split('..', 1) 306 if len(parts) == 1: 307 # No range. 308 self.point = self._convert(parts[0]) 309 self.rangemin = None 310 self.rangemax = None 311 else: 312 # One- or two-sided range. 313 self.point = None 314 self.rangemin = self._convert(parts[0]) 315 self.rangemax = self._convert(parts[1]) 316 317 def match(self, item): 318 if self.field not in item: 319 return False 320 value = item[self.field] 321 if isinstance(value, six.string_types): 322 value = self._convert(value) 323 324 if self.point is not None: 325 return value == self.point 326 else: 327 if self.rangemin is not None and value < self.rangemin: 328 return False 329 if self.rangemax is not None and value > self.rangemax: 330 return False 331 return True 332 333 def col_clause(self): 334 if self.point is not None: 335 return self.field + '=?', (self.point,) 336 else: 337 if self.rangemin is not None and self.rangemax is not None: 338 return (u'{0} >= ? AND {0} <= ?'.format(self.field), 339 (self.rangemin, self.rangemax)) 340 elif self.rangemin is not None: 341 return u'{0} >= ?'.format(self.field), (self.rangemin,) 342 elif self.rangemax is not None: 343 return u'{0} <= ?'.format(self.field), (self.rangemax,) 344 else: 345 return u'1', () 346 347 348class CollectionQuery(Query): 349 """An abstract query class that aggregates other queries. Can be 350 indexed like a list to access the sub-queries. 351 """ 352 353 def __init__(self, subqueries=()): 354 self.subqueries = subqueries 355 356 # Act like a sequence. 357 358 def __len__(self): 359 return len(self.subqueries) 360 361 def __getitem__(self, key): 362 return self.subqueries[key] 363 364 def __iter__(self): 365 return iter(self.subqueries) 366 367 def __contains__(self, item): 368 return item in self.subqueries 369 370 def clause_with_joiner(self, joiner): 371 """Return a clause created by joining together the clauses of 372 all subqueries with the string joiner (padded by spaces). 373 """ 374 clause_parts = [] 375 subvals = [] 376 for subq in self.subqueries: 377 subq_clause, subq_subvals = subq.clause() 378 if not subq_clause: 379 # Fall back to slow query. 380 return None, () 381 clause_parts.append('(' + subq_clause + ')') 382 subvals += subq_subvals 383 clause = (' ' + joiner + ' ').join(clause_parts) 384 return clause, subvals 385 386 def __repr__(self): 387 return "{0.__class__.__name__}({0.subqueries!r})".format(self) 388 389 def __eq__(self, other): 390 return super(CollectionQuery, self).__eq__(other) and \ 391 self.subqueries == other.subqueries 392 393 def __hash__(self): 394 """Since subqueries are mutable, this object should not be hashable. 395 However and for conveniences purposes, it can be hashed. 396 """ 397 return reduce(mul, map(hash, self.subqueries), 1) 398 399 400class AnyFieldQuery(CollectionQuery): 401 """A query that matches if a given FieldQuery subclass matches in 402 any field. The individual field query class is provided to the 403 constructor. 404 """ 405 406 def __init__(self, pattern, fields, cls): 407 self.pattern = pattern 408 self.fields = fields 409 self.query_class = cls 410 411 subqueries = [] 412 for field in self.fields: 413 subqueries.append(cls(field, pattern, True)) 414 super(AnyFieldQuery, self).__init__(subqueries) 415 416 def clause(self): 417 return self.clause_with_joiner('or') 418 419 def match(self, item): 420 for subq in self.subqueries: 421 if subq.match(item): 422 return True 423 return False 424 425 def __repr__(self): 426 return ("{0.__class__.__name__}({0.pattern!r}, {0.fields!r}, " 427 "{0.query_class.__name__})".format(self)) 428 429 def __eq__(self, other): 430 return super(AnyFieldQuery, self).__eq__(other) and \ 431 self.query_class == other.query_class 432 433 def __hash__(self): 434 return hash((self.pattern, tuple(self.fields), self.query_class)) 435 436 437class MutableCollectionQuery(CollectionQuery): 438 """A collection query whose subqueries may be modified after the 439 query is initialized. 440 """ 441 442 def __setitem__(self, key, value): 443 self.subqueries[key] = value 444 445 def __delitem__(self, key): 446 del self.subqueries[key] 447 448 449class AndQuery(MutableCollectionQuery): 450 """A conjunction of a list of other queries.""" 451 452 def clause(self): 453 return self.clause_with_joiner('and') 454 455 def match(self, item): 456 return all([q.match(item) for q in self.subqueries]) 457 458 459class OrQuery(MutableCollectionQuery): 460 """A conjunction of a list of other queries.""" 461 462 def clause(self): 463 return self.clause_with_joiner('or') 464 465 def match(self, item): 466 return any([q.match(item) for q in self.subqueries]) 467 468 469class NotQuery(Query): 470 """A query that matches the negation of its `subquery`, as a shorcut for 471 performing `not(subquery)` without using regular expressions. 472 """ 473 474 def __init__(self, subquery): 475 self.subquery = subquery 476 477 def clause(self): 478 clause, subvals = self.subquery.clause() 479 if clause: 480 return 'not ({0})'.format(clause), subvals 481 else: 482 # If there is no clause, there is nothing to negate. All the logic 483 # is handled by match() for slow queries. 484 return clause, subvals 485 486 def match(self, item): 487 return not self.subquery.match(item) 488 489 def __repr__(self): 490 return "{0.__class__.__name__}({0.subquery!r})".format(self) 491 492 def __eq__(self, other): 493 return super(NotQuery, self).__eq__(other) and \ 494 self.subquery == other.subquery 495 496 def __hash__(self): 497 return hash(('not', hash(self.subquery))) 498 499 500class TrueQuery(Query): 501 """A query that always matches.""" 502 503 def clause(self): 504 return '1', () 505 506 def match(self, item): 507 return True 508 509 510class FalseQuery(Query): 511 """A query that never matches.""" 512 513 def clause(self): 514 return '0', () 515 516 def match(self, item): 517 return False 518 519 520# Time/date queries. 521 522def _to_epoch_time(date): 523 """Convert a `datetime` object to an integer number of seconds since 524 the (local) Unix epoch. 525 """ 526 if hasattr(date, 'timestamp'): 527 # The `timestamp` method exists on Python 3.3+. 528 return int(date.timestamp()) 529 else: 530 epoch = datetime.fromtimestamp(0) 531 delta = date - epoch 532 return int(delta.total_seconds()) 533 534 535def _parse_periods(pattern): 536 """Parse a string containing two dates separated by two dots (..). 537 Return a pair of `Period` objects. 538 """ 539 parts = pattern.split('..', 1) 540 if len(parts) == 1: 541 instant = Period.parse(parts[0]) 542 return (instant, instant) 543 else: 544 start = Period.parse(parts[0]) 545 end = Period.parse(parts[1]) 546 return (start, end) 547 548 549class Period(object): 550 """A period of time given by a date, time and precision. 551 552 Example: 2014-01-01 10:50:30 with precision 'month' represents all 553 instants of time during January 2014. 554 """ 555 556 precisions = ('year', 'month', 'day', 'hour', 'minute', 'second') 557 date_formats = ( 558 ('%Y',), # year 559 ('%Y-%m',), # month 560 ('%Y-%m-%d',), # day 561 ('%Y-%m-%dT%H', '%Y-%m-%d %H'), # hour 562 ('%Y-%m-%dT%H:%M', '%Y-%m-%d %H:%M'), # minute 563 ('%Y-%m-%dT%H:%M:%S', '%Y-%m-%d %H:%M:%S') # second 564 ) 565 relative_units = {'y': 365, 'm': 30, 'w': 7, 'd': 1} 566 relative_re = '(?P<sign>[+|-]?)(?P<quantity>[0-9]+)' + \ 567 '(?P<timespan>[y|m|w|d])' 568 569 def __init__(self, date, precision): 570 """Create a period with the given date (a `datetime` object) and 571 precision (a string, one of "year", "month", "day", "hour", "minute", 572 or "second"). 573 """ 574 if precision not in Period.precisions: 575 raise ValueError(u'Invalid precision {0}'.format(precision)) 576 self.date = date 577 self.precision = precision 578 579 @classmethod 580 def parse(cls, string): 581 """Parse a date and return a `Period` object or `None` if the 582 string is empty, or raise an InvalidQueryArgumentValueError if 583 the string cannot be parsed to a date. 584 585 The date may be absolute or relative. Absolute dates look like 586 `YYYY`, or `YYYY-MM-DD`, or `YYYY-MM-DD HH:MM:SS`, etc. Relative 587 dates have three parts: 588 589 - Optionally, a ``+`` or ``-`` sign indicating the future or the 590 past. The default is the future. 591 - A number: how much to add or subtract. 592 - A letter indicating the unit: days, weeks, months or years 593 (``d``, ``w``, ``m`` or ``y``). A "month" is exactly 30 days 594 and a "year" is exactly 365 days. 595 """ 596 597 def find_date_and_format(string): 598 for ord, format in enumerate(cls.date_formats): 599 for format_option in format: 600 try: 601 date = datetime.strptime(string, format_option) 602 return date, ord 603 except ValueError: 604 # Parsing failed. 605 pass 606 return (None, None) 607 608 if not string: 609 return None 610 611 # Check for a relative date. 612 match_dq = re.match(cls.relative_re, string) 613 if match_dq: 614 sign = match_dq.group('sign') 615 quantity = match_dq.group('quantity') 616 timespan = match_dq.group('timespan') 617 618 # Add or subtract the given amount of time from the current 619 # date. 620 multiplier = -1 if sign == '-' else 1 621 days = cls.relative_units[timespan] 622 date = datetime.now() + \ 623 timedelta(days=int(quantity) * days) * multiplier 624 return cls(date, cls.precisions[5]) 625 626 # Check for an absolute date. 627 date, ordinal = find_date_and_format(string) 628 if date is None: 629 raise InvalidQueryArgumentValueError(string, 630 'a valid date/time string') 631 precision = cls.precisions[ordinal] 632 return cls(date, precision) 633 634 def open_right_endpoint(self): 635 """Based on the precision, convert the period to a precise 636 `datetime` for use as a right endpoint in a right-open interval. 637 """ 638 precision = self.precision 639 date = self.date 640 if 'year' == self.precision: 641 return date.replace(year=date.year + 1, month=1) 642 elif 'month' == precision: 643 if (date.month < 12): 644 return date.replace(month=date.month + 1) 645 else: 646 return date.replace(year=date.year + 1, month=1) 647 elif 'day' == precision: 648 return date + timedelta(days=1) 649 elif 'hour' == precision: 650 return date + timedelta(hours=1) 651 elif 'minute' == precision: 652 return date + timedelta(minutes=1) 653 elif 'second' == precision: 654 return date + timedelta(seconds=1) 655 else: 656 raise ValueError(u'unhandled precision {0}'.format(precision)) 657 658 659class DateInterval(object): 660 """A closed-open interval of dates. 661 662 A left endpoint of None means since the beginning of time. 663 A right endpoint of None means towards infinity. 664 """ 665 666 def __init__(self, start, end): 667 if start is not None and end is not None and not start < end: 668 raise ValueError(u"start date {0} is not before end date {1}" 669 .format(start, end)) 670 self.start = start 671 self.end = end 672 673 @classmethod 674 def from_periods(cls, start, end): 675 """Create an interval with two Periods as the endpoints. 676 """ 677 end_date = end.open_right_endpoint() if end is not None else None 678 start_date = start.date if start is not None else None 679 return cls(start_date, end_date) 680 681 def contains(self, date): 682 if self.start is not None and date < self.start: 683 return False 684 if self.end is not None and date >= self.end: 685 return False 686 return True 687 688 def __str__(self): 689 return '[{0}, {1})'.format(self.start, self.end) 690 691 692class DateQuery(FieldQuery): 693 """Matches date fields stored as seconds since Unix epoch time. 694 695 Dates can be specified as ``year-month-day`` strings where only year 696 is mandatory. 697 698 The value of a date field can be matched against a date interval by 699 using an ellipsis interval syntax similar to that of NumericQuery. 700 """ 701 702 def __init__(self, field, pattern, fast=True): 703 super(DateQuery, self).__init__(field, pattern, fast) 704 start, end = _parse_periods(pattern) 705 self.interval = DateInterval.from_periods(start, end) 706 707 def match(self, item): 708 if self.field not in item: 709 return False 710 timestamp = float(item[self.field]) 711 date = datetime.fromtimestamp(timestamp) 712 return self.interval.contains(date) 713 714 _clause_tmpl = "{0} {1} ?" 715 716 def col_clause(self): 717 clause_parts = [] 718 subvals = [] 719 720 if self.interval.start: 721 clause_parts.append(self._clause_tmpl.format(self.field, ">=")) 722 subvals.append(_to_epoch_time(self.interval.start)) 723 724 if self.interval.end: 725 clause_parts.append(self._clause_tmpl.format(self.field, "<")) 726 subvals.append(_to_epoch_time(self.interval.end)) 727 728 if clause_parts: 729 # One- or two-sided interval. 730 clause = ' AND '.join(clause_parts) 731 else: 732 # Match any date. 733 clause = '1' 734 return clause, subvals 735 736 737class DurationQuery(NumericQuery): 738 """NumericQuery that allow human-friendly (M:SS) time interval formats. 739 740 Converts the range(s) to a float value, and delegates on NumericQuery. 741 742 Raises InvalidQueryError when the pattern does not represent an int, float 743 or M:SS time interval. 744 """ 745 746 def _convert(self, s): 747 """Convert a M:SS or numeric string to a float. 748 749 Return None if `s` is empty. 750 Raise an InvalidQueryError if the string cannot be converted. 751 """ 752 if not s: 753 return None 754 try: 755 return util.raw_seconds_short(s) 756 except ValueError: 757 try: 758 return float(s) 759 except ValueError: 760 raise InvalidQueryArgumentValueError( 761 s, 762 u"a M:SS string or a float") 763 764 765# Sorting. 766 767class Sort(object): 768 """An abstract class representing a sort operation for a query into 769 the item database. 770 """ 771 772 def order_clause(self): 773 """Generates a SQL fragment to be used in a ORDER BY clause, or 774 None if no fragment is used (i.e., this is a slow sort). 775 """ 776 return None 777 778 def sort(self, items): 779 """Sort the list of objects and return a list. 780 """ 781 return sorted(items) 782 783 def is_slow(self): 784 """Indicate whether this query is *slow*, meaning that it cannot 785 be executed in SQL and must be executed in Python. 786 """ 787 return False 788 789 def __hash__(self): 790 return 0 791 792 def __eq__(self, other): 793 return type(self) == type(other) 794 795 796class MultipleSort(Sort): 797 """Sort that encapsulates multiple sub-sorts. 798 """ 799 800 def __init__(self, sorts=None): 801 self.sorts = sorts or [] 802 803 def add_sort(self, sort): 804 self.sorts.append(sort) 805 806 def _sql_sorts(self): 807 """Return the list of sub-sorts for which we can be (at least 808 partially) fast. 809 810 A contiguous suffix of fast (SQL-capable) sub-sorts are 811 executable in SQL. The remaining, even if they are fast 812 independently, must be executed slowly. 813 """ 814 sql_sorts = [] 815 for sort in reversed(self.sorts): 816 if not sort.order_clause() is None: 817 sql_sorts.append(sort) 818 else: 819 break 820 sql_sorts.reverse() 821 return sql_sorts 822 823 def order_clause(self): 824 order_strings = [] 825 for sort in self._sql_sorts(): 826 order = sort.order_clause() 827 order_strings.append(order) 828 829 return ", ".join(order_strings) 830 831 def is_slow(self): 832 for sort in self.sorts: 833 if sort.is_slow(): 834 return True 835 return False 836 837 def sort(self, items): 838 slow_sorts = [] 839 switch_slow = False 840 for sort in reversed(self.sorts): 841 if switch_slow: 842 slow_sorts.append(sort) 843 elif sort.order_clause() is None: 844 switch_slow = True 845 slow_sorts.append(sort) 846 else: 847 pass 848 849 for sort in slow_sorts: 850 items = sort.sort(items) 851 return items 852 853 def __repr__(self): 854 return 'MultipleSort({!r})'.format(self.sorts) 855 856 def __hash__(self): 857 return hash(tuple(self.sorts)) 858 859 def __eq__(self, other): 860 return super(MultipleSort, self).__eq__(other) and \ 861 self.sorts == other.sorts 862 863 864class FieldSort(Sort): 865 """An abstract sort criterion that orders by a specific field (of 866 any kind). 867 """ 868 869 def __init__(self, field, ascending=True, case_insensitive=True): 870 self.field = field 871 self.ascending = ascending 872 self.case_insensitive = case_insensitive 873 874 def sort(self, objs): 875 # TODO: Conversion and null-detection here. In Python 3, 876 # comparisons with None fail. We should also support flexible 877 # attributes with different types without falling over. 878 879 def key(item): 880 field_val = item.get(self.field, '') 881 if self.case_insensitive and isinstance(field_val, six.text_type): 882 field_val = field_val.lower() 883 return field_val 884 885 return sorted(objs, key=key, reverse=not self.ascending) 886 887 def __repr__(self): 888 return '<{0}: {1}{2}>'.format( 889 type(self).__name__, 890 self.field, 891 '+' if self.ascending else '-', 892 ) 893 894 def __hash__(self): 895 return hash((self.field, self.ascending)) 896 897 def __eq__(self, other): 898 return super(FieldSort, self).__eq__(other) and \ 899 self.field == other.field and \ 900 self.ascending == other.ascending 901 902 903class FixedFieldSort(FieldSort): 904 """Sort object to sort on a fixed field. 905 """ 906 907 def order_clause(self): 908 order = "ASC" if self.ascending else "DESC" 909 if self.case_insensitive: 910 field = '(CASE ' \ 911 'WHEN TYPEOF({0})="text" THEN LOWER({0}) ' \ 912 'WHEN TYPEOF({0})="blob" THEN LOWER({0}) ' \ 913 'ELSE {0} END)'.format(self.field) 914 else: 915 field = self.field 916 return "{0} {1}".format(field, order) 917 918 919class SlowFieldSort(FieldSort): 920 """A sort criterion by some model field other than a fixed field: 921 i.e., a computed or flexible field. 922 """ 923 924 def is_slow(self): 925 return True 926 927 928class NullSort(Sort): 929 """No sorting. Leave results unsorted.""" 930 931 def sort(self, items): 932 return items 933 934 def __nonzero__(self): 935 return self.__bool__() 936 937 def __bool__(self): 938 return False 939 940 def __eq__(self, other): 941 return type(self) == type(other) or other is None 942 943 def __hash__(self): 944 return 0 945