1import datetime
2import operator
3
4from flask import request
5from flask_peewee.forms import BaseModelConverter
6from flask_peewee._compat import reduce
7from peewee import *
8from wtforms import fields
9from wtforms import form
10from wtforms import validators
11from wtforms import widgets
12
13
14class QueryFilter(object):
15    """
16    Basic class representing a named field (with or without a list of options)
17    and an operation against a given value
18    """
19    def __init__(self, field, name, options=None):
20        self.field = field
21        self.name = name
22        self.options = options
23
24    def query(self, value):
25        raise NotImplementedError
26
27    def operation(self):
28        raise NotImplementedError
29
30    def get_options(self):
31        return self.options
32
33
34class EqualQueryFilter(QueryFilter):
35    def query(self, value):
36        return self.field == value
37
38    def operation(self):
39        return 'equal to'
40
41
42class NotEqualQueryFilter(QueryFilter):
43    def query(self, value):
44        return self.field != value
45
46    def operation(self):
47        return 'not equal to'
48
49
50class LessThanQueryFilter(QueryFilter):
51    def query(self, value):
52        return self.field < value
53
54    def operation(self):
55        return 'less than'
56
57
58class LessThanEqualToQueryFilter(QueryFilter):
59    def query(self, value):
60        return self.field <= value
61
62    def operation(self):
63        return 'less than or equal to'
64
65
66class GreaterThanQueryFilter(QueryFilter):
67    def query(self, value):
68        return self.field > value
69
70    def operation(self):
71        return 'greater than'
72
73
74class GreaterThanEqualToQueryFilter(QueryFilter):
75    def query(self, value):
76        return self.field >= value
77
78    def operation(self):
79        return 'greater than or equal to'
80
81
82class StartsWithQueryFilter(QueryFilter):
83    def query(self, value):
84        return fn.Lower(fn.Substr(self.field, 1, len(value))) == value.lower()
85
86    def operation(self):
87        return 'starts with'
88
89
90class ContainsQueryFilter(QueryFilter):
91    def query(self, value):
92        return self.field ** ('%%%s%%' % value)
93
94    def operation(self):
95        return 'contains'
96
97
98class YearFilter(QueryFilter):
99    def query(self, value):
100        value = int(value)
101        return self.field.year == value
102
103    def operation(self):
104        return 'year equals'
105
106
107class MonthFilter(QueryFilter):
108    def query(self, value):
109        value = int(value)
110        return self.field.month == value
111
112    def operation(self):
113        return 'month equals'
114
115
116class WithinDaysAgoFilter(QueryFilter):
117    def query(self, value):
118        value = int(value)
119        return self.field >= (
120            datetime.date.today() - datetime.timedelta(days=value))
121
122    def operation(self):
123        return 'within X days ago'
124
125
126class OlderThanDaysAgoFilter(QueryFilter):
127    def query(self, value):
128        value = int(value)
129        return self.field < (
130            datetime.date.today() - datetime.timedelta(days=value))
131
132    def operation(self):
133        return 'older than X days ago'
134
135
136class FilterMapping(object):
137    """
138    Map a peewee field to a list of valid query filters for that field
139    """
140    string = (
141        EqualQueryFilter, NotEqualQueryFilter, StartsWithQueryFilter,
142        ContainsQueryFilter)
143    numeric = (
144        EqualQueryFilter, NotEqualQueryFilter, LessThanQueryFilter,
145        GreaterThanQueryFilter, LessThanEqualToQueryFilter,
146        GreaterThanEqualToQueryFilter)
147    datetime_date = (numeric + (
148        WithinDaysAgoFilter, OlderThanDaysAgoFilter, YearFilter, MonthFilter))
149    foreign_key = (EqualQueryFilter, NotEqualQueryFilter)
150    boolean = (EqualQueryFilter, NotEqualQueryFilter)
151
152    def get_field_types(self):
153        return {
154            CharField: 'string',
155            TextField: 'string',
156            DateTimeField: 'datetime_date',
157            DateField: 'datetime_date',
158            TimeField: 'numeric',
159            IntegerField: 'numeric',
160            BigIntegerField: 'numeric',
161            FloatField: 'numeric',
162            DoubleField: 'numeric',
163            DecimalField: 'numeric',
164            BooleanField: 'boolean',
165            AutoField: 'numeric',
166            ForeignKeyField: 'foreign_key',
167        }
168
169    def convert(self, field):
170        mapping = self.get_field_types()
171
172        for klass in type(field).__mro__:
173            if klass in mapping:
174                mapping_fn = getattr(self, 'convert_%s' % mapping[klass])
175                return mapping_fn(field)
176
177        # fall back to numeric
178        return self.convert_numeric(field)
179
180    def convert_string(self, field):
181        return [f(field, field.verbose_name, field.choices) for f in self.string]
182
183    def convert_numeric(self, field):
184        return [f(field, field.verbose_name, field.choices) for f in self.numeric]
185
186    def convert_datetime_date(self, field):
187        return [f(field, field.verbose_name, field.choices) for f in self.datetime_date]
188
189    def convert_boolean(self, field):
190        boolean_choices = [('True', '1', 'False', '')]
191        return [f(field, field.verbose_name, boolean_choices) for f in self.boolean]
192
193    def convert_foreign_key(self, field):
194        return [f(field, field.verbose_name, field.choices) for f in self.foreign_key]
195
196
197class FieldTreeNode(object):
198    def __init__(self, model, fields, children=None):
199        self.model = model
200        self.fields = fields
201        self.children = children or {}
202
203
204def make_field_tree(model, fields, exclude, force_recursion=False, seen=None):
205    no_explicit_fields = fields is None # assume we want all of them
206    if no_explicit_fields:
207        fields = model._meta.sorted_field_names
208    exclude = exclude or []
209    seen = seen or set()
210
211    model_fields = []
212    children = {}
213
214    for field_obj in model._meta.sorted_fields:
215        if field_obj.name in exclude or field_obj in seen:
216            continue
217
218        if field_obj.name in fields:
219            model_fields.append(field_obj)
220
221        if isinstance(field_obj, ForeignKeyField):
222            seen.add(field_obj)
223            if no_explicit_fields:
224                rel_fields = None
225            else:
226                rel_fields = [
227                    rf.replace('%s__' % field_obj.name, '') \
228                        for rf in fields if rf.startswith('%s__' % field_obj.name)
229                ]
230                if not rel_fields and force_recursion:
231                    rel_fields = None
232
233            rel_exclude = [
234                rx.replace('%s__' % field_obj.name, '') \
235                    for rx in exclude if rx.startswith('%s__' % field_obj.name)
236            ]
237            children[field_obj.name] = make_field_tree(field_obj.rel_model, rel_fields, rel_exclude, force_recursion, seen)
238
239    return FieldTreeNode(model, model_fields, children)
240
241
242class SmallSelectWidget(widgets.Select):
243    def __call__(self, field, **kwargs):
244        kwargs['class'] = 'span2'
245        return super(SmallSelectWidget, self).__call__(field, **kwargs)
246
247
248class FilterForm(object):
249    base_class = form.Form
250    separator = '-'
251    field_operation_prefix = 'fo_'
252    field_value_prefix = 'fv_'
253    field_relation_prefix = 'fr_'
254
255    def __init__(self, model, model_converter, filter_mapping, fields=None, exclude=None):
256        self.model = model
257        self.model_converter = model_converter
258        self.filter_mapping = filter_mapping
259
260        # convert fields and exclude into a tree
261        self._field_tree = make_field_tree(model, fields, exclude)
262
263        self._query_filters = self.load_query_filters()
264
265    def load_query_filters(self):
266        query_filters = {}
267        queue = [self._field_tree]
268
269        while queue:
270            curr = queue.pop(0)
271            for field in curr.fields:
272                query_filters[field] = self.filter_mapping.convert(field)
273            queue.extend(curr.children.values())
274
275        return query_filters
276
277    def get_operation_field(self, field):
278        choices = []
279        for i, query_filter in enumerate(self._query_filters[field]):
280            choices.append((str(i), query_filter.operation()))
281
282        return fields.SelectField(choices=choices, validators=[validators.Optional()], widget=SmallSelectWidget())
283
284    def get_field_default(self, field):
285        if isinstance(field, DateTimeField):
286            return datetime.datetime.now()
287        elif isinstance(field, DateField):
288            return datetime.date.today()
289        elif isinstance(field, TimeField):
290            return datetime.time(0, 0)
291        return field.default
292
293    def get_value_field(self, field):
294        field_name, form_field = self.model_converter.convert(field.model, field, None)
295
296        form_field.kwargs['default'] = self.get_field_default(field)
297        form_field.kwargs['validators'] = [validators.Optional()]
298        return form_field
299
300    def get_field_dict(self, node=None, prefix=None):
301        field_dict = {}
302        node = node or self._field_tree
303
304        for field in node.fields:
305            op_field = self.get_operation_field(field)
306            val_field = self.get_value_field(field)
307            field_dict['%s%s' % (self.field_operation_prefix, field.name)] = op_field
308            field_dict['%s%s' % (self.field_value_prefix, field.name)] = val_field
309
310        for prefix, node in node.children.items():
311            child_fd = self.get_field_dict(node, prefix)
312            field_dict['%s%s' % (self.field_relation_prefix, prefix)] = fields.FormField(
313                self.get_form(child_fd),
314                separator=self.separator,
315            )
316
317        return field_dict
318
319    def get_form(self, field_dict):
320        return type(
321            self.model.__name__ + 'FilterForm',
322            (self.base_class, ),
323            field_dict,
324        )
325
326    def parse_query_filters(self):
327        # reconstruct the "select" and "value" fields we are searching for in the
328        # arguments from the request by depth-first searching the field tree --
329        # basically what we should have at the end is the field we're querying,
330        # the type of query (QueryFilter), the value requested, and the path we
331        # took to get there (joins)
332        accum = {}
333
334        def _dfs(node, prefix, models, join_columns):
335            for field in node.fields:
336                qf_select = self.field_operation_prefix.join((prefix, field.name))
337                qf_value = self.field_value_prefix.join((prefix, field.name))
338
339                if qf_select in request.args and qf_value in request.args:
340                    accum.setdefault(field, [])
341                    accum[field].append((
342                        request.args.getlist(qf_select),
343                        request.args.getlist(qf_value),
344                        models,
345                        join_columns,
346                        qf_select,
347                        qf_value,
348                    ))
349
350            for child_prefix, child in node.children.items():
351                new_prefix = prefix + self.field_relation_prefix + child_prefix + self.separator
352                model_copy = list(models) + [child.model]
353                join_copy = list(join_columns) + [node.model._meta.fields[child_prefix]]
354                _dfs(child, new_prefix, model_copy, join_copy)
355
356        _dfs(self._field_tree, '', [], [])
357
358        return accum
359
360    def process_request(self, query):
361        field_dict = self.get_field_dict()
362        FormClass = self.get_form(field_dict)
363
364        form = FormClass(request.args)
365        query_filters = self.parse_query_filters()
366        cleaned = []
367
368        for field, filters in query_filters.items():
369            for (filter_idx_list, filter_value_list, path, join_path, qf_s, qf_v) in filters:
370                query = query.switch(self.model)
371                for join, model in zip(join_path, path):
372                    query = query.join(model, on=join)
373
374                q_objects = []
375                for filter_idx, filter_value in zip(filter_idx_list, filter_value_list):
376                    idx = int(filter_idx)
377                    cleaned.append((qf_s, idx, qf_v, filter_value))
378                    query_filter = self._query_filters[field][idx]
379                    q_objects.append(query_filter.query(field.db_value(filter_value)))
380
381                query = query.where(reduce(operator.or_, q_objects))
382
383        return form, query, cleaned
384
385
386class FilterModelConverter(BaseModelConverter):
387    def __init__(self, *args, **kwargs):
388        super(FilterModelConverter, self).__init__(*args, **kwargs)
389        self.defaults = dict(self.defaults)
390        self.defaults[TextField] = fields.TextField
391        self.defaults[DateTimeField] = fields.DateTimeField
392