1#  Licensed to Elasticsearch B.V. under one or more contributor
2#  license agreements. See the NOTICE file distributed with
3#  this work for additional information regarding copyright
4#  ownership. Elasticsearch B.V. licenses this file to you under
5#  the Apache License, Version 2.0 (the "License"); you may
6#  not use this file except in compliance with the License.
7#  You may obtain a copy of the License at
8#
9# 	http://www.apache.org/licenses/LICENSE-2.0
10#
11#  Unless required by applicable law or agreed to in writing,
12#  software distributed under the License is distributed on an
13#  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14#  KIND, either express or implied.  See the License for the
15#  specific language governing permissions and limitations
16#  under the License.
17
18from datetime import timedelta, datetime
19from six import iteritems, itervalues
20
21from .search import Search
22from .aggs import A
23from .utils import AttrDict
24from .response import Response
25from .query import Terms, Nested, Range, MatchAll
26
27__all__ = [
28    "FacetedSearch",
29    "HistogramFacet",
30    "TermsFacet",
31    "DateHistogramFacet",
32    "RangeFacet",
33    "NestedFacet",
34]
35
36
37class Facet(object):
38    """
39    A facet on faceted search. Wraps and aggregation and provides functionality
40    to create a filter for selected values and return a list of facet values
41    from the result of the aggregation.
42    """
43
44    agg_type = None
45
46    def __init__(self, metric=None, metric_sort="desc", **kwargs):
47        self.filter_values = ()
48        self._params = kwargs
49        self._metric = metric
50        if metric and metric_sort:
51            self._params["order"] = {"metric": metric_sort}
52
53    def get_aggregation(self):
54        """
55        Return the aggregation object.
56        """
57        agg = A(self.agg_type, **self._params)
58        if self._metric:
59            agg.metric("metric", self._metric)
60        return agg
61
62    def add_filter(self, filter_values):
63        """
64        Construct a filter.
65        """
66        if not filter_values:
67            return
68
69        f = self.get_value_filter(filter_values[0])
70        for v in filter_values[1:]:
71            f |= self.get_value_filter(v)
72        return f
73
74    def get_value_filter(self, filter_value):
75        """
76        Construct a filter for an individual value
77        """
78        pass
79
80    def is_filtered(self, key, filter_values):
81        """
82        Is a filter active on the given key.
83        """
84        return key in filter_values
85
86    def get_value(self, bucket):
87        """
88        return a value representing a bucket. Its key as default.
89        """
90        return bucket["key"]
91
92    def get_metric(self, bucket):
93        """
94        Return a metric, by default doc_count for a bucket.
95        """
96        if self._metric:
97            return bucket["metric"]["value"]
98        return bucket["doc_count"]
99
100    def get_values(self, data, filter_values):
101        """
102        Turn the raw bucket data into a list of tuples containing the key,
103        number of documents and a flag indicating whether this value has been
104        selected or not.
105        """
106        out = []
107        for bucket in data.buckets:
108            key = self.get_value(bucket)
109            out.append(
110                (key, self.get_metric(bucket), self.is_filtered(key, filter_values))
111            )
112        return out
113
114
115class TermsFacet(Facet):
116    agg_type = "terms"
117
118    def add_filter(self, filter_values):
119        """ Create a terms filter instead of bool containing term filters.  """
120        if filter_values:
121            return Terms(
122                _expand__to_dot=False, **{self._params["field"]: filter_values}
123            )
124
125
126class RangeFacet(Facet):
127    agg_type = "range"
128
129    def _range_to_dict(self, range):
130        key, range = range
131        out = {"key": key}
132        if range[0] is not None:
133            out["from"] = range[0]
134        if range[1] is not None:
135            out["to"] = range[1]
136        return out
137
138    def __init__(self, ranges, **kwargs):
139        super(RangeFacet, self).__init__(**kwargs)
140        self._params["ranges"] = list(map(self._range_to_dict, ranges))
141        self._params["keyed"] = False
142        self._ranges = dict(ranges)
143
144    def get_value_filter(self, filter_value):
145        f, t = self._ranges[filter_value]
146        limits = {}
147        if f is not None:
148            limits["gte"] = f
149        if t is not None:
150            limits["lt"] = t
151
152        return Range(_expand__to_dot=False, **{self._params["field"]: limits})
153
154
155class HistogramFacet(Facet):
156    agg_type = "histogram"
157
158    def get_value_filter(self, filter_value):
159        return Range(
160            _expand__to_dot=False,
161            **{
162                self._params["field"]: {
163                    "gte": filter_value,
164                    "lt": filter_value + self._params["interval"],
165                }
166            }
167        )
168
169
170class DateHistogramFacet(Facet):
171    agg_type = "date_histogram"
172
173    DATE_INTERVALS = {
174        "month": lambda d: (d + timedelta(days=32)).replace(day=1),
175        "week": lambda d: d + timedelta(days=7),
176        "day": lambda d: d + timedelta(days=1),
177        "hour": lambda d: d + timedelta(hours=1),
178    }
179
180    def __init__(self, **kwargs):
181        kwargs.setdefault("min_doc_count", 0)
182        super(DateHistogramFacet, self).__init__(**kwargs)
183
184    def get_value(self, bucket):
185        if not isinstance(bucket["key"], datetime):
186            # Elasticsearch returns key=None instead of 0 for date 1970-01-01,
187            # so we need to set key to 0 to avoid TypeError exception
188            if bucket["key"] is None:
189                bucket["key"] = 0
190            # Preserve milliseconds in the datetime
191            return datetime.utcfromtimestamp(int(bucket["key"]) / 1000.0)
192        else:
193            return bucket["key"]
194
195    def get_value_filter(self, filter_value):
196        return Range(
197            _expand__to_dot=False,
198            **{
199                self._params["field"]: {
200                    "gte": filter_value,
201                    "lt": self.DATE_INTERVALS[self._params["interval"]](filter_value),
202                }
203            }
204        )
205
206
207class NestedFacet(Facet):
208    agg_type = "nested"
209
210    def __init__(self, path, nested_facet):
211        self._path = path
212        self._inner = nested_facet
213        super(NestedFacet, self).__init__(
214            path=path, aggs={"inner": nested_facet.get_aggregation()}
215        )
216
217    def get_values(self, data, filter_values):
218        return self._inner.get_values(data.inner, filter_values)
219
220    def add_filter(self, filter_values):
221        inner_q = self._inner.add_filter(filter_values)
222        if inner_q:
223            return Nested(path=self._path, query=inner_q)
224
225
226class FacetedResponse(Response):
227    @property
228    def query_string(self):
229        return self._faceted_search._query
230
231    @property
232    def facets(self):
233        if not hasattr(self, "_facets"):
234            super(AttrDict, self).__setattr__("_facets", AttrDict({}))
235            for name, facet in iteritems(self._faceted_search.facets):
236                self._facets[name] = facet.get_values(
237                    getattr(getattr(self.aggregations, "_filter_" + name), name),
238                    self._faceted_search.filter_values.get(name, ()),
239                )
240        return self._facets
241
242
243class FacetedSearch(object):
244    """
245    Abstraction for creating faceted navigation searches that takes care of
246    composing the queries, aggregations and filters as needed as well as
247    presenting the results in an easy-to-consume fashion::
248
249        class BlogSearch(FacetedSearch):
250            index = 'blogs'
251            doc_types = [Blog, Post]
252            fields = ['title^5', 'category', 'description', 'body']
253
254            facets = {
255                'type': TermsFacet(field='_type'),
256                'category': TermsFacet(field='category'),
257                'weekly_posts': DateHistogramFacet(field='published_from', interval='week')
258            }
259
260            def search(self):
261                ' Override search to add your own filters '
262                s = super(BlogSearch, self).search()
263                return s.filter('term', published=True)
264
265        # when using:
266        blog_search = BlogSearch("web framework", filters={"category": "python"})
267
268        # supports pagination
269        blog_search[10:20]
270
271        response = blog_search.execute()
272
273        # easy access to aggregation results:
274        for category, hit_count, is_selected in response.facets.category:
275            print(
276                "Category %s has %d hits%s." % (
277                    category,
278                    hit_count,
279                    ' and is chosen' if is_selected else ''
280                )
281            )
282
283    """
284
285    index = None
286    doc_types = None
287    fields = None
288    facets = {}
289    using = "default"
290
291    def __init__(self, query=None, filters={}, sort=()):
292        """
293        :arg query: the text to search for
294        :arg filters: facet values to filter
295        :arg sort: sort information to be passed to :class:`~elasticsearch_dsl.Search`
296        """
297        self._query = query
298        self._filters = {}
299        self._sort = sort
300        self.filter_values = {}
301        for name, value in iteritems(filters):
302            self.add_filter(name, value)
303
304        self._s = self.build_search()
305
306    def count(self):
307        return self._s.count()
308
309    def __getitem__(self, k):
310        self._s = self._s[k]
311        return self
312
313    def __iter__(self):
314        return iter(self._s)
315
316    def add_filter(self, name, filter_values):
317        """
318        Add a filter for a facet.
319        """
320        # normalize the value into a list
321        if not isinstance(filter_values, (tuple, list)):
322            if filter_values is None:
323                return
324            filter_values = [
325                filter_values,
326            ]
327
328        # remember the filter values for use in FacetedResponse
329        self.filter_values[name] = filter_values
330
331        # get the filter from the facet
332        f = self.facets[name].add_filter(filter_values)
333        if f is None:
334            return
335
336        self._filters[name] = f
337
338    def search(self):
339        """
340        Returns the base Search object to which the facets are added.
341
342        You can customize the query by overriding this method and returning a
343        modified search object.
344        """
345        s = Search(doc_type=self.doc_types, index=self.index, using=self.using)
346        return s.response_class(FacetedResponse)
347
348    def query(self, search, query):
349        """
350        Add query part to ``search``.
351
352        Override this if you wish to customize the query used.
353        """
354        if query:
355            if self.fields:
356                return search.query("multi_match", fields=self.fields, query=query)
357            else:
358                return search.query("multi_match", query=query)
359        return search
360
361    def aggregate(self, search):
362        """
363        Add aggregations representing the facets selected, including potential
364        filters.
365        """
366        for f, facet in iteritems(self.facets):
367            agg = facet.get_aggregation()
368            agg_filter = MatchAll()
369            for field, filter in iteritems(self._filters):
370                if f == field:
371                    continue
372                agg_filter &= filter
373            search.aggs.bucket("_filter_" + f, "filter", filter=agg_filter).bucket(
374                f, agg
375            )
376
377    def filter(self, search):
378        """
379        Add a ``post_filter`` to the search request narrowing the results based
380        on the facet filters.
381        """
382        if not self._filters:
383            return search
384
385        post_filter = MatchAll()
386        for f in itervalues(self._filters):
387            post_filter &= f
388        return search.post_filter(post_filter)
389
390    def highlight(self, search):
391        """
392        Add highlighting for all the fields
393        """
394        return search.highlight(
395            *(f if "^" not in f else f.split("^", 1)[0] for f in self.fields)
396        )
397
398    def sort(self, search):
399        """
400        Add sorting information to the request.
401        """
402        if self._sort:
403            search = search.sort(*self._sort)
404        return search
405
406    def build_search(self):
407        """
408        Construct the ``Search`` object.
409        """
410        s = self.search()
411        s = self.query(s, self._query)
412        s = self.filter(s)
413        if self.fields:
414            s = self.highlight(s)
415        s = self.sort(s)
416        self.aggregate(s)
417        return s
418
419    def execute(self):
420        """
421        Execute the search and return the response.
422        """
423        r = self._s.execute()
424        r._faceted_search = self
425        return r
426