1from itertools import zip_longest 2 3from django.apps import apps 4from django.db import connections 5 6from wagtail.search.index import Indexed, RelatedFields, SearchField 7 8 9def get_postgresql_connections(): 10 return [connection for connection in connections.all() 11 if connection.vendor == 'postgresql'] 12 13 14def get_descendant_models(model): 15 """ 16 Returns all descendants of a model, including the model itself. 17 """ 18 descendant_models = {other_model for other_model in apps.get_models() 19 if issubclass(other_model, model)} 20 descendant_models.add(model) 21 return descendant_models 22 23 24def get_content_type_pk(model): 25 # We import it locally because this file is loaded before apps are ready. 26 from django.contrib.contenttypes.models import ContentType 27 return ContentType.objects.get_for_model(model).pk 28 29 30def get_ancestors_content_types_pks(model): 31 """ 32 Returns content types ids for the ancestors of this model, excluding it. 33 """ 34 from django.contrib.contenttypes.models import ContentType 35 return [ct.pk for ct in 36 ContentType.objects.get_for_models(*model._meta.get_parent_list()) 37 .values()] 38 39 40def get_descendants_content_types_pks(model): 41 """ 42 Returns content types ids for the descendants of this model, including it. 43 """ 44 from django.contrib.contenttypes.models import ContentType 45 return [ct.pk for ct in 46 ContentType.objects.get_for_models(*get_descendant_models(model)) 47 .values()] 48 49 50def get_search_fields(search_fields): 51 for search_field in search_fields: 52 if isinstance(search_field, SearchField): 53 yield search_field 54 elif isinstance(search_field, RelatedFields): 55 for sub_field in get_search_fields(search_field.fields): 56 yield sub_field 57 58 59WEIGHTS = 'ABCD' 60WEIGHTS_COUNT = len(WEIGHTS) 61# These are filled when apps are ready. 62BOOSTS_WEIGHTS = [] 63WEIGHTS_VALUES = [] 64 65 66def get_boosts(): 67 boosts = set() 68 for model in apps.get_models(): 69 if issubclass(model, Indexed): 70 for search_field in get_search_fields(model.get_search_fields()): 71 boost = search_field.boost 72 if boost is not None: 73 boosts.add(boost) 74 return boosts 75 76 77def determine_boosts_weights(boosts=()): 78 if not boosts: 79 boosts = get_boosts() 80 boosts = list(sorted(boosts, reverse=True)) 81 min_boost = boosts[-1] 82 if len(boosts) <= WEIGHTS_COUNT: 83 return list(zip_longest(boosts, WEIGHTS, fillvalue=min(min_boost, 0))) 84 max_boost = boosts[0] 85 boost_step = (max_boost - min_boost) / (WEIGHTS_COUNT - 1) 86 return [(max_boost - (i * boost_step), weight) 87 for i, weight in enumerate(WEIGHTS)] 88 89 90def set_weights(): 91 BOOSTS_WEIGHTS.extend(determine_boosts_weights()) 92 weights = [w for w, c in BOOSTS_WEIGHTS] 93 min_weight = min(weights) 94 if min_weight <= 0: 95 if min_weight == 0: 96 min_weight = -0.1 97 weights = [w - min_weight for w in weights] 98 max_weight = max(weights) 99 WEIGHTS_VALUES.extend([w / max_weight 100 for w in reversed(weights)]) 101 102 103def get_weight(boost): 104 if boost is None: 105 return WEIGHTS[-1] 106 for max_boost, weight in BOOSTS_WEIGHTS: 107 if boost >= max_boost: 108 return weight 109 return weight 110 111 112def get_sql_weights(): 113 return '{' + ','.join(map(str, WEIGHTS_VALUES)) + '}' 114