1import datetime 2import operator 3 4from flask import request 5from flask_peewee.forms import BaseModelConverter 6from flask_peewee._compat import reduce 7from peewee import * 8from wtforms import fields 9from wtforms import form 10from wtforms import validators 11from wtforms import widgets 12 13 14class QueryFilter(object): 15 """ 16 Basic class representing a named field (with or without a list of options) 17 and an operation against a given value 18 """ 19 def __init__(self, field, name, options=None): 20 self.field = field 21 self.name = name 22 self.options = options 23 24 def query(self, value): 25 raise NotImplementedError 26 27 def operation(self): 28 raise NotImplementedError 29 30 def get_options(self): 31 return self.options 32 33 34class EqualQueryFilter(QueryFilter): 35 def query(self, value): 36 return self.field == value 37 38 def operation(self): 39 return 'equal to' 40 41 42class NotEqualQueryFilter(QueryFilter): 43 def query(self, value): 44 return self.field != value 45 46 def operation(self): 47 return 'not equal to' 48 49 50class LessThanQueryFilter(QueryFilter): 51 def query(self, value): 52 return self.field < value 53 54 def operation(self): 55 return 'less than' 56 57 58class LessThanEqualToQueryFilter(QueryFilter): 59 def query(self, value): 60 return self.field <= value 61 62 def operation(self): 63 return 'less than or equal to' 64 65 66class GreaterThanQueryFilter(QueryFilter): 67 def query(self, value): 68 return self.field > value 69 70 def operation(self): 71 return 'greater than' 72 73 74class GreaterThanEqualToQueryFilter(QueryFilter): 75 def query(self, value): 76 return self.field >= value 77 78 def operation(self): 79 return 'greater than or equal to' 80 81 82class StartsWithQueryFilter(QueryFilter): 83 def query(self, value): 84 return fn.Lower(fn.Substr(self.field, 1, len(value))) == value.lower() 85 86 def operation(self): 87 return 'starts with' 88 89 90class ContainsQueryFilter(QueryFilter): 91 def query(self, value): 92 return self.field ** ('%%%s%%' % value) 93 94 def operation(self): 95 return 'contains' 96 97 98class YearFilter(QueryFilter): 99 def query(self, value): 100 value = int(value) 101 return self.field.year == value 102 103 def operation(self): 104 return 'year equals' 105 106 107class MonthFilter(QueryFilter): 108 def query(self, value): 109 value = int(value) 110 return self.field.month == value 111 112 def operation(self): 113 return 'month equals' 114 115 116class WithinDaysAgoFilter(QueryFilter): 117 def query(self, value): 118 value = int(value) 119 return self.field >= ( 120 datetime.date.today() - datetime.timedelta(days=value)) 121 122 def operation(self): 123 return 'within X days ago' 124 125 126class OlderThanDaysAgoFilter(QueryFilter): 127 def query(self, value): 128 value = int(value) 129 return self.field < ( 130 datetime.date.today() - datetime.timedelta(days=value)) 131 132 def operation(self): 133 return 'older than X days ago' 134 135 136class FilterMapping(object): 137 """ 138 Map a peewee field to a list of valid query filters for that field 139 """ 140 string = ( 141 EqualQueryFilter, NotEqualQueryFilter, StartsWithQueryFilter, 142 ContainsQueryFilter) 143 numeric = ( 144 EqualQueryFilter, NotEqualQueryFilter, LessThanQueryFilter, 145 GreaterThanQueryFilter, LessThanEqualToQueryFilter, 146 GreaterThanEqualToQueryFilter) 147 datetime_date = (numeric + ( 148 WithinDaysAgoFilter, OlderThanDaysAgoFilter, YearFilter, MonthFilter)) 149 foreign_key = (EqualQueryFilter, NotEqualQueryFilter) 150 boolean = (EqualQueryFilter, NotEqualQueryFilter) 151 152 def get_field_types(self): 153 return { 154 CharField: 'string', 155 TextField: 'string', 156 DateTimeField: 'datetime_date', 157 DateField: 'datetime_date', 158 TimeField: 'numeric', 159 IntegerField: 'numeric', 160 BigIntegerField: 'numeric', 161 FloatField: 'numeric', 162 DoubleField: 'numeric', 163 DecimalField: 'numeric', 164 BooleanField: 'boolean', 165 AutoField: 'numeric', 166 ForeignKeyField: 'foreign_key', 167 } 168 169 def convert(self, field): 170 mapping = self.get_field_types() 171 172 for klass in type(field).__mro__: 173 if klass in mapping: 174 mapping_fn = getattr(self, 'convert_%s' % mapping[klass]) 175 return mapping_fn(field) 176 177 # fall back to numeric 178 return self.convert_numeric(field) 179 180 def convert_string(self, field): 181 return [f(field, field.verbose_name, field.choices) for f in self.string] 182 183 def convert_numeric(self, field): 184 return [f(field, field.verbose_name, field.choices) for f in self.numeric] 185 186 def convert_datetime_date(self, field): 187 return [f(field, field.verbose_name, field.choices) for f in self.datetime_date] 188 189 def convert_boolean(self, field): 190 boolean_choices = [('True', '1', 'False', '')] 191 return [f(field, field.verbose_name, boolean_choices) for f in self.boolean] 192 193 def convert_foreign_key(self, field): 194 return [f(field, field.verbose_name, field.choices) for f in self.foreign_key] 195 196 197class FieldTreeNode(object): 198 def __init__(self, model, fields, children=None): 199 self.model = model 200 self.fields = fields 201 self.children = children or {} 202 203 204def make_field_tree(model, fields, exclude, force_recursion=False, seen=None): 205 no_explicit_fields = fields is None # assume we want all of them 206 if no_explicit_fields: 207 fields = model._meta.sorted_field_names 208 exclude = exclude or [] 209 seen = seen or set() 210 211 model_fields = [] 212 children = {} 213 214 for field_obj in model._meta.sorted_fields: 215 if field_obj.name in exclude or field_obj in seen: 216 continue 217 218 if field_obj.name in fields: 219 model_fields.append(field_obj) 220 221 if isinstance(field_obj, ForeignKeyField): 222 seen.add(field_obj) 223 if no_explicit_fields: 224 rel_fields = None 225 else: 226 rel_fields = [ 227 rf.replace('%s__' % field_obj.name, '') \ 228 for rf in fields if rf.startswith('%s__' % field_obj.name) 229 ] 230 if not rel_fields and force_recursion: 231 rel_fields = None 232 233 rel_exclude = [ 234 rx.replace('%s__' % field_obj.name, '') \ 235 for rx in exclude if rx.startswith('%s__' % field_obj.name) 236 ] 237 children[field_obj.name] = make_field_tree(field_obj.rel_model, rel_fields, rel_exclude, force_recursion, seen) 238 239 return FieldTreeNode(model, model_fields, children) 240 241 242class SmallSelectWidget(widgets.Select): 243 def __call__(self, field, **kwargs): 244 kwargs['class'] = 'span2' 245 return super(SmallSelectWidget, self).__call__(field, **kwargs) 246 247 248class FilterForm(object): 249 base_class = form.Form 250 separator = '-' 251 field_operation_prefix = 'fo_' 252 field_value_prefix = 'fv_' 253 field_relation_prefix = 'fr_' 254 255 def __init__(self, model, model_converter, filter_mapping, fields=None, exclude=None): 256 self.model = model 257 self.model_converter = model_converter 258 self.filter_mapping = filter_mapping 259 260 # convert fields and exclude into a tree 261 self._field_tree = make_field_tree(model, fields, exclude) 262 263 self._query_filters = self.load_query_filters() 264 265 def load_query_filters(self): 266 query_filters = {} 267 queue = [self._field_tree] 268 269 while queue: 270 curr = queue.pop(0) 271 for field in curr.fields: 272 query_filters[field] = self.filter_mapping.convert(field) 273 queue.extend(curr.children.values()) 274 275 return query_filters 276 277 def get_operation_field(self, field): 278 choices = [] 279 for i, query_filter in enumerate(self._query_filters[field]): 280 choices.append((str(i), query_filter.operation())) 281 282 return fields.SelectField(choices=choices, validators=[validators.Optional()], widget=SmallSelectWidget()) 283 284 def get_field_default(self, field): 285 if isinstance(field, DateTimeField): 286 return datetime.datetime.now() 287 elif isinstance(field, DateField): 288 return datetime.date.today() 289 elif isinstance(field, TimeField): 290 return datetime.time(0, 0) 291 return field.default 292 293 def get_value_field(self, field): 294 field_name, form_field = self.model_converter.convert(field.model, field, None) 295 296 form_field.kwargs['default'] = self.get_field_default(field) 297 form_field.kwargs['validators'] = [validators.Optional()] 298 return form_field 299 300 def get_field_dict(self, node=None, prefix=None): 301 field_dict = {} 302 node = node or self._field_tree 303 304 for field in node.fields: 305 op_field = self.get_operation_field(field) 306 val_field = self.get_value_field(field) 307 field_dict['%s%s' % (self.field_operation_prefix, field.name)] = op_field 308 field_dict['%s%s' % (self.field_value_prefix, field.name)] = val_field 309 310 for prefix, node in node.children.items(): 311 child_fd = self.get_field_dict(node, prefix) 312 field_dict['%s%s' % (self.field_relation_prefix, prefix)] = fields.FormField( 313 self.get_form(child_fd), 314 separator=self.separator, 315 ) 316 317 return field_dict 318 319 def get_form(self, field_dict): 320 return type( 321 self.model.__name__ + 'FilterForm', 322 (self.base_class, ), 323 field_dict, 324 ) 325 326 def parse_query_filters(self): 327 # reconstruct the "select" and "value" fields we are searching for in the 328 # arguments from the request by depth-first searching the field tree -- 329 # basically what we should have at the end is the field we're querying, 330 # the type of query (QueryFilter), the value requested, and the path we 331 # took to get there (joins) 332 accum = {} 333 334 def _dfs(node, prefix, models, join_columns): 335 for field in node.fields: 336 qf_select = self.field_operation_prefix.join((prefix, field.name)) 337 qf_value = self.field_value_prefix.join((prefix, field.name)) 338 339 if qf_select in request.args and qf_value in request.args: 340 accum.setdefault(field, []) 341 accum[field].append(( 342 request.args.getlist(qf_select), 343 request.args.getlist(qf_value), 344 models, 345 join_columns, 346 qf_select, 347 qf_value, 348 )) 349 350 for child_prefix, child in node.children.items(): 351 new_prefix = prefix + self.field_relation_prefix + child_prefix + self.separator 352 model_copy = list(models) + [child.model] 353 join_copy = list(join_columns) + [node.model._meta.fields[child_prefix]] 354 _dfs(child, new_prefix, model_copy, join_copy) 355 356 _dfs(self._field_tree, '', [], []) 357 358 return accum 359 360 def process_request(self, query): 361 field_dict = self.get_field_dict() 362 FormClass = self.get_form(field_dict) 363 364 form = FormClass(request.args) 365 query_filters = self.parse_query_filters() 366 cleaned = [] 367 368 for field, filters in query_filters.items(): 369 for (filter_idx_list, filter_value_list, path, join_path, qf_s, qf_v) in filters: 370 query = query.switch(self.model) 371 for join, model in zip(join_path, path): 372 query = query.join(model, on=join) 373 374 q_objects = [] 375 for filter_idx, filter_value in zip(filter_idx_list, filter_value_list): 376 idx = int(filter_idx) 377 cleaned.append((qf_s, idx, qf_v, filter_value)) 378 query_filter = self._query_filters[field][idx] 379 q_objects.append(query_filter.query(field.db_value(filter_value))) 380 381 query = query.where(reduce(operator.or_, q_objects)) 382 383 return form, query, cleaned 384 385 386class FilterModelConverter(BaseModelConverter): 387 def __init__(self, *args, **kwargs): 388 super(FilterModelConverter, self).__init__(*args, **kwargs) 389 self.defaults = dict(self.defaults) 390 self.defaults[TextField] = fields.TextField 391 self.defaults[DateTimeField] = fields.DateTimeField 392