1from collections import OrderedDict
2from warnings import warn
3
4from django.core.exceptions import FieldDoesNotExist
5from django.db import models
6from django.db.models import Count
7from django.db.models.expressions import Value
8
9from wagtail.search.backends.base import (
10    BaseSearchBackend, BaseSearchQueryCompiler, BaseSearchResults, FilterFieldError)
11from wagtail.search.query import And, Boost, MatchAll, Not, Or, Phrase, PlainText
12from wagtail.search.utils import AND, OR
13
14
15MATCH_ALL = '_ALL_'
16MATCH_NONE = '_NONE_'
17
18
19class DatabaseSearchQueryCompiler(BaseSearchQueryCompiler):
20    DEFAULT_OPERATOR = 'and'
21    OPERATORS = {
22        'and': AND,
23        'or': OR,
24    }
25
26    def __init__(self, *args, **kwargs):
27        super().__init__(*args, **kwargs)
28        self.fields_names = list(self.get_fields_names())
29
30    def get_fields_names(self):
31        model = self.queryset.model
32        fields_names = self.fields or [field.field_name for field in
33                                       model.get_searchable_search_fields()]
34        # Check if the field exists (this will filter out indexed callables)
35        for field_name in fields_names:
36            try:
37                model._meta.get_field(field_name)
38            except FieldDoesNotExist:
39                continue
40            else:
41                yield field_name
42
43    def _process_lookup(self, field, lookup, value):
44        return models.Q(**{field.get_attname(self.queryset.model) + '__' + lookup: value})
45
46    def _connect_filters(self, filters, connector, negated):
47        if connector == 'AND':
48            q = models.Q(*filters)
49        elif connector == 'OR':
50            q = OR([models.Q(fil) for fil in filters])
51        else:
52            return
53
54        if negated:
55            q = ~q
56
57        return q
58
59    def build_single_term_filter(self, term):
60        term_query = models.Q()
61        for field_name in self.fields_names:
62            term_query |= models.Q(**{field_name + '__icontains': term})
63        return term_query
64
65    def check_boost(self, query, boost=1.0):
66        if query.boost * boost != 1.0:
67            warn('Database search backend does not support term boosting.')
68
69    def build_database_filter(self, query, boost=1.0):
70        if isinstance(query, PlainText):
71            self.check_boost(query, boost=boost)
72
73            operator = self.OPERATORS[query.operator]
74
75            return operator([
76                self.build_single_term_filter(term)
77                for term in query.query_string.split()
78            ])
79
80        if isinstance(query, Phrase):
81            q = models.Q()
82            for field_name in self.fields_names:
83                q |= models.Q(**{field_name + '__icontains': query.query_string})
84            return q
85
86        if isinstance(query, Boost):
87            boost *= query.boost
88            return self.build_database_filter(query.subquery, boost=boost)
89
90        if isinstance(query, MatchAll):
91            return MATCH_ALL
92
93        if isinstance(query, Not):
94            q = self.build_database_filter(query.subquery, boost=boost)
95
96            if q == MATCH_ALL:
97                return MATCH_NONE
98
99            elif q == MATCH_NONE:
100                return MATCH_ALL
101
102            else:
103                return ~q
104
105        if isinstance(query, And):
106            subqueries = [
107                self.build_database_filter(subquery, boost=boost)
108                for subquery in query.subqueries
109            ]
110
111            # If there's a MATCH_NONE, return MATCH_NONE
112            if MATCH_NONE in subqueries:
113                return MATCH_NONE
114
115            # Ignore MATCH_ALL
116            subqueries = [q for q in subqueries if q != MATCH_ALL]
117
118            return AND(subqueries)
119
120        if isinstance(query, Or):
121            subqueries = [
122                self.build_database_filter(subquery, boost=boost)
123                for subquery in query.subqueries
124            ]
125
126            # If there's a MATCH_ALL, return MATCH_ALL
127            if MATCH_ALL in subqueries:
128                return MATCH_ALL
129
130            # Ignore MATCH_NONE
131            subqueries = [q for q in subqueries if q != MATCH_NONE]
132
133            return OR(subqueries)
134
135        raise NotImplementedError(
136            '`%s` is not supported by the database search backend.'
137            % query.__class__.__name__)
138
139
140class DatabaseSearchResults(BaseSearchResults):
141    def get_queryset(self):
142        queryset = self.query_compiler.queryset
143
144        # Run _get_filters_from_queryset to test that no fields that are not
145        # a FilterField have been used in the query.
146        self.query_compiler._get_filters_from_queryset()
147
148        q = self.query_compiler.build_database_filter(self.query_compiler.query)
149
150        if q == MATCH_ALL:
151            pass
152        elif q == MATCH_NONE:
153            queryset = queryset.none()
154        else:
155            queryset = queryset.filter(q)
156
157        return queryset.distinct()[self.start:self.stop]
158
159    def _do_search(self):
160        queryset = self.get_queryset()
161
162        if self._score_field:
163            queryset = queryset.annotate(**{self._score_field: Value(None, output_field=models.FloatField())})
164
165        return queryset.iterator()
166
167    def _do_count(self):
168        return self.get_queryset().count()
169
170    supports_facet = True
171
172    def facet(self, field_name):
173        # Get field
174        field = self.query_compiler._get_filterable_field(field_name)
175        if field is None:
176            raise FilterFieldError(
177                'Cannot facet search results with field "' + field_name + '". Please add index.FilterField(\''
178                + field_name + '\') to ' + self.query_compiler.queryset.model.__name__ + '.search_fields.',
179                field_name=field_name
180            )
181
182        query = self.get_queryset()
183        results = query.values(field_name).annotate(count=Count('pk')).order_by('-count')
184
185        return OrderedDict([
186            (result[field_name], result['count'])
187            for result in results
188        ])
189
190
191class DatabaseSearchBackend(BaseSearchBackend):
192    query_compiler_class = DatabaseSearchQueryCompiler
193    results_class = DatabaseSearchResults
194
195    def reset_index(self):
196        pass  # Not needed
197
198    def add_type(self, model):
199        pass  # Not needed
200
201    def refresh_index(self):
202        pass  # Not needed
203
204    def add(self, obj):
205        pass  # Not needed
206
207    def add_bulk(self, model, obj_list):
208        return  # Not needed
209
210    def delete(self, obj):
211        pass  # Not needed
212
213
214SearchBackend = DatabaseSearchBackend
215