1import copy
2from collections import OrderedDict
3
4from django import forms
5from django.db import models
6from django.db.models.constants import LOOKUP_SEP
7from django.db.models.fields.related import (
8    ManyToManyRel,
9    ManyToOneRel,
10    OneToOneRel
11)
12
13from .conf import settings
14from .constants import ALL_FIELDS
15from .filters import (
16    BaseInFilter,
17    BaseRangeFilter,
18    BooleanFilter,
19    CharFilter,
20    ChoiceFilter,
21    DateFilter,
22    DateTimeFilter,
23    DurationFilter,
24    Filter,
25    ModelChoiceFilter,
26    ModelMultipleChoiceFilter,
27    NumberFilter,
28    TimeFilter,
29    UUIDFilter
30)
31from .utils import (
32    get_all_model_fields,
33    get_model_field,
34    resolve_field,
35    try_dbfield
36)
37
38
39def remote_queryset(field):
40    """
41    Get the queryset for the other side of a relationship. This works
42    for both `RelatedField`s and `ForeignObjectRel`s.
43    """
44    model = field.related_model
45
46    # Reverse relationships do not have choice limits
47    if not hasattr(field, 'get_limit_choices_to'):
48        return model._default_manager.all()
49
50    limit_choices_to = field.get_limit_choices_to()
51    return model._default_manager.complex_filter(limit_choices_to)
52
53
54class FilterSetOptions:
55    def __init__(self, options=None):
56        self.model = getattr(options, 'model', None)
57        self.fields = getattr(options, 'fields', None)
58        self.exclude = getattr(options, 'exclude', None)
59
60        self.filter_overrides = getattr(options, 'filter_overrides', {})
61
62        self.form = getattr(options, 'form', forms.Form)
63
64
65class FilterSetMetaclass(type):
66    def __new__(cls, name, bases, attrs):
67        attrs['declared_filters'] = cls.get_declared_filters(bases, attrs)
68
69        new_class = super().__new__(cls, name, bases, attrs)
70        new_class._meta = FilterSetOptions(getattr(new_class, 'Meta', None))
71        new_class.base_filters = new_class.get_filters()
72
73        # TODO: remove assertion in 2.1
74        assert not hasattr(new_class, 'filter_for_reverse_field'), (
75            "`%(cls)s.filter_for_reverse_field` has been removed. "
76            "`%(cls)s.filter_for_field` now generates filters for reverse fields. "
77            "See: https://django-filter.readthedocs.io/en/main/guide/migration.html"
78            % {'cls': new_class.__name__}
79        )
80
81        return new_class
82
83    @classmethod
84    def get_declared_filters(cls, bases, attrs):
85        filters = [
86            (filter_name, attrs.pop(filter_name))
87            for filter_name, obj in list(attrs.items())
88            if isinstance(obj, Filter)
89        ]
90
91        # Default the `filter.field_name` to the attribute name on the filterset
92        for filter_name, f in filters:
93            if getattr(f, 'field_name', None) is None:
94                f.field_name = filter_name
95
96        filters.sort(key=lambda x: x[1].creation_counter)
97
98        # Ensures a base class field doesn't override cls attrs, and maintains
99        # field precedence when inheriting multiple parents. e.g. if there is a
100        # class C(A, B), and A and B both define 'field', use 'field' from A.
101        known = set(attrs)
102
103        def visit(name):
104            known.add(name)
105            return name
106
107        base_filters = [
108            (visit(name), f)
109            for base in bases if hasattr(base, 'declared_filters')
110            for name, f in base.declared_filters.items() if name not in known
111        ]
112
113        return OrderedDict(base_filters + filters)
114
115
116FILTER_FOR_DBFIELD_DEFAULTS = {
117    models.AutoField:                   {'filter_class': NumberFilter},
118    models.CharField:                   {'filter_class': CharFilter},
119    models.TextField:                   {'filter_class': CharFilter},
120    models.BooleanField:                {'filter_class': BooleanFilter},
121    models.DateField:                   {'filter_class': DateFilter},
122    models.DateTimeField:               {'filter_class': DateTimeFilter},
123    models.TimeField:                   {'filter_class': TimeFilter},
124    models.DurationField:               {'filter_class': DurationFilter},
125    models.DecimalField:                {'filter_class': NumberFilter},
126    models.SmallIntegerField:           {'filter_class': NumberFilter},
127    models.IntegerField:                {'filter_class': NumberFilter},
128    models.PositiveIntegerField:        {'filter_class': NumberFilter},
129    models.PositiveSmallIntegerField:   {'filter_class': NumberFilter},
130    models.FloatField:                  {'filter_class': NumberFilter},
131    models.NullBooleanField:            {'filter_class': BooleanFilter},
132    models.SlugField:                   {'filter_class': CharFilter},
133    models.EmailField:                  {'filter_class': CharFilter},
134    models.FilePathField:               {'filter_class': CharFilter},
135    models.URLField:                    {'filter_class': CharFilter},
136    models.GenericIPAddressField:       {'filter_class': CharFilter},
137    models.CommaSeparatedIntegerField:  {'filter_class': CharFilter},
138    models.UUIDField:                   {'filter_class': UUIDFilter},
139
140    # Forward relationships
141    models.OneToOneField: {
142        'filter_class': ModelChoiceFilter,
143        'extra': lambda f: {
144            'queryset': remote_queryset(f),
145            'to_field_name': f.remote_field.field_name,
146            'null_label': settings.NULL_CHOICE_LABEL if f.null else None,
147        }
148    },
149    models.ForeignKey: {
150        'filter_class': ModelChoiceFilter,
151        'extra': lambda f: {
152            'queryset': remote_queryset(f),
153            'to_field_name': f.remote_field.field_name,
154            'null_label': settings.NULL_CHOICE_LABEL if f.null else None,
155        }
156    },
157    models.ManyToManyField: {
158        'filter_class': ModelMultipleChoiceFilter,
159        'extra': lambda f: {
160            'queryset': remote_queryset(f),
161        }
162    },
163
164    # Reverse relationships
165    OneToOneRel: {
166        'filter_class': ModelChoiceFilter,
167        'extra': lambda f: {
168            'queryset': remote_queryset(f),
169            'null_label': settings.NULL_CHOICE_LABEL if f.null else None,
170        }
171    },
172    ManyToOneRel: {
173        'filter_class': ModelMultipleChoiceFilter,
174        'extra': lambda f: {
175            'queryset': remote_queryset(f),
176        }
177    },
178    ManyToManyRel: {
179        'filter_class': ModelMultipleChoiceFilter,
180        'extra': lambda f: {
181            'queryset': remote_queryset(f),
182        }
183    },
184}
185
186
187class BaseFilterSet:
188    FILTER_DEFAULTS = FILTER_FOR_DBFIELD_DEFAULTS
189
190    def __init__(self, data=None, queryset=None, *, request=None, prefix=None):
191        if queryset is None:
192            queryset = self._meta.model._default_manager.all()
193        model = queryset.model
194
195        self.is_bound = data is not None
196        self.data = data or {}
197        self.queryset = queryset
198        self.request = request
199        self.form_prefix = prefix
200
201        self.filters = copy.deepcopy(self.base_filters)
202
203        # propagate the model and filterset to the filters
204        for filter_ in self.filters.values():
205            filter_.model = model
206            filter_.parent = self
207
208    def is_valid(self):
209        """
210        Return True if the underlying form has no errors, or False otherwise.
211        """
212        return self.is_bound and self.form.is_valid()
213
214    @property
215    def errors(self):
216        """
217        Return an ErrorDict for the data provided for the underlying form.
218        """
219        return self.form.errors
220
221    def filter_queryset(self, queryset):
222        """
223        Filter the queryset with the underlying form's `cleaned_data`. You must
224        call `is_valid()` or `errors` before calling this method.
225
226        This method should be overridden if additional filtering needs to be
227        applied to the queryset before it is cached.
228        """
229        for name, value in self.form.cleaned_data.items():
230            queryset = self.filters[name].filter(queryset, value)
231            assert isinstance(queryset, models.QuerySet), \
232                "Expected '%s.%s' to return a QuerySet, but got a %s instead." \
233                % (type(self).__name__, name, type(queryset).__name__)
234        return queryset
235
236    @property
237    def qs(self):
238        if not hasattr(self, '_qs'):
239            qs = self.queryset.all()
240            if self.is_bound:
241                # ensure form validation before filtering
242                self.errors
243                qs = self.filter_queryset(qs)
244            self._qs = qs
245        return self._qs
246
247    def get_form_class(self):
248        """
249        Returns a django Form suitable of validating the filterset data.
250
251        This method should be overridden if the form class needs to be
252        customized relative to the filterset instance.
253        """
254        fields = OrderedDict([
255            (name, filter_.field)
256            for name, filter_ in self.filters.items()])
257
258        return type(str('%sForm' % self.__class__.__name__),
259                    (self._meta.form,), fields)
260
261    @property
262    def form(self):
263        if not hasattr(self, '_form'):
264            Form = self.get_form_class()
265            if self.is_bound:
266                self._form = Form(self.data, prefix=self.form_prefix)
267            else:
268                self._form = Form(prefix=self.form_prefix)
269        return self._form
270
271    @classmethod
272    def get_fields(cls):
273        """
274        Resolve the 'fields' argument that should be used for generating filters on the
275        filterset. This is 'Meta.fields' sans the fields in 'Meta.exclude'.
276        """
277        model = cls._meta.model
278        fields = cls._meta.fields
279        exclude = cls._meta.exclude
280
281        assert not (fields is None and exclude is None), \
282            "Setting 'Meta.model' without either 'Meta.fields' or 'Meta.exclude' " \
283            "has been deprecated since 0.15.0 and is now disallowed. Add an explicit " \
284            "'Meta.fields' or 'Meta.exclude' to the %s class." % cls.__name__
285
286        # Setting exclude with no fields implies all other fields.
287        if exclude is not None and fields is None:
288            fields = ALL_FIELDS
289
290        # Resolve ALL_FIELDS into all fields for the filterset's model.
291        if fields == ALL_FIELDS:
292            fields = get_all_model_fields(model)
293
294        # Remove excluded fields
295        exclude = exclude or []
296        if not isinstance(fields, dict):
297            fields = [(f, [settings.DEFAULT_LOOKUP_EXPR]) for f in fields if f not in exclude]
298        else:
299            fields = [(f, lookups) for f, lookups in fields.items() if f not in exclude]
300
301        return OrderedDict(fields)
302
303    @classmethod
304    def get_filter_name(cls, field_name, lookup_expr):
305        """
306        Combine a field name and lookup expression into a usable filter name.
307        Exact lookups are the implicit default, so "exact" is stripped from the
308        end of the filter name.
309        """
310        filter_name = LOOKUP_SEP.join([field_name, lookup_expr])
311
312        # This also works with transformed exact lookups, such as 'date__exact'
313        _default_expr = LOOKUP_SEP + settings.DEFAULT_LOOKUP_EXPR
314        if filter_name.endswith(_default_expr):
315            filter_name = filter_name[:-len(_default_expr)]
316
317        return filter_name
318
319    @classmethod
320    def get_filters(cls):
321        """
322        Get all filters for the filterset. This is the combination of declared and
323        generated filters.
324        """
325
326        # No model specified - skip filter generation
327        if not cls._meta.model:
328            return cls.declared_filters.copy()
329
330        # Determine the filters that should be included on the filterset.
331        filters = OrderedDict()
332        fields = cls.get_fields()
333        undefined = []
334
335        for field_name, lookups in fields.items():
336            field = get_model_field(cls._meta.model, field_name)
337
338            # warn if the field doesn't exist.
339            if field is None:
340                undefined.append(field_name)
341
342            for lookup_expr in lookups:
343                filter_name = cls.get_filter_name(field_name, lookup_expr)
344
345                # If the filter is explicitly declared on the class, skip generation
346                if filter_name in cls.declared_filters:
347                    filters[filter_name] = cls.declared_filters[filter_name]
348                    continue
349
350                if field is not None:
351                    filters[filter_name] = cls.filter_for_field(field, field_name, lookup_expr)
352
353        # Allow Meta.fields to contain declared filters *only* when a list/tuple
354        if isinstance(cls._meta.fields, (list, tuple)):
355            undefined = [f for f in undefined if f not in cls.declared_filters]
356
357        if undefined:
358            raise TypeError(
359                "'Meta.fields' must not contain non-model field names: %s"
360                % ', '.join(undefined)
361            )
362
363        # Add in declared filters. This is necessary since we don't enforce adding
364        # declared filters to the 'Meta.fields' option
365        filters.update(cls.declared_filters)
366        return filters
367
368    @classmethod
369    def filter_for_field(cls, field, field_name, lookup_expr=None):
370        if lookup_expr is None:
371            lookup_expr = settings.DEFAULT_LOOKUP_EXPR
372        field, lookup_type = resolve_field(field, lookup_expr)
373
374        default = {
375            'field_name': field_name,
376            'lookup_expr': lookup_expr,
377        }
378
379        filter_class, params = cls.filter_for_lookup(field, lookup_type)
380        default.update(params)
381
382        assert filter_class is not None, (
383            "%s resolved field '%s' with '%s' lookup to an unrecognized field "
384            "type %s. Try adding an override to 'Meta.filter_overrides'. See: "
385            "https://django-filter.readthedocs.io/en/main/ref/filterset.html"
386            "#customise-filter-generation-with-filter-overrides"
387        ) % (cls.__name__, field_name, lookup_expr, field.__class__.__name__)
388
389        return filter_class(**default)
390
391    @classmethod
392    def filter_for_lookup(cls, field, lookup_type):
393        DEFAULTS = dict(cls.FILTER_DEFAULTS)
394        if hasattr(cls, '_meta'):
395            DEFAULTS.update(cls._meta.filter_overrides)
396
397        data = try_dbfield(DEFAULTS.get, field.__class__) or {}
398        filter_class = data.get('filter_class')
399        params = data.get('extra', lambda field: {})(field)
400
401        # if there is no filter class, exit early
402        if not filter_class:
403            return None, {}
404
405        # perform lookup specific checks
406        if lookup_type == 'exact' and getattr(field, 'choices', None):
407            return ChoiceFilter, {'choices': field.choices}
408
409        if lookup_type == 'isnull':
410            data = try_dbfield(DEFAULTS.get, models.BooleanField)
411
412            filter_class = data.get('filter_class')
413            params = data.get('extra', lambda field: {})(field)
414            return filter_class, params
415
416        if lookup_type == 'in':
417            class ConcreteInFilter(BaseInFilter, filter_class):
418                pass
419            ConcreteInFilter.__name__ = cls._csv_filter_class_name(
420                filter_class, lookup_type
421            )
422
423            return ConcreteInFilter, params
424
425        if lookup_type == 'range':
426            class ConcreteRangeFilter(BaseRangeFilter, filter_class):
427                pass
428            ConcreteRangeFilter.__name__ = cls._csv_filter_class_name(
429                filter_class, lookup_type
430            )
431
432            return ConcreteRangeFilter, params
433
434        return filter_class, params
435
436    @classmethod
437    def _csv_filter_class_name(cls, filter_class, lookup_type):
438        """
439        Generate a suitable class name for a concrete filter class. This is not
440        completely reliable, as not all filter class names are of the format
441        <Type>Filter.
442
443        ex::
444
445            FilterSet._csv_filter_class_name(DateTimeFilter, 'in')
446
447            returns 'DateTimeInFilter'
448
449        """
450        # DateTimeFilter => DateTime
451        type_name = filter_class.__name__
452        if type_name.endswith('Filter'):
453            type_name = type_name[:-6]
454
455        # in => In
456        lookup_name = lookup_type.capitalize()
457
458        # DateTimeInFilter
459        return str('%s%sFilter' % (type_name, lookup_name))
460
461
462class FilterSet(BaseFilterSet, metaclass=FilterSetMetaclass):
463    pass
464
465
466def filterset_factory(model, fields=ALL_FIELDS):
467    meta = type(str('Meta'), (object,), {'model': model, 'fields': fields})
468    filterset = type(str('%sFilterSet' % model._meta.object_name),
469                     (FilterSet,), {'Meta': meta})
470    return filterset
471