1from collections import OrderedDict 2from warnings import warn 3 4from django.core.exceptions import FieldDoesNotExist 5from django.db import models 6from django.db.models import Count 7from django.db.models.expressions import Value 8 9from wagtail.search.backends.base import ( 10 BaseSearchBackend, BaseSearchQueryCompiler, BaseSearchResults, FilterFieldError) 11from wagtail.search.query import And, Boost, MatchAll, Not, Or, Phrase, PlainText 12from wagtail.search.utils import AND, OR 13 14 15MATCH_ALL = '_ALL_' 16MATCH_NONE = '_NONE_' 17 18 19class DatabaseSearchQueryCompiler(BaseSearchQueryCompiler): 20 DEFAULT_OPERATOR = 'and' 21 OPERATORS = { 22 'and': AND, 23 'or': OR, 24 } 25 26 def __init__(self, *args, **kwargs): 27 super().__init__(*args, **kwargs) 28 self.fields_names = list(self.get_fields_names()) 29 30 def get_fields_names(self): 31 model = self.queryset.model 32 fields_names = self.fields or [field.field_name for field in 33 model.get_searchable_search_fields()] 34 # Check if the field exists (this will filter out indexed callables) 35 for field_name in fields_names: 36 try: 37 model._meta.get_field(field_name) 38 except FieldDoesNotExist: 39 continue 40 else: 41 yield field_name 42 43 def _process_lookup(self, field, lookup, value): 44 return models.Q(**{field.get_attname(self.queryset.model) + '__' + lookup: value}) 45 46 def _connect_filters(self, filters, connector, negated): 47 if connector == 'AND': 48 q = models.Q(*filters) 49 elif connector == 'OR': 50 q = OR([models.Q(fil) for fil in filters]) 51 else: 52 return 53 54 if negated: 55 q = ~q 56 57 return q 58 59 def build_single_term_filter(self, term): 60 term_query = models.Q() 61 for field_name in self.fields_names: 62 term_query |= models.Q(**{field_name + '__icontains': term}) 63 return term_query 64 65 def check_boost(self, query, boost=1.0): 66 if query.boost * boost != 1.0: 67 warn('Database search backend does not support term boosting.') 68 69 def build_database_filter(self, query, boost=1.0): 70 if isinstance(query, PlainText): 71 self.check_boost(query, boost=boost) 72 73 operator = self.OPERATORS[query.operator] 74 75 return operator([ 76 self.build_single_term_filter(term) 77 for term in query.query_string.split() 78 ]) 79 80 if isinstance(query, Phrase): 81 q = models.Q() 82 for field_name in self.fields_names: 83 q |= models.Q(**{field_name + '__icontains': query.query_string}) 84 return q 85 86 if isinstance(query, Boost): 87 boost *= query.boost 88 return self.build_database_filter(query.subquery, boost=boost) 89 90 if isinstance(query, MatchAll): 91 return MATCH_ALL 92 93 if isinstance(query, Not): 94 q = self.build_database_filter(query.subquery, boost=boost) 95 96 if q == MATCH_ALL: 97 return MATCH_NONE 98 99 elif q == MATCH_NONE: 100 return MATCH_ALL 101 102 else: 103 return ~q 104 105 if isinstance(query, And): 106 subqueries = [ 107 self.build_database_filter(subquery, boost=boost) 108 for subquery in query.subqueries 109 ] 110 111 # If there's a MATCH_NONE, return MATCH_NONE 112 if MATCH_NONE in subqueries: 113 return MATCH_NONE 114 115 # Ignore MATCH_ALL 116 subqueries = [q for q in subqueries if q != MATCH_ALL] 117 118 return AND(subqueries) 119 120 if isinstance(query, Or): 121 subqueries = [ 122 self.build_database_filter(subquery, boost=boost) 123 for subquery in query.subqueries 124 ] 125 126 # If there's a MATCH_ALL, return MATCH_ALL 127 if MATCH_ALL in subqueries: 128 return MATCH_ALL 129 130 # Ignore MATCH_NONE 131 subqueries = [q for q in subqueries if q != MATCH_NONE] 132 133 return OR(subqueries) 134 135 raise NotImplementedError( 136 '`%s` is not supported by the database search backend.' 137 % query.__class__.__name__) 138 139 140class DatabaseSearchResults(BaseSearchResults): 141 def get_queryset(self): 142 queryset = self.query_compiler.queryset 143 144 # Run _get_filters_from_queryset to test that no fields that are not 145 # a FilterField have been used in the query. 146 self.query_compiler._get_filters_from_queryset() 147 148 q = self.query_compiler.build_database_filter(self.query_compiler.query) 149 150 if q == MATCH_ALL: 151 pass 152 elif q == MATCH_NONE: 153 queryset = queryset.none() 154 else: 155 queryset = queryset.filter(q) 156 157 return queryset.distinct()[self.start:self.stop] 158 159 def _do_search(self): 160 queryset = self.get_queryset() 161 162 if self._score_field: 163 queryset = queryset.annotate(**{self._score_field: Value(None, output_field=models.FloatField())}) 164 165 return queryset.iterator() 166 167 def _do_count(self): 168 return self.get_queryset().count() 169 170 supports_facet = True 171 172 def facet(self, field_name): 173 # Get field 174 field = self.query_compiler._get_filterable_field(field_name) 175 if field is None: 176 raise FilterFieldError( 177 'Cannot facet search results with field "' + field_name + '". Please add index.FilterField(\'' 178 + field_name + '\') to ' + self.query_compiler.queryset.model.__name__ + '.search_fields.', 179 field_name=field_name 180 ) 181 182 query = self.get_queryset() 183 results = query.values(field_name).annotate(count=Count('pk')).order_by('-count') 184 185 return OrderedDict([ 186 (result[field_name], result['count']) 187 for result in results 188 ]) 189 190 191class DatabaseSearchBackend(BaseSearchBackend): 192 query_compiler_class = DatabaseSearchQueryCompiler 193 results_class = DatabaseSearchResults 194 195 def reset_index(self): 196 pass # Not needed 197 198 def add_type(self, model): 199 pass # Not needed 200 201 def refresh_index(self): 202 pass # Not needed 203 204 def add(self, obj): 205 pass # Not needed 206 207 def add_bulk(self, model, obj_list): 208 return # Not needed 209 210 def delete(self, obj): 211 pass # Not needed 212 213 214SearchBackend = DatabaseSearchBackend 215