1from dataclasses import dataclass, field
2from typing import Dict, List, Set, TYPE_CHECKING, Tuple, Type, Union
3
4from ormar.queryset.utils import get_relationship_alias_model_and_str
5
6if TYPE_CHECKING:  # pragma: no cover
7    from ormar import Model
8
9
10@dataclass
11class Excludable:
12    """
13    Class that keeps sets of fields to exclude and include
14    """
15
16    include: Set = field(default_factory=set)
17    exclude: Set = field(default_factory=set)
18
19    def get_copy(self) -> "Excludable":
20        """
21        Return copy of self to avoid in place modifications
22        :return: copy of self with copied sets
23        :rtype: ormar.models.excludable.Excludable
24        """
25        _copy = self.__class__()
26        _copy.include = {x for x in self.include}
27        _copy.exclude = {x for x in self.exclude}
28        return _copy
29
30    def set_values(self, value: Set, is_exclude: bool) -> None:
31        """
32        Appends the data to include/exclude sets.
33
34        :param value: set of values to add
35        :type value: set
36        :param is_exclude: flag if values are to be excluded or included
37        :type is_exclude: bool
38        """
39        prop = "exclude" if is_exclude else "include"
40        current_value = getattr(self, prop)
41        current_value.update(value)
42        setattr(self, prop, current_value)
43
44    def is_included(self, key: str) -> bool:
45        """
46        Check if field in included (in set or set is {...})
47        :param key: key to check
48        :type key: str
49        :return: result of the check
50        :rtype: bool
51        """
52        return (... in self.include or key in self.include) if self.include else True
53
54    def is_excluded(self, key: str) -> bool:
55        """
56        Check if field in excluded (in set or set is {...})
57        :param key: key to check
58        :type key: str
59        :return: result of the check
60        :rtype: bool
61        """
62        return (... in self.exclude or key in self.exclude) if self.exclude else False
63
64
65class ExcludableItems:
66    """
67    Keeps a dictionary of Excludables by alias + model_name keys
68    to allow quick lookup by nested models without need to travers
69    deeply nested dictionaries and passing include/exclude around
70    """
71
72    def __init__(self) -> None:
73        self.items: Dict[str, Excludable] = dict()
74
75    @classmethod
76    def from_excludable(cls, other: "ExcludableItems") -> "ExcludableItems":
77        """
78        Copy passed ExcludableItems to avoid inplace modifications.
79
80        :param other: other excludable items to be copied
81        :type other: ormar.models.excludable.ExcludableItems
82        :return: copy of other
83        :rtype: ormar.models.excludable.ExcludableItems
84        """
85        new_excludable = cls()
86        for key, value in other.items.items():
87            new_excludable.items[key] = value.get_copy()
88        return new_excludable
89
90    def include_entry_count(self) -> int:
91        """
92        Returns count of include items inside
93        """
94        count = 0
95        for key in self.items.keys():
96            count += len(self.items[key].include)
97        return count
98
99    def get(self, model_cls: Type["Model"], alias: str = "") -> Excludable:
100        """
101        Return Excludable for given model and alias.
102
103        :param model_cls: target model to check
104        :type model_cls: ormar.models.metaclass.ModelMetaclass
105        :param alias: table alias from relation manager
106        :type alias: str
107        :return: Excludable for given model and alias
108        :rtype: ormar.models.excludable.Excludable
109        """
110        key = f"{alias + '_' if alias else ''}{model_cls.get_name(lower=True)}"
111        excludable = self.items.get(key)
112        if not excludable:
113            excludable = Excludable()
114            self.items[key] = excludable
115        return excludable
116
117    def build(
118        self,
119        items: Union[List[str], str, Tuple[str], Set[str], Dict],
120        model_cls: Type["Model"],
121        is_exclude: bool = False,
122    ) -> None:
123        """
124        Receives the one of the types of items and parses them as to achieve
125        a end situation with one excludable per alias/model in relation.
126
127        Each excludable has two sets of values - one to include, one to exclude.
128
129        :param items: values to be included or excluded
130        :type items: Union[List[str], str, Tuple[str], Set[str], Dict]
131        :param model_cls: source model from which relations are constructed
132        :type model_cls: ormar.models.metaclass.ModelMetaclass
133        :param is_exclude: flag if items should be included or excluded
134        :type is_exclude: bool
135        """
136        if isinstance(items, str):
137            items = {items}
138
139        if isinstance(items, Dict):
140            self._traverse_dict(
141                values=items,
142                source_model=model_cls,
143                model_cls=model_cls,
144                is_exclude=is_exclude,
145            )
146
147        else:
148            items = set(items)
149            nested_items = set(x for x in items if "__" in x)
150            items.difference_update(nested_items)
151            self._set_excludes(
152                items=items,
153                model_name=model_cls.get_name(lower=True),
154                is_exclude=is_exclude,
155            )
156            if nested_items:
157                self._traverse_list(
158                    values=nested_items, model_cls=model_cls, is_exclude=is_exclude
159                )
160
161    def _set_excludes(
162        self, items: Set, model_name: str, is_exclude: bool, alias: str = ""
163    ) -> None:
164        """
165        Sets set of values to be included or excluded for given key and model.
166
167        :param items: items to include/exclude
168        :type items: set
169        :param model_name: name of model to construct key
170        :type model_name: str
171        :param is_exclude: flag if values should be included or excluded
172        :type is_exclude: bool
173        :param alias:
174        :type alias: str
175        """
176        key = f"{alias + '_' if alias else ''}{model_name}"
177        excludable = self.items.get(key)
178        if not excludable:
179            excludable = Excludable()
180        excludable.set_values(value=items, is_exclude=is_exclude)
181        self.items[key] = excludable
182
183    def _traverse_dict(  # noqa: CFQ002
184        self,
185        values: Dict,
186        source_model: Type["Model"],
187        model_cls: Type["Model"],
188        is_exclude: bool,
189        related_items: List = None,
190        alias: str = "",
191    ) -> None:
192        """
193        Goes through dict of nested values and construct/update Excludables.
194
195        :param values: items to include/exclude
196        :type values: Dict
197        :param source_model: source model from which relations are constructed
198        :type source_model: ormar.models.metaclass.ModelMetaclass
199        :param model_cls: model from which current relation is constructed
200        :type model_cls: ormar.models.metaclass.ModelMetaclass
201        :param is_exclude: flag if values should be included or excluded
202        :type is_exclude: bool
203        :param related_items: list of names of related fields chain
204        :type related_items: List
205        :param alias: alias of relation
206        :type alias: str
207        """
208        self_fields = set()
209        related_items = related_items[:] if related_items else []
210        for key, value in values.items():
211            if value is ...:
212                self_fields.add(key)
213            elif isinstance(value, set):
214                (
215                    table_prefix,
216                    target_model,
217                    _,
218                    _,
219                ) = get_relationship_alias_model_and_str(
220                    source_model=source_model, related_parts=related_items + [key]
221                )
222                self._set_excludes(
223                    items=value,
224                    model_name=target_model.get_name(),
225                    is_exclude=is_exclude,
226                    alias=table_prefix,
227                )
228            else:
229                # dict
230                related_items.append(key)
231                (
232                    table_prefix,
233                    target_model,
234                    _,
235                    _,
236                ) = get_relationship_alias_model_and_str(
237                    source_model=source_model, related_parts=related_items
238                )
239                self._traverse_dict(
240                    values=value,
241                    source_model=source_model,
242                    model_cls=target_model,
243                    is_exclude=is_exclude,
244                    related_items=related_items,
245                    alias=table_prefix,
246                )
247        if self_fields:
248            self._set_excludes(
249                items=self_fields,
250                model_name=model_cls.get_name(),
251                is_exclude=is_exclude,
252                alias=alias,
253            )
254
255    def _traverse_list(
256        self, values: Set[str], model_cls: Type["Model"], is_exclude: bool
257    ) -> None:
258        """
259        Goes through list of values and construct/update Excludables.
260
261        :param values: items to include/exclude
262        :type values: set
263        :param model_cls: model from which current relation is constructed
264        :type model_cls: ormar.models.metaclass.ModelMetaclass
265        :param is_exclude: flag if values should be included or excluded
266        :type is_exclude: bool
267        """
268        # here we have only nested related keys
269        for key in values:
270            key_split = key.split("__")
271            related_items, field_name = key_split[:-1], key_split[-1]
272            (table_prefix, target_model, _, _) = get_relationship_alias_model_and_str(
273                source_model=model_cls, related_parts=related_items
274            )
275            self._set_excludes(
276                items={field_name},
277                model_name=target_model.get_name(),
278                is_exclude=is_exclude,
279                alias=table_prefix,
280            )
281