1from __future__ import absolute_import
2from __future__ import unicode_literals
3
4from collections import OrderedDict
5import copy
6
7from django.db.models.constants import LOOKUP_SEP
8from django.utils import six
9
10from django_filters import filterset, rest_framework
11from django_filters.utils import get_model_field
12
13from . import filters
14from . import utils
15
16
17class FilterSetMetaclass(filterset.FilterSetMetaclass):
18    def __new__(cls, name, bases, attrs):
19        new_class = super(FilterSetMetaclass, cls).__new__(cls, name, bases, attrs)
20
21        opts = copy.deepcopy(new_class._meta)
22        orig_meta = new_class._meta
23
24        declared_filters = new_class.declared_filters.copy()
25        orig_declared = new_class.declared_filters
26
27        # If no model is defined, skip auto filter processing
28        if not opts.model:
29            return new_class
30
31        # Generate filters for auto filters
32        auto_filters = OrderedDict([
33            (param, f) for param, f in six.iteritems(new_class.declared_filters)
34            if isinstance(f, filters.AutoFilter)
35        ])
36
37        # Remove auto filters from declared_filters so that they *are* overwritten
38        # RelatedFilter is an exception, and should *not* be overwritten
39        for param, f in six.iteritems(auto_filters):
40            if not isinstance(f, filters.RelatedFilter):
41                del declared_filters[param]
42
43        for param, f in six.iteritems(auto_filters):
44            opts.fields = {f.name: f.lookups or []}
45
46            # patch, generate auto filters
47            new_class._meta, new_class.declared_filters = opts, declared_filters
48            generated_filters = new_class.get_filters()
49
50            # get_filters() generates param names from the model field name
51            # Replace the field name with the parameter name from the filerset
52            new_class.base_filters.update(OrderedDict(
53                (gen_param.replace(f.name, param, 1), gen_f)
54                for gen_param, gen_f in six.iteritems(generated_filters)
55            ))
56
57        new_class._meta, new_class.declared_filters = orig_meta, orig_declared
58
59        return new_class
60
61    @property
62    def related_filters(self):
63        # check __dict__ instead of use hasattr. we *don't* want to check
64        # parents for existence of existing cache. eg, we do not want
65        # FilterSet.get_subset([...]) to return the same cache.
66        if '_related_filters' not in self.__dict__:
67            self._related_filters = OrderedDict([
68                (name, f) for name, f in six.iteritems(self.base_filters)
69                if isinstance(f, filters.RelatedFilter)
70            ])
71        return self._related_filters
72
73
74class FilterSet(six.with_metaclass(FilterSetMetaclass, rest_framework.FilterSet)):
75    _subset_cache = {}
76
77    @classmethod
78    def get_fields(cls):
79        fields = super(FilterSet, cls).get_fields()
80
81        for name, lookups in six.iteritems(fields):
82            if lookups == filters.ALL_LOOKUPS:
83                field = get_model_field(cls._meta.model, name)
84                fields[name] = utils.lookups_for_field(field)
85
86        return fields
87
88    def expand_filters(self):
89        """
90        Build a set of filters based on the requested data. The resulting set
91        will walk `RelatedFilter`s to recursively build the set of filters.
92        """
93        # build param data for related filters: {rel: {param: value}}
94        related_data = OrderedDict(
95            [(name, OrderedDict()) for name in self.__class__.related_filters]
96        )
97        for param, value in six.iteritems(self.data):
98            filter_name, related_param = self.get_related_filter_param(param)
99
100            # skip non lookup/related keys
101            if filter_name is None:
102                continue
103
104            if filter_name in related_data:
105                related_data[filter_name][related_param] = value
106
107        # build the compiled set of all filters
108        requested_filters = OrderedDict()
109        for filter_name, f in six.iteritems(self.filters):
110            exclude_name = '%s!' % filter_name
111
112            # Add plain lookup filters if match. ie, `username__icontains`
113            if filter_name in self.data:
114                requested_filters[filter_name] = f
115
116            # include exclusion keys
117            if exclude_name in self.data:
118                # deepcopy the *base* filter to prevent copying of model & parent
119                f_copy = copy.deepcopy(self.base_filters[filter_name])
120                f_copy.parent = f.parent
121                f_copy.model = f.model
122                f_copy.exclude = not f.exclude
123
124                requested_filters[exclude_name] = f_copy
125
126            # include filters from related subsets
127            if isinstance(f, filters.RelatedFilter) and filter_name in related_data:
128                subset_data = related_data[filter_name]
129                subset_class = f.filterset.get_subset(subset_data)
130                filterset = subset_class(data=subset_data, request=self.request)
131
132                # modify filter names to account for relationship
133                for related_name, related_f in six.iteritems(filterset.expand_filters()):
134                    related_name = LOOKUP_SEP.join([filter_name, related_name])
135                    related_f.name = LOOKUP_SEP.join([f.name, related_f.name])
136                    requested_filters[related_name] = related_f
137
138        return requested_filters
139
140    @classmethod
141    def get_param_filter_name(cls, param):
142        """
143        Get the filter name for the request data parameter.
144
145        ex::
146
147            # regular attribute filters
148            name = FilterSet.get_param_filter_name('email')
149            assert name == 'email'
150
151            # exclusion filters
152            name = FilterSet.get_param_filter_name('email!')
153            assert name == 'email'
154
155            # related filters
156            name = FilterSet.get_param_filter_name('author__email')
157            assert name == 'author'
158
159        """
160        # Attempt to match against filters with lookups first. (username__endswith)
161        if param in cls.base_filters:
162            return param
163
164        # Attempt to match against exclusion filters
165        if param[-1] == '!' and param[:-1] in cls.base_filters:
166            return param[:-1]
167
168        # Fallback to matching against relationships. (author__username__endswith).
169        related_filters = cls.related_filters.keys()
170
171        # preference more specific filters. eg, `note__author` over `note`.
172        for name in sorted(related_filters)[::-1]:
173            # we need to match against '__' to prevent eager matching against
174            # like names. eg, note vs note2. Exact matches are handled above.
175            if param.startswith("%s%s" % (name, LOOKUP_SEP)):
176                return name
177
178    @classmethod
179    def get_related_filter_param(cls, param):
180        """
181        Get a tuple of (filter name, related param).
182
183        ex::
184
185            name, param = FilterSet.get_filter_name('author__email__foobar')
186            assert name == 'author'
187            assert param = 'email__foobar'
188
189            name, param = FilterSet.get_filter_name('author')
190            assert name is None
191            assert param is None
192
193        """
194        related_filters = cls.related_filters.keys()
195
196        # preference more specific filters. eg, `note__author` over `note`.
197        for name in sorted(related_filters)[::-1]:
198            # we need to match against '__' to prevent eager matching against
199            # like names. eg, note vs note2. Exact matches are handled above.
200            if param.startswith("%s%s" % (name, LOOKUP_SEP)):
201                # strip param + LOOKUP_SET from param
202                related_param = param[len(name) + len(LOOKUP_SEP):]
203                return name, related_param
204
205        # not a related param
206        return None, None
207
208    @classmethod
209    def get_subset(cls, params):
210        """
211        Returns a FilterSubset class that contains the subset of filters
212        specified in the requested `params`. This is useful for creating
213        FilterSets that traverse relationships, as it helps to minimize
214        the deepcopy overhead incurred when instantiating the FilterSet.
215        """
216        # Determine names of filters from query params and remove empty values.
217        # param names that traverse relations are translated to just the local
218        # filter names. eg, `author__username` => `author`. Empty values are
219        # removed, as they indicate an unknown field eg, author__foobar__isnull
220        filter_names = [cls.get_param_filter_name(param) for param in params]
221        filter_names = [f for f in filter_names if f is not None]
222
223        # attempt to retrieve related filterset subset from the cache
224        key = cls.cache_key(filter_names)
225        subset_class = cls.cache_get(key)
226
227        # if no cached subset, then derive base_filters and create new subset
228        if subset_class is not None:
229            return subset_class
230
231        class FilterSubsetMetaclass(type(cls)):
232            def __new__(cls, name, bases, attrs):
233                new_class = super(FilterSubsetMetaclass, cls).__new__(cls, name, bases, attrs)
234                new_class.base_filters = OrderedDict([
235                    (param, f)
236                    for param, f in six.iteritems(new_class.base_filters)
237                    if param in filter_names
238                ])
239                return new_class
240
241        class FilterSubset(six.with_metaclass(FilterSubsetMetaclass, cls)):
242            pass
243
244        FilterSubset.__name__ = str('%sSubset' % (cls.__name__, ))
245        cls.cache_set(key, FilterSubset)
246        return FilterSubset
247
248    @classmethod
249    def cache_key(cls, filter_names):
250        return '%sSubset-%s' % (cls.__name__, '-'.join(sorted(filter_names)), )
251
252    @classmethod
253    def cache_get(cls, key):
254        return cls._subset_cache.get(key)
255
256    @classmethod
257    def cache_set(cls, key, value):
258        cls._subset_cache[key] = value
259
260    @property
261    def qs(self):
262        available_filters = self.filters
263        requested_filters = self.expand_filters()
264
265        self.filters = requested_filters
266        qs = super(FilterSet, self).qs
267        self.filters = available_filters
268
269        return qs
270