1# -*- coding: utf-8 -*-
2# This file is part of beets.
3# Copyright 2016, Adrian Sampson.
4#
5# Permission is hereby granted, free of charge, to any person obtaining
6# a copy of this software and associated documentation files (the
7# "Software"), to deal in the Software without restriction, including
8# without limitation the rights to use, copy, modify, merge, publish,
9# distribute, sublicense, and/or sell copies of the Software, and to
10# permit persons to whom the Software is furnished to do so, subject to
11# the following conditions:
12#
13# The above copyright notice and this permission notice shall be
14# included in all copies or substantial portions of the Software.
15
16"""The Query type hierarchy for DBCore.
17"""
18from __future__ import division, absolute_import, print_function
19
20import re
21from operator import mul
22from beets import util
23from datetime import datetime, timedelta
24import unicodedata
25from functools import reduce
26import six
27
28if not six.PY2:
29    buffer = memoryview  # sqlite won't accept memoryview in python 2
30
31
32class ParsingError(ValueError):
33    """Abstract class for any unparseable user-requested album/query
34    specification.
35    """
36
37
38class InvalidQueryError(ParsingError):
39    """Represent any kind of invalid query.
40
41    The query should be a unicode string or a list, which will be space-joined.
42    """
43
44    def __init__(self, query, explanation):
45        if isinstance(query, list):
46            query = " ".join(query)
47        message = u"'{0}': {1}".format(query, explanation)
48        super(InvalidQueryError, self).__init__(message)
49
50
51class InvalidQueryArgumentValueError(ParsingError):
52    """Represent a query argument that could not be converted as expected.
53
54    It exists to be caught in upper stack levels so a meaningful (i.e. with the
55    query) InvalidQueryError can be raised.
56    """
57
58    def __init__(self, what, expected, detail=None):
59        message = u"'{0}' is not {1}".format(what, expected)
60        if detail:
61            message = u"{0}: {1}".format(message, detail)
62        super(InvalidQueryArgumentValueError, self).__init__(message)
63
64
65class Query(object):
66    """An abstract class representing a query into the item database.
67    """
68
69    def clause(self):
70        """Generate an SQLite expression implementing the query.
71
72        Return (clause, subvals) where clause is a valid sqlite
73        WHERE clause implementing the query and subvals is a list of
74        items to be substituted for ?s in the clause.
75        """
76        return None, ()
77
78    def match(self, item):
79        """Check whether this query matches a given Item. Can be used to
80        perform queries on arbitrary sets of Items.
81        """
82        raise NotImplementedError
83
84    def __repr__(self):
85        return "{0.__class__.__name__}()".format(self)
86
87    def __eq__(self, other):
88        return type(self) == type(other)
89
90    def __hash__(self):
91        return 0
92
93
94class FieldQuery(Query):
95    """An abstract query that searches in a specific field for a
96    pattern. Subclasses must provide a `value_match` class method, which
97    determines whether a certain pattern string matches a certain value
98    string. Subclasses may also provide `col_clause` to implement the
99    same matching functionality in SQLite.
100    """
101
102    def __init__(self, field, pattern, fast=True):
103        self.field = field
104        self.pattern = pattern
105        self.fast = fast
106
107    def col_clause(self):
108        return None, ()
109
110    def clause(self):
111        if self.fast:
112            return self.col_clause()
113        else:
114            # Matching a flexattr. This is a slow query.
115            return None, ()
116
117    @classmethod
118    def value_match(cls, pattern, value):
119        """Determine whether the value matches the pattern. Both
120        arguments are strings.
121        """
122        raise NotImplementedError()
123
124    def match(self, item):
125        return self.value_match(self.pattern, item.get(self.field))
126
127    def __repr__(self):
128        return ("{0.__class__.__name__}({0.field!r}, {0.pattern!r}, "
129                "{0.fast})".format(self))
130
131    def __eq__(self, other):
132        return super(FieldQuery, self).__eq__(other) and \
133            self.field == other.field and self.pattern == other.pattern
134
135    def __hash__(self):
136        return hash((self.field, hash(self.pattern)))
137
138
139class MatchQuery(FieldQuery):
140    """A query that looks for exact matches in an item field."""
141
142    def col_clause(self):
143        return self.field + " = ?", [self.pattern]
144
145    @classmethod
146    def value_match(cls, pattern, value):
147        return pattern == value
148
149
150class NoneQuery(FieldQuery):
151    """A query that checks whether a field is null."""
152
153    def __init__(self, field, fast=True):
154        super(NoneQuery, self).__init__(field, None, fast)
155
156    def col_clause(self):
157        return self.field + " IS NULL", ()
158
159    @classmethod
160    def match(cls, item):
161        try:
162            return item[cls.field] is None
163        except KeyError:
164            return True
165
166    def __repr__(self):
167        return "{0.__class__.__name__}({0.field!r}, {0.fast})".format(self)
168
169
170class StringFieldQuery(FieldQuery):
171    """A FieldQuery that converts values to strings before matching
172    them.
173    """
174
175    @classmethod
176    def value_match(cls, pattern, value):
177        """Determine whether the value matches the pattern. The value
178        may have any type.
179        """
180        return cls.string_match(pattern, util.as_string(value))
181
182    @classmethod
183    def string_match(cls, pattern, value):
184        """Determine whether the value matches the pattern. Both
185        arguments are strings. Subclasses implement this method.
186        """
187        raise NotImplementedError()
188
189
190class SubstringQuery(StringFieldQuery):
191    """A query that matches a substring in a specific item field."""
192
193    def col_clause(self):
194        pattern = (self.pattern
195                   .replace('\\', '\\\\')
196                   .replace('%', '\\%')
197                   .replace('_', '\\_'))
198        search = '%' + pattern + '%'
199        clause = self.field + " like ? escape '\\'"
200        subvals = [search]
201        return clause, subvals
202
203    @classmethod
204    def string_match(cls, pattern, value):
205        return pattern.lower() in value.lower()
206
207
208class RegexpQuery(StringFieldQuery):
209    """A query that matches a regular expression in a specific item
210    field.
211
212    Raises InvalidQueryError when the pattern is not a valid regular
213    expression.
214    """
215
216    def __init__(self, field, pattern, fast=True):
217        super(RegexpQuery, self).__init__(field, pattern, fast)
218        pattern = self._normalize(pattern)
219        try:
220            self.pattern = re.compile(self.pattern)
221        except re.error as exc:
222            # Invalid regular expression.
223            raise InvalidQueryArgumentValueError(pattern,
224                                                 u"a regular expression",
225                                                 format(exc))
226
227    @staticmethod
228    def _normalize(s):
229        """Normalize a Unicode string's representation (used on both
230        patterns and matched values).
231        """
232        return unicodedata.normalize('NFC', s)
233
234    @classmethod
235    def string_match(cls, pattern, value):
236        return pattern.search(cls._normalize(value)) is not None
237
238
239class BooleanQuery(MatchQuery):
240    """Matches a boolean field. Pattern should either be a boolean or a
241    string reflecting a boolean.
242    """
243
244    def __init__(self, field, pattern, fast=True):
245        super(BooleanQuery, self).__init__(field, pattern, fast)
246        if isinstance(pattern, six.string_types):
247            self.pattern = util.str2bool(pattern)
248        self.pattern = int(self.pattern)
249
250
251class BytesQuery(MatchQuery):
252    """Match a raw bytes field (i.e., a path). This is a necessary hack
253    to work around the `sqlite3` module's desire to treat `bytes` and
254    `unicode` equivalently in Python 2. Always use this query instead of
255    `MatchQuery` when matching on BLOB values.
256    """
257
258    def __init__(self, field, pattern):
259        super(BytesQuery, self).__init__(field, pattern)
260
261        # Use a buffer/memoryview representation of the pattern for SQLite
262        # matching. This instructs SQLite to treat the blob as binary
263        # rather than encoded Unicode.
264        if isinstance(self.pattern, (six.text_type, bytes)):
265            if isinstance(self.pattern, six.text_type):
266                self.pattern = self.pattern.encode('utf-8')
267            self.buf_pattern = buffer(self.pattern)
268        elif isinstance(self.pattern, buffer):
269            self.buf_pattern = self.pattern
270            self.pattern = bytes(self.pattern)
271
272    def col_clause(self):
273        return self.field + " = ?", [self.buf_pattern]
274
275
276class NumericQuery(FieldQuery):
277    """Matches numeric fields. A syntax using Ruby-style range ellipses
278    (``..``) lets users specify one- or two-sided ranges. For example,
279    ``year:2001..`` finds music released since the turn of the century.
280
281    Raises InvalidQueryError when the pattern does not represent an int or
282    a float.
283    """
284
285    def _convert(self, s):
286        """Convert a string to a numeric type (float or int).
287
288        Return None if `s` is empty.
289        Raise an InvalidQueryError if the string cannot be converted.
290        """
291        # This is really just a bit of fun premature optimization.
292        if not s:
293            return None
294        try:
295            return int(s)
296        except ValueError:
297            try:
298                return float(s)
299            except ValueError:
300                raise InvalidQueryArgumentValueError(s, u"an int or a float")
301
302    def __init__(self, field, pattern, fast=True):
303        super(NumericQuery, self).__init__(field, pattern, fast)
304
305        parts = pattern.split('..', 1)
306        if len(parts) == 1:
307            # No range.
308            self.point = self._convert(parts[0])
309            self.rangemin = None
310            self.rangemax = None
311        else:
312            # One- or two-sided range.
313            self.point = None
314            self.rangemin = self._convert(parts[0])
315            self.rangemax = self._convert(parts[1])
316
317    def match(self, item):
318        if self.field not in item:
319            return False
320        value = item[self.field]
321        if isinstance(value, six.string_types):
322            value = self._convert(value)
323
324        if self.point is not None:
325            return value == self.point
326        else:
327            if self.rangemin is not None and value < self.rangemin:
328                return False
329            if self.rangemax is not None and value > self.rangemax:
330                return False
331            return True
332
333    def col_clause(self):
334        if self.point is not None:
335            return self.field + '=?', (self.point,)
336        else:
337            if self.rangemin is not None and self.rangemax is not None:
338                return (u'{0} >= ? AND {0} <= ?'.format(self.field),
339                        (self.rangemin, self.rangemax))
340            elif self.rangemin is not None:
341                return u'{0} >= ?'.format(self.field), (self.rangemin,)
342            elif self.rangemax is not None:
343                return u'{0} <= ?'.format(self.field), (self.rangemax,)
344            else:
345                return u'1', ()
346
347
348class CollectionQuery(Query):
349    """An abstract query class that aggregates other queries. Can be
350    indexed like a list to access the sub-queries.
351    """
352
353    def __init__(self, subqueries=()):
354        self.subqueries = subqueries
355
356    # Act like a sequence.
357
358    def __len__(self):
359        return len(self.subqueries)
360
361    def __getitem__(self, key):
362        return self.subqueries[key]
363
364    def __iter__(self):
365        return iter(self.subqueries)
366
367    def __contains__(self, item):
368        return item in self.subqueries
369
370    def clause_with_joiner(self, joiner):
371        """Return a clause created by joining together the clauses of
372        all subqueries with the string joiner (padded by spaces).
373        """
374        clause_parts = []
375        subvals = []
376        for subq in self.subqueries:
377            subq_clause, subq_subvals = subq.clause()
378            if not subq_clause:
379                # Fall back to slow query.
380                return None, ()
381            clause_parts.append('(' + subq_clause + ')')
382            subvals += subq_subvals
383        clause = (' ' + joiner + ' ').join(clause_parts)
384        return clause, subvals
385
386    def __repr__(self):
387        return "{0.__class__.__name__}({0.subqueries!r})".format(self)
388
389    def __eq__(self, other):
390        return super(CollectionQuery, self).__eq__(other) and \
391            self.subqueries == other.subqueries
392
393    def __hash__(self):
394        """Since subqueries are mutable, this object should not be hashable.
395        However and for conveniences purposes, it can be hashed.
396        """
397        return reduce(mul, map(hash, self.subqueries), 1)
398
399
400class AnyFieldQuery(CollectionQuery):
401    """A query that matches if a given FieldQuery subclass matches in
402    any field. The individual field query class is provided to the
403    constructor.
404    """
405
406    def __init__(self, pattern, fields, cls):
407        self.pattern = pattern
408        self.fields = fields
409        self.query_class = cls
410
411        subqueries = []
412        for field in self.fields:
413            subqueries.append(cls(field, pattern, True))
414        super(AnyFieldQuery, self).__init__(subqueries)
415
416    def clause(self):
417        return self.clause_with_joiner('or')
418
419    def match(self, item):
420        for subq in self.subqueries:
421            if subq.match(item):
422                return True
423        return False
424
425    def __repr__(self):
426        return ("{0.__class__.__name__}({0.pattern!r}, {0.fields!r}, "
427                "{0.query_class.__name__})".format(self))
428
429    def __eq__(self, other):
430        return super(AnyFieldQuery, self).__eq__(other) and \
431            self.query_class == other.query_class
432
433    def __hash__(self):
434        return hash((self.pattern, tuple(self.fields), self.query_class))
435
436
437class MutableCollectionQuery(CollectionQuery):
438    """A collection query whose subqueries may be modified after the
439    query is initialized.
440    """
441
442    def __setitem__(self, key, value):
443        self.subqueries[key] = value
444
445    def __delitem__(self, key):
446        del self.subqueries[key]
447
448
449class AndQuery(MutableCollectionQuery):
450    """A conjunction of a list of other queries."""
451
452    def clause(self):
453        return self.clause_with_joiner('and')
454
455    def match(self, item):
456        return all([q.match(item) for q in self.subqueries])
457
458
459class OrQuery(MutableCollectionQuery):
460    """A conjunction of a list of other queries."""
461
462    def clause(self):
463        return self.clause_with_joiner('or')
464
465    def match(self, item):
466        return any([q.match(item) for q in self.subqueries])
467
468
469class NotQuery(Query):
470    """A query that matches the negation of its `subquery`, as a shorcut for
471    performing `not(subquery)` without using regular expressions.
472    """
473
474    def __init__(self, subquery):
475        self.subquery = subquery
476
477    def clause(self):
478        clause, subvals = self.subquery.clause()
479        if clause:
480            return 'not ({0})'.format(clause), subvals
481        else:
482            # If there is no clause, there is nothing to negate. All the logic
483            # is handled by match() for slow queries.
484            return clause, subvals
485
486    def match(self, item):
487        return not self.subquery.match(item)
488
489    def __repr__(self):
490        return "{0.__class__.__name__}({0.subquery!r})".format(self)
491
492    def __eq__(self, other):
493        return super(NotQuery, self).__eq__(other) and \
494            self.subquery == other.subquery
495
496    def __hash__(self):
497        return hash(('not', hash(self.subquery)))
498
499
500class TrueQuery(Query):
501    """A query that always matches."""
502
503    def clause(self):
504        return '1', ()
505
506    def match(self, item):
507        return True
508
509
510class FalseQuery(Query):
511    """A query that never matches."""
512
513    def clause(self):
514        return '0', ()
515
516    def match(self, item):
517        return False
518
519
520# Time/date queries.
521
522def _to_epoch_time(date):
523    """Convert a `datetime` object to an integer number of seconds since
524    the (local) Unix epoch.
525    """
526    if hasattr(date, 'timestamp'):
527        # The `timestamp` method exists on Python 3.3+.
528        return int(date.timestamp())
529    else:
530        epoch = datetime.fromtimestamp(0)
531        delta = date - epoch
532        return int(delta.total_seconds())
533
534
535def _parse_periods(pattern):
536    """Parse a string containing two dates separated by two dots (..).
537    Return a pair of `Period` objects.
538    """
539    parts = pattern.split('..', 1)
540    if len(parts) == 1:
541        instant = Period.parse(parts[0])
542        return (instant, instant)
543    else:
544        start = Period.parse(parts[0])
545        end = Period.parse(parts[1])
546        return (start, end)
547
548
549class Period(object):
550    """A period of time given by a date, time and precision.
551
552    Example: 2014-01-01 10:50:30 with precision 'month' represents all
553    instants of time during January 2014.
554    """
555
556    precisions = ('year', 'month', 'day', 'hour', 'minute', 'second')
557    date_formats = (
558        ('%Y',),  # year
559        ('%Y-%m',),  # month
560        ('%Y-%m-%d',),  # day
561        ('%Y-%m-%dT%H', '%Y-%m-%d %H'),  # hour
562        ('%Y-%m-%dT%H:%M', '%Y-%m-%d %H:%M'),  # minute
563        ('%Y-%m-%dT%H:%M:%S', '%Y-%m-%d %H:%M:%S')  # second
564    )
565    relative_units = {'y': 365, 'm': 30, 'w': 7, 'd': 1}
566    relative_re = '(?P<sign>[+|-]?)(?P<quantity>[0-9]+)' + \
567        '(?P<timespan>[y|m|w|d])'
568
569    def __init__(self, date, precision):
570        """Create a period with the given date (a `datetime` object) and
571        precision (a string, one of "year", "month", "day", "hour", "minute",
572        or "second").
573        """
574        if precision not in Period.precisions:
575            raise ValueError(u'Invalid precision {0}'.format(precision))
576        self.date = date
577        self.precision = precision
578
579    @classmethod
580    def parse(cls, string):
581        """Parse a date and return a `Period` object or `None` if the
582        string is empty, or raise an InvalidQueryArgumentValueError if
583        the string cannot be parsed to a date.
584
585        The date may be absolute or relative. Absolute dates look like
586        `YYYY`, or `YYYY-MM-DD`, or `YYYY-MM-DD HH:MM:SS`, etc. Relative
587        dates have three parts:
588
589        - Optionally, a ``+`` or ``-`` sign indicating the future or the
590          past. The default is the future.
591        - A number: how much to add or subtract.
592        - A letter indicating the unit: days, weeks, months or years
593          (``d``, ``w``, ``m`` or ``y``). A "month" is exactly 30 days
594          and a "year" is exactly 365 days.
595        """
596
597        def find_date_and_format(string):
598            for ord, format in enumerate(cls.date_formats):
599                for format_option in format:
600                    try:
601                        date = datetime.strptime(string, format_option)
602                        return date, ord
603                    except ValueError:
604                        # Parsing failed.
605                        pass
606            return (None, None)
607
608        if not string:
609            return None
610
611        # Check for a relative date.
612        match_dq = re.match(cls.relative_re, string)
613        if match_dq:
614            sign = match_dq.group('sign')
615            quantity = match_dq.group('quantity')
616            timespan = match_dq.group('timespan')
617
618            # Add or subtract the given amount of time from the current
619            # date.
620            multiplier = -1 if sign == '-' else 1
621            days = cls.relative_units[timespan]
622            date = datetime.now() + \
623                timedelta(days=int(quantity) * days) * multiplier
624            return cls(date, cls.precisions[5])
625
626        # Check for an absolute date.
627        date, ordinal = find_date_and_format(string)
628        if date is None:
629            raise InvalidQueryArgumentValueError(string,
630                                                 'a valid date/time string')
631        precision = cls.precisions[ordinal]
632        return cls(date, precision)
633
634    def open_right_endpoint(self):
635        """Based on the precision, convert the period to a precise
636        `datetime` for use as a right endpoint in a right-open interval.
637        """
638        precision = self.precision
639        date = self.date
640        if 'year' == self.precision:
641            return date.replace(year=date.year + 1, month=1)
642        elif 'month' == precision:
643            if (date.month < 12):
644                return date.replace(month=date.month + 1)
645            else:
646                return date.replace(year=date.year + 1, month=1)
647        elif 'day' == precision:
648            return date + timedelta(days=1)
649        elif 'hour' == precision:
650            return date + timedelta(hours=1)
651        elif 'minute' == precision:
652            return date + timedelta(minutes=1)
653        elif 'second' == precision:
654            return date + timedelta(seconds=1)
655        else:
656            raise ValueError(u'unhandled precision {0}'.format(precision))
657
658
659class DateInterval(object):
660    """A closed-open interval of dates.
661
662    A left endpoint of None means since the beginning of time.
663    A right endpoint of None means towards infinity.
664    """
665
666    def __init__(self, start, end):
667        if start is not None and end is not None and not start < end:
668            raise ValueError(u"start date {0} is not before end date {1}"
669                             .format(start, end))
670        self.start = start
671        self.end = end
672
673    @classmethod
674    def from_periods(cls, start, end):
675        """Create an interval with two Periods as the endpoints.
676        """
677        end_date = end.open_right_endpoint() if end is not None else None
678        start_date = start.date if start is not None else None
679        return cls(start_date, end_date)
680
681    def contains(self, date):
682        if self.start is not None and date < self.start:
683            return False
684        if self.end is not None and date >= self.end:
685            return False
686        return True
687
688    def __str__(self):
689        return '[{0}, {1})'.format(self.start, self.end)
690
691
692class DateQuery(FieldQuery):
693    """Matches date fields stored as seconds since Unix epoch time.
694
695    Dates can be specified as ``year-month-day`` strings where only year
696    is mandatory.
697
698    The value of a date field can be matched against a date interval by
699    using an ellipsis interval syntax similar to that of NumericQuery.
700    """
701
702    def __init__(self, field, pattern, fast=True):
703        super(DateQuery, self).__init__(field, pattern, fast)
704        start, end = _parse_periods(pattern)
705        self.interval = DateInterval.from_periods(start, end)
706
707    def match(self, item):
708        if self.field not in item:
709            return False
710        timestamp = float(item[self.field])
711        date = datetime.fromtimestamp(timestamp)
712        return self.interval.contains(date)
713
714    _clause_tmpl = "{0} {1} ?"
715
716    def col_clause(self):
717        clause_parts = []
718        subvals = []
719
720        if self.interval.start:
721            clause_parts.append(self._clause_tmpl.format(self.field, ">="))
722            subvals.append(_to_epoch_time(self.interval.start))
723
724        if self.interval.end:
725            clause_parts.append(self._clause_tmpl.format(self.field, "<"))
726            subvals.append(_to_epoch_time(self.interval.end))
727
728        if clause_parts:
729            # One- or two-sided interval.
730            clause = ' AND '.join(clause_parts)
731        else:
732            # Match any date.
733            clause = '1'
734        return clause, subvals
735
736
737class DurationQuery(NumericQuery):
738    """NumericQuery that allow human-friendly (M:SS) time interval formats.
739
740    Converts the range(s) to a float value, and delegates on NumericQuery.
741
742    Raises InvalidQueryError when the pattern does not represent an int, float
743    or M:SS time interval.
744    """
745
746    def _convert(self, s):
747        """Convert a M:SS or numeric string to a float.
748
749        Return None if `s` is empty.
750        Raise an InvalidQueryError if the string cannot be converted.
751        """
752        if not s:
753            return None
754        try:
755            return util.raw_seconds_short(s)
756        except ValueError:
757            try:
758                return float(s)
759            except ValueError:
760                raise InvalidQueryArgumentValueError(
761                    s,
762                    u"a M:SS string or a float")
763
764
765# Sorting.
766
767class Sort(object):
768    """An abstract class representing a sort operation for a query into
769    the item database.
770    """
771
772    def order_clause(self):
773        """Generates a SQL fragment to be used in a ORDER BY clause, or
774        None if no fragment is used (i.e., this is a slow sort).
775        """
776        return None
777
778    def sort(self, items):
779        """Sort the list of objects and return a list.
780        """
781        return sorted(items)
782
783    def is_slow(self):
784        """Indicate whether this query is *slow*, meaning that it cannot
785        be executed in SQL and must be executed in Python.
786        """
787        return False
788
789    def __hash__(self):
790        return 0
791
792    def __eq__(self, other):
793        return type(self) == type(other)
794
795
796class MultipleSort(Sort):
797    """Sort that encapsulates multiple sub-sorts.
798    """
799
800    def __init__(self, sorts=None):
801        self.sorts = sorts or []
802
803    def add_sort(self, sort):
804        self.sorts.append(sort)
805
806    def _sql_sorts(self):
807        """Return the list of sub-sorts for which we can be (at least
808        partially) fast.
809
810        A contiguous suffix of fast (SQL-capable) sub-sorts are
811        executable in SQL. The remaining, even if they are fast
812        independently, must be executed slowly.
813        """
814        sql_sorts = []
815        for sort in reversed(self.sorts):
816            if not sort.order_clause() is None:
817                sql_sorts.append(sort)
818            else:
819                break
820        sql_sorts.reverse()
821        return sql_sorts
822
823    def order_clause(self):
824        order_strings = []
825        for sort in self._sql_sorts():
826            order = sort.order_clause()
827            order_strings.append(order)
828
829        return ", ".join(order_strings)
830
831    def is_slow(self):
832        for sort in self.sorts:
833            if sort.is_slow():
834                return True
835        return False
836
837    def sort(self, items):
838        slow_sorts = []
839        switch_slow = False
840        for sort in reversed(self.sorts):
841            if switch_slow:
842                slow_sorts.append(sort)
843            elif sort.order_clause() is None:
844                switch_slow = True
845                slow_sorts.append(sort)
846            else:
847                pass
848
849        for sort in slow_sorts:
850            items = sort.sort(items)
851        return items
852
853    def __repr__(self):
854        return 'MultipleSort({!r})'.format(self.sorts)
855
856    def __hash__(self):
857        return hash(tuple(self.sorts))
858
859    def __eq__(self, other):
860        return super(MultipleSort, self).__eq__(other) and \
861            self.sorts == other.sorts
862
863
864class FieldSort(Sort):
865    """An abstract sort criterion that orders by a specific field (of
866    any kind).
867    """
868
869    def __init__(self, field, ascending=True, case_insensitive=True):
870        self.field = field
871        self.ascending = ascending
872        self.case_insensitive = case_insensitive
873
874    def sort(self, objs):
875        # TODO: Conversion and null-detection here. In Python 3,
876        # comparisons with None fail. We should also support flexible
877        # attributes with different types without falling over.
878
879        def key(item):
880            field_val = item.get(self.field, '')
881            if self.case_insensitive and isinstance(field_val, six.text_type):
882                field_val = field_val.lower()
883            return field_val
884
885        return sorted(objs, key=key, reverse=not self.ascending)
886
887    def __repr__(self):
888        return '<{0}: {1}{2}>'.format(
889            type(self).__name__,
890            self.field,
891            '+' if self.ascending else '-',
892        )
893
894    def __hash__(self):
895        return hash((self.field, self.ascending))
896
897    def __eq__(self, other):
898        return super(FieldSort, self).__eq__(other) and \
899            self.field == other.field and \
900            self.ascending == other.ascending
901
902
903class FixedFieldSort(FieldSort):
904    """Sort object to sort on a fixed field.
905    """
906
907    def order_clause(self):
908        order = "ASC" if self.ascending else "DESC"
909        if self.case_insensitive:
910            field = '(CASE ' \
911                    'WHEN TYPEOF({0})="text" THEN LOWER({0}) ' \
912                    'WHEN TYPEOF({0})="blob" THEN LOWER({0}) ' \
913                    'ELSE {0} END)'.format(self.field)
914        else:
915            field = self.field
916        return "{0} {1}".format(field, order)
917
918
919class SlowFieldSort(FieldSort):
920    """A sort criterion by some model field other than a fixed field:
921    i.e., a computed or flexible field.
922    """
923
924    def is_slow(self):
925        return True
926
927
928class NullSort(Sort):
929    """No sorting. Leave results unsorted."""
930
931    def sort(self, items):
932        return items
933
934    def __nonzero__(self):
935        return self.__bool__()
936
937    def __bool__(self):
938        return False
939
940    def __eq__(self, other):
941        return type(self) == type(other) or other is None
942
943    def __hash__(self):
944        return 0
945