1from itertools import zip_longest
2
3from django.apps import apps
4from django.db import connections
5
6from wagtail.search.index import Indexed, RelatedFields, SearchField
7
8
9def get_postgresql_connections():
10    return [connection for connection in connections.all()
11            if connection.vendor == 'postgresql']
12
13
14def get_descendant_models(model):
15    """
16    Returns all descendants of a model, including the model itself.
17    """
18    descendant_models = {other_model for other_model in apps.get_models()
19                         if issubclass(other_model, model)}
20    descendant_models.add(model)
21    return descendant_models
22
23
24def get_content_type_pk(model):
25    # We import it locally because this file is loaded before apps are ready.
26    from django.contrib.contenttypes.models import ContentType
27    return ContentType.objects.get_for_model(model).pk
28
29
30def get_ancestors_content_types_pks(model):
31    """
32    Returns content types ids for the ancestors of this model, excluding it.
33    """
34    from django.contrib.contenttypes.models import ContentType
35    return [ct.pk for ct in
36            ContentType.objects.get_for_models(*model._meta.get_parent_list())
37            .values()]
38
39
40def get_descendants_content_types_pks(model):
41    """
42    Returns content types ids for the descendants of this model, including it.
43    """
44    from django.contrib.contenttypes.models import ContentType
45    return [ct.pk for ct in
46            ContentType.objects.get_for_models(*get_descendant_models(model))
47            .values()]
48
49
50def get_search_fields(search_fields):
51    for search_field in search_fields:
52        if isinstance(search_field, SearchField):
53            yield search_field
54        elif isinstance(search_field, RelatedFields):
55            for sub_field in get_search_fields(search_field.fields):
56                yield sub_field
57
58
59WEIGHTS = 'ABCD'
60WEIGHTS_COUNT = len(WEIGHTS)
61# These are filled when apps are ready.
62BOOSTS_WEIGHTS = []
63WEIGHTS_VALUES = []
64
65
66def get_boosts():
67    boosts = set()
68    for model in apps.get_models():
69        if issubclass(model, Indexed):
70            for search_field in get_search_fields(model.get_search_fields()):
71                boost = search_field.boost
72                if boost is not None:
73                    boosts.add(boost)
74    return boosts
75
76
77def determine_boosts_weights(boosts=()):
78    if not boosts:
79        boosts = get_boosts()
80    boosts = list(sorted(boosts, reverse=True))
81    min_boost = boosts[-1]
82    if len(boosts) <= WEIGHTS_COUNT:
83        return list(zip_longest(boosts, WEIGHTS, fillvalue=min(min_boost, 0)))
84    max_boost = boosts[0]
85    boost_step = (max_boost - min_boost) / (WEIGHTS_COUNT - 1)
86    return [(max_boost - (i * boost_step), weight)
87            for i, weight in enumerate(WEIGHTS)]
88
89
90def set_weights():
91    BOOSTS_WEIGHTS.extend(determine_boosts_weights())
92    weights = [w for w, c in BOOSTS_WEIGHTS]
93    min_weight = min(weights)
94    if min_weight <= 0:
95        if min_weight == 0:
96            min_weight = -0.1
97        weights = [w - min_weight for w in weights]
98    max_weight = max(weights)
99    WEIGHTS_VALUES.extend([w / max_weight
100                           for w in reversed(weights)])
101
102
103def get_weight(boost):
104    if boost is None:
105        return WEIGHTS[-1]
106    for max_boost, weight in BOOSTS_WEIGHTS:
107        if boost >= max_boost:
108            return weight
109    return weight
110
111
112def get_sql_weights():
113    return '{' + ','.join(map(str, WEIGHTS_VALUES)) + '}'
114