1from flask_admin.babel import lazy_gettext 2from flask_admin.model import filters 3from flask_admin.contrib.sqla import tools 4from sqlalchemy.sql import not_, or_ 5import enum 6 7 8class BaseSQLAFilter(filters.BaseFilter): 9 """ 10 Base SQLAlchemy filter. 11 """ 12 def __init__(self, column, name, options=None, data_type=None): 13 """ 14 Constructor. 15 16 :param column: 17 Model field 18 :param name: 19 Display name 20 :param options: 21 Fixed set of options 22 :param data_type: 23 Client data type 24 """ 25 super(BaseSQLAFilter, self).__init__(name, options, data_type) 26 27 self.column = column 28 29 def get_column(self, alias): 30 return self.column if alias is None else getattr(alias, self.column.key) 31 32 def apply(self, query, value, alias=None): 33 return super(BaseSQLAFilter, self).apply(query, value) 34 35 36# Common filters 37class FilterEqual(BaseSQLAFilter): 38 def apply(self, query, value, alias=None): 39 return query.filter(self.get_column(alias) == value) 40 41 def operation(self): 42 return lazy_gettext('equals') 43 44 45class FilterNotEqual(BaseSQLAFilter): 46 def apply(self, query, value, alias=None): 47 return query.filter(self.get_column(alias) != value) 48 49 def operation(self): 50 return lazy_gettext('not equal') 51 52 53class FilterLike(BaseSQLAFilter): 54 def apply(self, query, value, alias=None): 55 stmt = tools.parse_like_term(value) 56 return query.filter(self.get_column(alias).ilike(stmt)) 57 58 def operation(self): 59 return lazy_gettext('contains') 60 61 62class FilterNotLike(BaseSQLAFilter): 63 def apply(self, query, value, alias=None): 64 stmt = tools.parse_like_term(value) 65 return query.filter(~self.get_column(alias).ilike(stmt)) 66 67 def operation(self): 68 return lazy_gettext('not contains') 69 70 71class FilterGreater(BaseSQLAFilter): 72 def apply(self, query, value, alias=None): 73 return query.filter(self.get_column(alias) > value) 74 75 def operation(self): 76 return lazy_gettext('greater than') 77 78 79class FilterSmaller(BaseSQLAFilter): 80 def apply(self, query, value, alias=None): 81 return query.filter(self.get_column(alias) < value) 82 83 def operation(self): 84 return lazy_gettext('smaller than') 85 86 87class FilterEmpty(BaseSQLAFilter, filters.BaseBooleanFilter): 88 def apply(self, query, value, alias=None): 89 if value == '1': 90 return query.filter(self.get_column(alias) == None) # noqa: E711 91 else: 92 return query.filter(self.get_column(alias) != None) # noqa: E711 93 94 def operation(self): 95 return lazy_gettext('empty') 96 97 98class FilterInList(BaseSQLAFilter): 99 def __init__(self, column, name, options=None, data_type=None): 100 super(FilterInList, self).__init__(column, name, options, data_type='select2-tags') 101 102 def clean(self, value): 103 return [v.strip() for v in value.split(',') if v.strip()] 104 105 def apply(self, query, value, alias=None): 106 return query.filter(self.get_column(alias).in_(value)) 107 108 def operation(self): 109 return lazy_gettext('in list') 110 111 112class FilterNotInList(FilterInList): 113 def apply(self, query, value, alias=None): 114 # NOT IN can exclude NULL values, so "or_ == None" needed to be added 115 column = self.get_column(alias) 116 return query.filter(or_(~column.in_(value), column == None)) # noqa: E711 117 118 def operation(self): 119 return lazy_gettext('not in list') 120 121 122# Customized type filters 123class BooleanEqualFilter(FilterEqual, filters.BaseBooleanFilter): 124 pass 125 126 127class BooleanNotEqualFilter(FilterNotEqual, filters.BaseBooleanFilter): 128 pass 129 130 131class IntEqualFilter(FilterEqual, filters.BaseIntFilter): 132 pass 133 134 135class IntNotEqualFilter(FilterNotEqual, filters.BaseIntFilter): 136 pass 137 138 139class IntGreaterFilter(FilterGreater, filters.BaseIntFilter): 140 pass 141 142 143class IntSmallerFilter(FilterSmaller, filters.BaseIntFilter): 144 pass 145 146 147class IntInListFilter(filters.BaseIntListFilter, FilterInList): 148 pass 149 150 151class IntNotInListFilter(filters.BaseIntListFilter, FilterNotInList): 152 pass 153 154 155class FloatEqualFilter(FilterEqual, filters.BaseFloatFilter): 156 pass 157 158 159class FloatNotEqualFilter(FilterNotEqual, filters.BaseFloatFilter): 160 pass 161 162 163class FloatGreaterFilter(FilterGreater, filters.BaseFloatFilter): 164 pass 165 166 167class FloatSmallerFilter(FilterSmaller, filters.BaseFloatFilter): 168 pass 169 170 171class FloatInListFilter(filters.BaseFloatListFilter, FilterInList): 172 pass 173 174 175class FloatNotInListFilter(filters.BaseFloatListFilter, FilterNotInList): 176 pass 177 178 179class DateEqualFilter(FilterEqual, filters.BaseDateFilter): 180 pass 181 182 183class DateNotEqualFilter(FilterNotEqual, filters.BaseDateFilter): 184 pass 185 186 187class DateGreaterFilter(FilterGreater, filters.BaseDateFilter): 188 pass 189 190 191class DateSmallerFilter(FilterSmaller, filters.BaseDateFilter): 192 pass 193 194 195class DateBetweenFilter(BaseSQLAFilter, filters.BaseDateBetweenFilter): 196 def __init__(self, column, name, options=None, data_type=None): 197 super(DateBetweenFilter, self).__init__(column, 198 name, 199 options, 200 data_type='daterangepicker') 201 202 def apply(self, query, value, alias=None): 203 start, end = value 204 return query.filter(self.get_column(alias).between(start, end)) 205 206 207class DateNotBetweenFilter(DateBetweenFilter): 208 def apply(self, query, value, alias=None): 209 start, end = value 210 # ~between() isn't possible until sqlalchemy 1.0.0 211 return query.filter(not_(self.get_column(alias).between(start, end))) 212 213 def operation(self): 214 return lazy_gettext('not between') 215 216 217class DateTimeEqualFilter(FilterEqual, filters.BaseDateTimeFilter): 218 pass 219 220 221class DateTimeNotEqualFilter(FilterNotEqual, filters.BaseDateTimeFilter): 222 pass 223 224 225class DateTimeGreaterFilter(FilterGreater, filters.BaseDateTimeFilter): 226 pass 227 228 229class DateTimeSmallerFilter(FilterSmaller, filters.BaseDateTimeFilter): 230 pass 231 232 233class DateTimeBetweenFilter(BaseSQLAFilter, filters.BaseDateTimeBetweenFilter): 234 def __init__(self, column, name, options=None, data_type=None): 235 super(DateTimeBetweenFilter, self).__init__(column, 236 name, 237 options, 238 data_type='datetimerangepicker') 239 240 def apply(self, query, value, alias=None): 241 start, end = value 242 return query.filter(self.get_column(alias).between(start, end)) 243 244 245class DateTimeNotBetweenFilter(DateTimeBetweenFilter): 246 def apply(self, query, value, alias=None): 247 start, end = value 248 return query.filter(not_(self.get_column(alias).between(start, end))) 249 250 def operation(self): 251 return lazy_gettext('not between') 252 253 254class TimeEqualFilter(FilterEqual, filters.BaseTimeFilter): 255 pass 256 257 258class TimeNotEqualFilter(FilterNotEqual, filters.BaseTimeFilter): 259 pass 260 261 262class TimeGreaterFilter(FilterGreater, filters.BaseTimeFilter): 263 pass 264 265 266class TimeSmallerFilter(FilterSmaller, filters.BaseTimeFilter): 267 pass 268 269 270class TimeBetweenFilter(BaseSQLAFilter, filters.BaseTimeBetweenFilter): 271 def __init__(self, column, name, options=None, data_type=None): 272 super(TimeBetweenFilter, self).__init__(column, 273 name, 274 options, 275 data_type='timerangepicker') 276 277 def apply(self, query, value, alias=None): 278 start, end = value 279 return query.filter(self.get_column(alias).between(start, end)) 280 281 282class TimeNotBetweenFilter(TimeBetweenFilter): 283 def apply(self, query, value, alias=None): 284 start, end = value 285 return query.filter(not_(self.get_column(alias).between(start, end))) 286 287 def operation(self): 288 return lazy_gettext('not between') 289 290 291class EnumEqualFilter(FilterEqual): 292 def __init__(self, column, name, options=None, enum_class=None, **kwargs): 293 self.enum_class = enum_class 294 super(EnumEqualFilter, self).__init__(column, name, options, **kwargs) 295 296 def clean(self, value): 297 if self.enum_class is None: 298 return super(EnumEqualFilter, self).clean(value) 299 return self.enum_class(value) 300 301 302class EnumFilterNotEqual(FilterNotEqual): 303 def __init__(self, column, name, options=None, enum_class=None, **kwargs): 304 self.enum_class = enum_class 305 super(EnumFilterNotEqual, self).__init__(column, name, options, **kwargs) 306 307 def clean(self, value): 308 if self.enum_class is None: 309 return super(EnumFilterNotEqual, self).clean(value) 310 return self.enum_class(value) 311 312 313class EnumFilterEmpty(FilterEmpty): 314 def __init__(self, column, name, options=None, enum_class=None, **kwargs): 315 self.enum_class = enum_class 316 super(EnumFilterEmpty, self).__init__(column, name, options, **kwargs) 317 318 319class EnumFilterInList(FilterInList): 320 def __init__(self, column, name, options=None, enum_class=None, **kwargs): 321 self.enum_class = enum_class 322 super(EnumFilterInList, self).__init__(column, name, options, **kwargs) 323 324 def clean(self, value): 325 values = super(EnumFilterInList, self).clean(value) 326 if self.enum_class is not None: 327 values = [self.enum_class(val) for val in values] 328 return values 329 330 331class EnumFilterNotInList(FilterNotInList): 332 def __init__(self, column, name, options=None, enum_class=None, **kwargs): 333 self.enum_class = enum_class 334 super(EnumFilterNotInList, self).__init__(column, name, options, **kwargs) 335 336 def clean(self, value): 337 values = super(EnumFilterNotInList, self).clean(value) 338 if self.enum_class is not None: 339 values = [self.enum_class(val) for val in values] 340 return values 341 342 343class ChoiceTypeEqualFilter(FilterEqual): 344 def __init__(self, column, name, options=None, **kwargs): 345 super(ChoiceTypeEqualFilter, self).__init__(column, name, options, **kwargs) 346 347 def apply(self, query, user_query, alias=None): 348 column = self.get_column(alias) 349 choice_type = None 350 # loop through choice 'values' to try and find an exact match 351 if isinstance(column.type.choices, enum.EnumMeta): 352 for choice in column.type.choices: 353 if choice.name == user_query: 354 choice_type = choice.value 355 break 356 else: 357 for type, value in column.type.choices: 358 if value == user_query: 359 choice_type = type 360 break 361 if choice_type: 362 return query.filter(column == choice_type) 363 else: 364 return query.filter(column.in_([])) 365 366 367class ChoiceTypeNotEqualFilter(FilterNotEqual): 368 def __init__(self, column, name, options=None, **kwargs): 369 super(ChoiceTypeNotEqualFilter, self).__init__(column, name, options, **kwargs) 370 371 def apply(self, query, user_query, alias=None): 372 column = self.get_column(alias) 373 choice_type = None 374 # loop through choice 'values' to try and find an exact match 375 if isinstance(column.type.choices, enum.EnumMeta): 376 for choice in column.type.choices: 377 if choice.name == user_query: 378 choice_type = choice.value 379 break 380 else: 381 for type, value in column.type.choices: 382 if value == user_query: 383 choice_type = type 384 break 385 if choice_type: 386 # != can exclude NULL values, so "or_ == None" needed to be added 387 return query.filter(or_(column != choice_type, column == None)) # noqa: E711 388 else: 389 return query 390 391 392class ChoiceTypeLikeFilter(FilterLike): 393 def __init__(self, column, name, options=None, **kwargs): 394 super(ChoiceTypeLikeFilter, self).__init__(column, name, options, **kwargs) 395 396 def apply(self, query, user_query, alias=None): 397 column = self.get_column(alias) 398 choice_types = [] 399 if user_query: 400 # loop through choice 'values' looking for matches 401 if isinstance(column.type.choices, enum.EnumMeta): 402 for choice in column.type.choices: 403 if user_query.lower() in choice.name.lower(): 404 choice_types.append(choice.value) 405 else: 406 for type, value in column.type.choices: 407 if user_query.lower() in value.lower(): 408 choice_types.append(type) 409 if choice_types: 410 return query.filter(column.in_(choice_types)) 411 else: 412 return query 413 414 415class ChoiceTypeNotLikeFilter(FilterNotLike): 416 def __init__(self, column, name, options=None, **kwargs): 417 super(ChoiceTypeNotLikeFilter, self).__init__(column, name, options, **kwargs) 418 419 def apply(self, query, user_query, alias=None): 420 column = self.get_column(alias) 421 choice_types = [] 422 if user_query: 423 # loop through choice 'values' looking for matches 424 if isinstance(column.type.choices, enum.EnumMeta): 425 for choice in column.type.choices: 426 if user_query.lower() in choice.name.lower(): 427 choice_types.append(choice.value) 428 else: 429 for type, value in column.type.choices: 430 if user_query.lower() in value.lower(): 431 choice_types.append(type) 432 if choice_types: 433 # != can exclude NULL values, so "or_ == None" needed to be added 434 return query.filter(or_(column.notin_(choice_types), column == None)) # noqa: E711 435 else: 436 return query 437 438 439class UuidFilterEqual(FilterEqual, filters.BaseUuidFilter): 440 pass 441 442 443class UuidFilterNotEqual(FilterNotEqual, filters.BaseUuidFilter): 444 pass 445 446 447class UuidFilterInList(filters.BaseUuidListFilter, FilterInList): 448 pass 449 450 451class UuidFilterNotInList(filters.BaseUuidListFilter, FilterNotInList): 452 pass 453 454 455# Base SQLA filter field converter 456class FilterConverter(filters.BaseFilterConverter): 457 strings = (FilterLike, FilterNotLike, FilterEqual, FilterNotEqual, 458 FilterEmpty, FilterInList, FilterNotInList) 459 string_key_filters = (FilterEqual, FilterNotEqual, FilterEmpty, FilterInList, FilterNotInList) 460 int_filters = (IntEqualFilter, IntNotEqualFilter, IntGreaterFilter, 461 IntSmallerFilter, FilterEmpty, IntInListFilter, 462 IntNotInListFilter) 463 float_filters = (FloatEqualFilter, FloatNotEqualFilter, FloatGreaterFilter, 464 FloatSmallerFilter, FilterEmpty, FloatInListFilter, 465 FloatNotInListFilter) 466 bool_filters = (BooleanEqualFilter, BooleanNotEqualFilter) 467 enum = (EnumEqualFilter, EnumFilterNotEqual, EnumFilterEmpty, EnumFilterInList, 468 EnumFilterNotInList) 469 date_filters = (DateEqualFilter, DateNotEqualFilter, DateGreaterFilter, 470 DateSmallerFilter, DateBetweenFilter, DateNotBetweenFilter, 471 FilterEmpty) 472 datetime_filters = (DateTimeEqualFilter, DateTimeNotEqualFilter, 473 DateTimeGreaterFilter, DateTimeSmallerFilter, 474 DateTimeBetweenFilter, DateTimeNotBetweenFilter, 475 FilterEmpty) 476 time_filters = (TimeEqualFilter, TimeNotEqualFilter, TimeGreaterFilter, 477 TimeSmallerFilter, TimeBetweenFilter, TimeNotBetweenFilter, 478 FilterEmpty) 479 choice_type_filters = (ChoiceTypeEqualFilter, ChoiceTypeNotEqualFilter, 480 ChoiceTypeLikeFilter, ChoiceTypeNotLikeFilter, FilterEmpty) 481 uuid_filters = (UuidFilterEqual, UuidFilterNotEqual, FilterEmpty, 482 UuidFilterInList, UuidFilterNotInList) 483 arrow_type_filters = (DateTimeGreaterFilter, DateTimeSmallerFilter, FilterEmpty) 484 485 def convert(self, type_name, column, name, **kwargs): 486 filter_name = type_name.lower() 487 488 if filter_name in self.converters: 489 return self.converters[filter_name](column, name, **kwargs) 490 491 return None 492 493 @filters.convert('string', 'char', 'unicode', 'varchar', 'tinytext', 494 'text', 'mediumtext', 'longtext', 'unicodetext', 495 'nchar', 'nvarchar', 'ntext', 'citext', 'emailtype', 496 'URLType', 'IPAddressType') 497 def conv_string(self, column, name, **kwargs): 498 return [f(column, name, **kwargs) for f in self.strings] 499 500 @filters.convert('UUIDType', 'ColorType', 'TimezoneType', 'CurrencyType') 501 def conv_string_keys(self, column, name, **kwargs): 502 return [f(column, name, **kwargs) for f in self.string_key_filters] 503 504 @filters.convert('boolean', 'tinyint') 505 def conv_bool(self, column, name, **kwargs): 506 return [f(column, name, **kwargs) for f in self.bool_filters] 507 508 @filters.convert('int', 'integer', 'smallinteger', 'smallint', 509 'biginteger', 'bigint', 'mediumint') 510 def conv_int(self, column, name, **kwargs): 511 return [f(column, name, **kwargs) for f in self.int_filters] 512 513 @filters.convert('float', 'real', 'decimal', 'numeric', 'double_precision', 'double') 514 def conv_float(self, column, name, **kwargs): 515 return [f(column, name, **kwargs) for f in self.float_filters] 516 517 @filters.convert('date') 518 def conv_date(self, column, name, **kwargs): 519 return [f(column, name, **kwargs) for f in self.date_filters] 520 521 @filters.convert('datetime', 'datetime2', 'timestamp', 'smalldatetime') 522 def conv_datetime(self, column, name, **kwargs): 523 return [f(column, name, **kwargs) for f in self.datetime_filters] 524 525 @filters.convert('time') 526 def conv_time(self, column, name, **kwargs): 527 return [f(column, name, **kwargs) for f in self.time_filters] 528 529 @filters.convert('ChoiceType') 530 def conv_sqla_utils_choice(self, column, name, **kwargs): 531 return [f(column, name, **kwargs) for f in self.choice_type_filters] 532 533 @filters.convert('ArrowType') 534 def conv_sqla_utils_arrow(self, column, name, **kwargs): 535 return [f(column, name, **kwargs) for f in self.arrow_type_filters] 536 537 @filters.convert('enum') 538 def conv_enum(self, column, name, options=None, **kwargs): 539 if not options: 540 options = [ 541 (v, v) 542 for v in column.type.enums 543 ] 544 try: 545 from sqlalchemy_enum34 import EnumType 546 except ImportError: 547 pass 548 else: 549 if isinstance(column.type, EnumType): 550 kwargs['enum_class'] = column.type._enum_class 551 552 return [f(column, name, options, **kwargs) for f in self.enum] 553 554 @filters.convert('uuid') 555 def conv_uuid(self, column, name, **kwargs): 556 return [f(column, name, **kwargs) for f in self.uuid_filters] 557