1from flask_admin.babel import lazy_gettext
2from flask_admin.model import filters
3from flask_admin.contrib.sqla import tools
4from sqlalchemy.sql import not_, or_
5import enum
6
7
8class BaseSQLAFilter(filters.BaseFilter):
9    """
10        Base SQLAlchemy filter.
11    """
12    def __init__(self, column, name, options=None, data_type=None):
13        """
14            Constructor.
15
16            :param column:
17                Model field
18            :param name:
19                Display name
20            :param options:
21                Fixed set of options
22            :param data_type:
23                Client data type
24        """
25        super(BaseSQLAFilter, self).__init__(name, options, data_type)
26
27        self.column = column
28
29    def get_column(self, alias):
30        return self.column if alias is None else getattr(alias, self.column.key)
31
32    def apply(self, query, value, alias=None):
33        return super(BaseSQLAFilter, self).apply(query, value)
34
35
36# Common filters
37class FilterEqual(BaseSQLAFilter):
38    def apply(self, query, value, alias=None):
39        return query.filter(self.get_column(alias) == value)
40
41    def operation(self):
42        return lazy_gettext('equals')
43
44
45class FilterNotEqual(BaseSQLAFilter):
46    def apply(self, query, value, alias=None):
47        return query.filter(self.get_column(alias) != value)
48
49    def operation(self):
50        return lazy_gettext('not equal')
51
52
53class FilterLike(BaseSQLAFilter):
54    def apply(self, query, value, alias=None):
55        stmt = tools.parse_like_term(value)
56        return query.filter(self.get_column(alias).ilike(stmt))
57
58    def operation(self):
59        return lazy_gettext('contains')
60
61
62class FilterNotLike(BaseSQLAFilter):
63    def apply(self, query, value, alias=None):
64        stmt = tools.parse_like_term(value)
65        return query.filter(~self.get_column(alias).ilike(stmt))
66
67    def operation(self):
68        return lazy_gettext('not contains')
69
70
71class FilterGreater(BaseSQLAFilter):
72    def apply(self, query, value, alias=None):
73        return query.filter(self.get_column(alias) > value)
74
75    def operation(self):
76        return lazy_gettext('greater than')
77
78
79class FilterSmaller(BaseSQLAFilter):
80    def apply(self, query, value, alias=None):
81        return query.filter(self.get_column(alias) < value)
82
83    def operation(self):
84        return lazy_gettext('smaller than')
85
86
87class FilterEmpty(BaseSQLAFilter, filters.BaseBooleanFilter):
88    def apply(self, query, value, alias=None):
89        if value == '1':
90            return query.filter(self.get_column(alias) == None)  # noqa: E711
91        else:
92            return query.filter(self.get_column(alias) != None)  # noqa: E711
93
94    def operation(self):
95        return lazy_gettext('empty')
96
97
98class FilterInList(BaseSQLAFilter):
99    def __init__(self, column, name, options=None, data_type=None):
100        super(FilterInList, self).__init__(column, name, options, data_type='select2-tags')
101
102    def clean(self, value):
103        return [v.strip() for v in value.split(',') if v.strip()]
104
105    def apply(self, query, value, alias=None):
106        return query.filter(self.get_column(alias).in_(value))
107
108    def operation(self):
109        return lazy_gettext('in list')
110
111
112class FilterNotInList(FilterInList):
113    def apply(self, query, value, alias=None):
114        # NOT IN can exclude NULL values, so "or_ == None" needed to be added
115        column = self.get_column(alias)
116        return query.filter(or_(~column.in_(value), column == None))  # noqa: E711
117
118    def operation(self):
119        return lazy_gettext('not in list')
120
121
122# Customized type filters
123class BooleanEqualFilter(FilterEqual, filters.BaseBooleanFilter):
124    pass
125
126
127class BooleanNotEqualFilter(FilterNotEqual, filters.BaseBooleanFilter):
128    pass
129
130
131class IntEqualFilter(FilterEqual, filters.BaseIntFilter):
132    pass
133
134
135class IntNotEqualFilter(FilterNotEqual, filters.BaseIntFilter):
136    pass
137
138
139class IntGreaterFilter(FilterGreater, filters.BaseIntFilter):
140    pass
141
142
143class IntSmallerFilter(FilterSmaller, filters.BaseIntFilter):
144    pass
145
146
147class IntInListFilter(filters.BaseIntListFilter, FilterInList):
148    pass
149
150
151class IntNotInListFilter(filters.BaseIntListFilter, FilterNotInList):
152    pass
153
154
155class FloatEqualFilter(FilterEqual, filters.BaseFloatFilter):
156    pass
157
158
159class FloatNotEqualFilter(FilterNotEqual, filters.BaseFloatFilter):
160    pass
161
162
163class FloatGreaterFilter(FilterGreater, filters.BaseFloatFilter):
164    pass
165
166
167class FloatSmallerFilter(FilterSmaller, filters.BaseFloatFilter):
168    pass
169
170
171class FloatInListFilter(filters.BaseFloatListFilter, FilterInList):
172    pass
173
174
175class FloatNotInListFilter(filters.BaseFloatListFilter, FilterNotInList):
176    pass
177
178
179class DateEqualFilter(FilterEqual, filters.BaseDateFilter):
180    pass
181
182
183class DateNotEqualFilter(FilterNotEqual, filters.BaseDateFilter):
184    pass
185
186
187class DateGreaterFilter(FilterGreater, filters.BaseDateFilter):
188    pass
189
190
191class DateSmallerFilter(FilterSmaller, filters.BaseDateFilter):
192    pass
193
194
195class DateBetweenFilter(BaseSQLAFilter, filters.BaseDateBetweenFilter):
196    def __init__(self, column, name, options=None, data_type=None):
197        super(DateBetweenFilter, self).__init__(column,
198                                                name,
199                                                options,
200                                                data_type='daterangepicker')
201
202    def apply(self, query, value, alias=None):
203        start, end = value
204        return query.filter(self.get_column(alias).between(start, end))
205
206
207class DateNotBetweenFilter(DateBetweenFilter):
208    def apply(self, query, value, alias=None):
209        start, end = value
210        # ~between() isn't possible until sqlalchemy 1.0.0
211        return query.filter(not_(self.get_column(alias).between(start, end)))
212
213    def operation(self):
214        return lazy_gettext('not between')
215
216
217class DateTimeEqualFilter(FilterEqual, filters.BaseDateTimeFilter):
218    pass
219
220
221class DateTimeNotEqualFilter(FilterNotEqual, filters.BaseDateTimeFilter):
222    pass
223
224
225class DateTimeGreaterFilter(FilterGreater, filters.BaseDateTimeFilter):
226    pass
227
228
229class DateTimeSmallerFilter(FilterSmaller, filters.BaseDateTimeFilter):
230    pass
231
232
233class DateTimeBetweenFilter(BaseSQLAFilter, filters.BaseDateTimeBetweenFilter):
234    def __init__(self, column, name, options=None, data_type=None):
235        super(DateTimeBetweenFilter, self).__init__(column,
236                                                    name,
237                                                    options,
238                                                    data_type='datetimerangepicker')
239
240    def apply(self, query, value, alias=None):
241        start, end = value
242        return query.filter(self.get_column(alias).between(start, end))
243
244
245class DateTimeNotBetweenFilter(DateTimeBetweenFilter):
246    def apply(self, query, value, alias=None):
247        start, end = value
248        return query.filter(not_(self.get_column(alias).between(start, end)))
249
250    def operation(self):
251        return lazy_gettext('not between')
252
253
254class TimeEqualFilter(FilterEqual, filters.BaseTimeFilter):
255    pass
256
257
258class TimeNotEqualFilter(FilterNotEqual, filters.BaseTimeFilter):
259    pass
260
261
262class TimeGreaterFilter(FilterGreater, filters.BaseTimeFilter):
263    pass
264
265
266class TimeSmallerFilter(FilterSmaller, filters.BaseTimeFilter):
267    pass
268
269
270class TimeBetweenFilter(BaseSQLAFilter, filters.BaseTimeBetweenFilter):
271    def __init__(self, column, name, options=None, data_type=None):
272        super(TimeBetweenFilter, self).__init__(column,
273                                                name,
274                                                options,
275                                                data_type='timerangepicker')
276
277    def apply(self, query, value, alias=None):
278        start, end = value
279        return query.filter(self.get_column(alias).between(start, end))
280
281
282class TimeNotBetweenFilter(TimeBetweenFilter):
283    def apply(self, query, value, alias=None):
284        start, end = value
285        return query.filter(not_(self.get_column(alias).between(start, end)))
286
287    def operation(self):
288        return lazy_gettext('not between')
289
290
291class EnumEqualFilter(FilterEqual):
292    def __init__(self, column, name, options=None, enum_class=None, **kwargs):
293        self.enum_class = enum_class
294        super(EnumEqualFilter, self).__init__(column, name, options, **kwargs)
295
296    def clean(self, value):
297        if self.enum_class is None:
298            return super(EnumEqualFilter, self).clean(value)
299        return self.enum_class(value)
300
301
302class EnumFilterNotEqual(FilterNotEqual):
303    def __init__(self, column, name, options=None, enum_class=None, **kwargs):
304        self.enum_class = enum_class
305        super(EnumFilterNotEqual, self).__init__(column, name, options, **kwargs)
306
307    def clean(self, value):
308        if self.enum_class is None:
309            return super(EnumFilterNotEqual, self).clean(value)
310        return self.enum_class(value)
311
312
313class EnumFilterEmpty(FilterEmpty):
314    def __init__(self, column, name, options=None, enum_class=None, **kwargs):
315        self.enum_class = enum_class
316        super(EnumFilterEmpty, self).__init__(column, name, options, **kwargs)
317
318
319class EnumFilterInList(FilterInList):
320    def __init__(self, column, name, options=None, enum_class=None, **kwargs):
321        self.enum_class = enum_class
322        super(EnumFilterInList, self).__init__(column, name, options, **kwargs)
323
324    def clean(self, value):
325        values = super(EnumFilterInList, self).clean(value)
326        if self.enum_class is not None:
327            values = [self.enum_class(val) for val in values]
328        return values
329
330
331class EnumFilterNotInList(FilterNotInList):
332    def __init__(self, column, name, options=None, enum_class=None, **kwargs):
333        self.enum_class = enum_class
334        super(EnumFilterNotInList, self).__init__(column, name, options, **kwargs)
335
336    def clean(self, value):
337        values = super(EnumFilterNotInList, self).clean(value)
338        if self.enum_class is not None:
339            values = [self.enum_class(val) for val in values]
340        return values
341
342
343class ChoiceTypeEqualFilter(FilterEqual):
344    def __init__(self, column, name, options=None, **kwargs):
345        super(ChoiceTypeEqualFilter, self).__init__(column, name, options, **kwargs)
346
347    def apply(self, query, user_query, alias=None):
348        column = self.get_column(alias)
349        choice_type = None
350        # loop through choice 'values' to try and find an exact match
351        if isinstance(column.type.choices, enum.EnumMeta):
352            for choice in column.type.choices:
353                if choice.name == user_query:
354                    choice_type = choice.value
355                    break
356        else:
357            for type, value in column.type.choices:
358                if value == user_query:
359                    choice_type = type
360                    break
361        if choice_type:
362            return query.filter(column == choice_type)
363        else:
364            return query.filter(column.in_([]))
365
366
367class ChoiceTypeNotEqualFilter(FilterNotEqual):
368    def __init__(self, column, name, options=None, **kwargs):
369        super(ChoiceTypeNotEqualFilter, self).__init__(column, name, options, **kwargs)
370
371    def apply(self, query, user_query, alias=None):
372        column = self.get_column(alias)
373        choice_type = None
374        # loop through choice 'values' to try and find an exact match
375        if isinstance(column.type.choices, enum.EnumMeta):
376            for choice in column.type.choices:
377                if choice.name == user_query:
378                    choice_type = choice.value
379                    break
380        else:
381            for type, value in column.type.choices:
382                if value == user_query:
383                    choice_type = type
384                    break
385        if choice_type:
386            # != can exclude NULL values, so "or_ == None" needed to be added
387            return query.filter(or_(column != choice_type, column == None))  # noqa: E711
388        else:
389            return query
390
391
392class ChoiceTypeLikeFilter(FilterLike):
393    def __init__(self, column, name, options=None, **kwargs):
394        super(ChoiceTypeLikeFilter, self).__init__(column, name, options, **kwargs)
395
396    def apply(self, query, user_query, alias=None):
397        column = self.get_column(alias)
398        choice_types = []
399        if user_query:
400            # loop through choice 'values' looking for matches
401            if isinstance(column.type.choices, enum.EnumMeta):
402                for choice in column.type.choices:
403                    if user_query.lower() in choice.name.lower():
404                        choice_types.append(choice.value)
405            else:
406                for type, value in column.type.choices:
407                    if user_query.lower() in value.lower():
408                        choice_types.append(type)
409        if choice_types:
410            return query.filter(column.in_(choice_types))
411        else:
412            return query
413
414
415class ChoiceTypeNotLikeFilter(FilterNotLike):
416    def __init__(self, column, name, options=None, **kwargs):
417        super(ChoiceTypeNotLikeFilter, self).__init__(column, name, options, **kwargs)
418
419    def apply(self, query, user_query, alias=None):
420        column = self.get_column(alias)
421        choice_types = []
422        if user_query:
423            # loop through choice 'values' looking for matches
424            if isinstance(column.type.choices, enum.EnumMeta):
425                for choice in column.type.choices:
426                    if user_query.lower() in choice.name.lower():
427                        choice_types.append(choice.value)
428            else:
429                for type, value in column.type.choices:
430                    if user_query.lower() in value.lower():
431                        choice_types.append(type)
432        if choice_types:
433            # != can exclude NULL values, so "or_ == None" needed to be added
434            return query.filter(or_(column.notin_(choice_types), column == None))  # noqa: E711
435        else:
436            return query
437
438
439class UuidFilterEqual(FilterEqual, filters.BaseUuidFilter):
440    pass
441
442
443class UuidFilterNotEqual(FilterNotEqual, filters.BaseUuidFilter):
444    pass
445
446
447class UuidFilterInList(filters.BaseUuidListFilter, FilterInList):
448    pass
449
450
451class UuidFilterNotInList(filters.BaseUuidListFilter, FilterNotInList):
452    pass
453
454
455# Base SQLA filter field converter
456class FilterConverter(filters.BaseFilterConverter):
457    strings = (FilterLike, FilterNotLike, FilterEqual, FilterNotEqual,
458               FilterEmpty, FilterInList, FilterNotInList)
459    string_key_filters = (FilterEqual, FilterNotEqual, FilterEmpty, FilterInList, FilterNotInList)
460    int_filters = (IntEqualFilter, IntNotEqualFilter, IntGreaterFilter,
461                   IntSmallerFilter, FilterEmpty, IntInListFilter,
462                   IntNotInListFilter)
463    float_filters = (FloatEqualFilter, FloatNotEqualFilter, FloatGreaterFilter,
464                     FloatSmallerFilter, FilterEmpty, FloatInListFilter,
465                     FloatNotInListFilter)
466    bool_filters = (BooleanEqualFilter, BooleanNotEqualFilter)
467    enum = (EnumEqualFilter, EnumFilterNotEqual, EnumFilterEmpty, EnumFilterInList,
468            EnumFilterNotInList)
469    date_filters = (DateEqualFilter, DateNotEqualFilter, DateGreaterFilter,
470                    DateSmallerFilter, DateBetweenFilter, DateNotBetweenFilter,
471                    FilterEmpty)
472    datetime_filters = (DateTimeEqualFilter, DateTimeNotEqualFilter,
473                        DateTimeGreaterFilter, DateTimeSmallerFilter,
474                        DateTimeBetweenFilter, DateTimeNotBetweenFilter,
475                        FilterEmpty)
476    time_filters = (TimeEqualFilter, TimeNotEqualFilter, TimeGreaterFilter,
477                    TimeSmallerFilter, TimeBetweenFilter, TimeNotBetweenFilter,
478                    FilterEmpty)
479    choice_type_filters = (ChoiceTypeEqualFilter, ChoiceTypeNotEqualFilter,
480                           ChoiceTypeLikeFilter, ChoiceTypeNotLikeFilter, FilterEmpty)
481    uuid_filters = (UuidFilterEqual, UuidFilterNotEqual, FilterEmpty,
482                    UuidFilterInList, UuidFilterNotInList)
483    arrow_type_filters = (DateTimeGreaterFilter, DateTimeSmallerFilter, FilterEmpty)
484
485    def convert(self, type_name, column, name, **kwargs):
486        filter_name = type_name.lower()
487
488        if filter_name in self.converters:
489            return self.converters[filter_name](column, name, **kwargs)
490
491        return None
492
493    @filters.convert('string', 'char', 'unicode', 'varchar', 'tinytext',
494                     'text', 'mediumtext', 'longtext', 'unicodetext',
495                     'nchar', 'nvarchar', 'ntext', 'citext', 'emailtype',
496                     'URLType', 'IPAddressType')
497    def conv_string(self, column, name, **kwargs):
498        return [f(column, name, **kwargs) for f in self.strings]
499
500    @filters.convert('UUIDType', 'ColorType', 'TimezoneType', 'CurrencyType')
501    def conv_string_keys(self, column, name, **kwargs):
502        return [f(column, name, **kwargs) for f in self.string_key_filters]
503
504    @filters.convert('boolean', 'tinyint')
505    def conv_bool(self, column, name, **kwargs):
506        return [f(column, name, **kwargs) for f in self.bool_filters]
507
508    @filters.convert('int', 'integer', 'smallinteger', 'smallint',
509                     'biginteger', 'bigint', 'mediumint')
510    def conv_int(self, column, name, **kwargs):
511        return [f(column, name, **kwargs) for f in self.int_filters]
512
513    @filters.convert('float', 'real', 'decimal', 'numeric', 'double_precision', 'double')
514    def conv_float(self, column, name, **kwargs):
515        return [f(column, name, **kwargs) for f in self.float_filters]
516
517    @filters.convert('date')
518    def conv_date(self, column, name, **kwargs):
519        return [f(column, name, **kwargs) for f in self.date_filters]
520
521    @filters.convert('datetime', 'datetime2', 'timestamp', 'smalldatetime')
522    def conv_datetime(self, column, name, **kwargs):
523        return [f(column, name, **kwargs) for f in self.datetime_filters]
524
525    @filters.convert('time')
526    def conv_time(self, column, name, **kwargs):
527        return [f(column, name, **kwargs) for f in self.time_filters]
528
529    @filters.convert('ChoiceType')
530    def conv_sqla_utils_choice(self, column, name, **kwargs):
531        return [f(column, name, **kwargs) for f in self.choice_type_filters]
532
533    @filters.convert('ArrowType')
534    def conv_sqla_utils_arrow(self, column, name, **kwargs):
535        return [f(column, name, **kwargs) for f in self.arrow_type_filters]
536
537    @filters.convert('enum')
538    def conv_enum(self, column, name, options=None, **kwargs):
539        if not options:
540            options = [
541                (v, v)
542                for v in column.type.enums
543            ]
544        try:
545            from sqlalchemy_enum34 import EnumType
546        except ImportError:
547            pass
548        else:
549            if isinstance(column.type, EnumType):
550                kwargs['enum_class'] = column.type._enum_class
551
552        return [f(column, name, options, **kwargs) for f in self.enum]
553
554    @filters.convert('uuid')
555    def conv_uuid(self, column, name, **kwargs):
556        return [f(column, name, **kwargs) for f in self.uuid_filters]
557