1from bson import DBRef, SON
2
3from mongoengine.base import (
4    BaseDict,
5    BaseList,
6    EmbeddedDocumentList,
7    TopLevelDocumentMetaclass,
8    get_document,
9)
10from mongoengine.base.datastructures import LazyReference
11from mongoengine.connection import get_db
12from mongoengine.document import Document, EmbeddedDocument
13from mongoengine.fields import DictField, ListField, MapField, ReferenceField
14from mongoengine.queryset import QuerySet
15
16
17class DeReference:
18    def __call__(self, items, max_depth=1, instance=None, name=None):
19        """
20        Cheaply dereferences the items to a set depth.
21        Also handles the conversion of complex data types.
22
23        :param items: The iterable (dict, list, queryset) to be dereferenced.
24        :param max_depth: The maximum depth to recurse to
25        :param instance: The owning instance used for tracking changes by
26            :class:`~mongoengine.base.ComplexBaseField`
27        :param name: The name of the field, used for tracking changes by
28            :class:`~mongoengine.base.ComplexBaseField`
29        :param get: A boolean determining if being called by __get__
30        """
31        if items is None or isinstance(items, str):
32            return items
33
34        # cheapest way to convert a queryset to a list
35        # list(queryset) uses a count() query to determine length
36        if isinstance(items, QuerySet):
37            items = [i for i in items]
38
39        self.max_depth = max_depth
40        doc_type = None
41
42        if instance and isinstance(
43            instance, (Document, EmbeddedDocument, TopLevelDocumentMetaclass)
44        ):
45            doc_type = instance._fields.get(name)
46            while hasattr(doc_type, "field"):
47                doc_type = doc_type.field
48
49            if isinstance(doc_type, ReferenceField):
50                field = doc_type
51                doc_type = doc_type.document_type
52                is_list = not hasattr(items, "items")
53
54                if is_list and all([i.__class__ == doc_type for i in items]):
55                    return items
56                elif not is_list and all(
57                    [i.__class__ == doc_type for i in items.values()]
58                ):
59                    return items
60                elif not field.dbref:
61                    # We must turn the ObjectIds into DBRefs
62
63                    # Recursively dig into the sub items of a list/dict
64                    # to turn the ObjectIds into DBRefs
65                    def _get_items_from_list(items):
66                        new_items = []
67                        for v in items:
68                            value = v
69                            if isinstance(v, dict):
70                                value = _get_items_from_dict(v)
71                            elif isinstance(v, list):
72                                value = _get_items_from_list(v)
73                            elif not isinstance(v, (DBRef, Document)):
74                                value = field.to_python(v)
75                            new_items.append(value)
76                        return new_items
77
78                    def _get_items_from_dict(items):
79                        new_items = {}
80                        for k, v in items.items():
81                            value = v
82                            if isinstance(v, list):
83                                value = _get_items_from_list(v)
84                            elif isinstance(v, dict):
85                                value = _get_items_from_dict(v)
86                            elif not isinstance(v, (DBRef, Document)):
87                                value = field.to_python(v)
88                            new_items[k] = value
89                        return new_items
90
91                    if not hasattr(items, "items"):
92                        items = _get_items_from_list(items)
93                    else:
94                        items = _get_items_from_dict(items)
95
96        self.reference_map = self._find_references(items)
97        self.object_map = self._fetch_objects(doc_type=doc_type)
98        return self._attach_objects(items, 0, instance, name)
99
100    def _find_references(self, items, depth=0):
101        """
102        Recursively finds all db references to be dereferenced
103
104        :param items: The iterable (dict, list, queryset)
105        :param depth: The current depth of recursion
106        """
107        reference_map = {}
108        if not items or depth >= self.max_depth:
109            return reference_map
110
111        # Determine the iterator to use
112        if isinstance(items, dict):
113            iterator = items.values()
114        else:
115            iterator = items
116
117        # Recursively find dbreferences
118        depth += 1
119        for item in iterator:
120            if isinstance(item, (Document, EmbeddedDocument)):
121                for field_name, field in item._fields.items():
122                    v = item._data.get(field_name, None)
123                    if isinstance(v, LazyReference):
124                        # LazyReference inherits DBRef but should not be dereferenced here !
125                        continue
126                    elif isinstance(v, DBRef):
127                        reference_map.setdefault(field.document_type, set()).add(v.id)
128                    elif isinstance(v, (dict, SON)) and "_ref" in v:
129                        reference_map.setdefault(get_document(v["_cls"]), set()).add(
130                            v["_ref"].id
131                        )
132                    elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
133                        field_cls = getattr(
134                            getattr(field, "field", None), "document_type", None
135                        )
136                        references = self._find_references(v, depth)
137                        for key, refs in references.items():
138                            if isinstance(
139                                field_cls, (Document, TopLevelDocumentMetaclass)
140                            ):
141                                key = field_cls
142                            reference_map.setdefault(key, set()).update(refs)
143            elif isinstance(item, LazyReference):
144                # LazyReference inherits DBRef but should not be dereferenced here !
145                continue
146            elif isinstance(item, DBRef):
147                reference_map.setdefault(item.collection, set()).add(item.id)
148            elif isinstance(item, (dict, SON)) and "_ref" in item:
149                reference_map.setdefault(get_document(item["_cls"]), set()).add(
150                    item["_ref"].id
151                )
152            elif isinstance(item, (dict, list, tuple)) and depth - 1 <= self.max_depth:
153                references = self._find_references(item, depth - 1)
154                for key, refs in references.items():
155                    reference_map.setdefault(key, set()).update(refs)
156
157        return reference_map
158
159    def _fetch_objects(self, doc_type=None):
160        """Fetch all references and convert to their document objects
161        """
162        object_map = {}
163        for collection, dbrefs in self.reference_map.items():
164
165            # we use getattr instead of hasattr because hasattr swallows any exception under python2
166            # so it could hide nasty things without raising exceptions (cfr bug #1688))
167            ref_document_cls_exists = getattr(collection, "objects", None) is not None
168
169            if ref_document_cls_exists:
170                col_name = collection._get_collection_name()
171                refs = [
172                    dbref for dbref in dbrefs if (col_name, dbref) not in object_map
173                ]
174                references = collection.objects.in_bulk(refs)
175                for key, doc in references.items():
176                    object_map[(col_name, key)] = doc
177            else:  # Generic reference: use the refs data to convert to document
178                if isinstance(doc_type, (ListField, DictField, MapField)):
179                    continue
180
181                refs = [
182                    dbref for dbref in dbrefs if (collection, dbref) not in object_map
183                ]
184
185                if doc_type:
186                    references = doc_type._get_db()[collection].find(
187                        {"_id": {"$in": refs}}
188                    )
189                    for ref in references:
190                        doc = doc_type._from_son(ref)
191                        object_map[(collection, doc.id)] = doc
192                else:
193                    references = get_db()[collection].find({"_id": {"$in": refs}})
194                    for ref in references:
195                        if "_cls" in ref:
196                            doc = get_document(ref["_cls"])._from_son(ref)
197                        elif doc_type is None:
198                            doc = get_document(
199                                "".join(x.capitalize() for x in collection.split("_"))
200                            )._from_son(ref)
201                        else:
202                            doc = doc_type._from_son(ref)
203                        object_map[(collection, doc.id)] = doc
204        return object_map
205
206    def _attach_objects(self, items, depth=0, instance=None, name=None):
207        """
208        Recursively finds all db references to be dereferenced
209
210        :param items: The iterable (dict, list, queryset)
211        :param depth: The current depth of recursion
212        :param instance: The owning instance used for tracking changes by
213            :class:`~mongoengine.base.ComplexBaseField`
214        :param name: The name of the field, used for tracking changes by
215            :class:`~mongoengine.base.ComplexBaseField`
216        """
217        if not items:
218            if isinstance(items, (BaseDict, BaseList)):
219                return items
220
221            if instance:
222                if isinstance(items, dict):
223                    return BaseDict(items, instance, name)
224                else:
225                    return BaseList(items, instance, name)
226
227        if isinstance(items, (dict, SON)):
228            if "_ref" in items:
229                return self.object_map.get(
230                    (items["_ref"].collection, items["_ref"].id), items
231                )
232            elif "_cls" in items:
233                doc = get_document(items["_cls"])._from_son(items)
234                _cls = doc._data.pop("_cls", None)
235                del items["_cls"]
236                doc._data = self._attach_objects(doc._data, depth, doc, None)
237                if _cls is not None:
238                    doc._data["_cls"] = _cls
239                return doc
240
241        if not hasattr(items, "items"):
242            is_list = True
243            list_type = BaseList
244            if isinstance(items, EmbeddedDocumentList):
245                list_type = EmbeddedDocumentList
246            as_tuple = isinstance(items, tuple)
247            iterator = enumerate(items)
248            data = []
249        else:
250            is_list = False
251            iterator = items.items()
252            data = {}
253
254        depth += 1
255        for k, v in iterator:
256            if is_list:
257                data.append(v)
258            else:
259                data[k] = v
260
261            if k in self.object_map and not is_list:
262                data[k] = self.object_map[k]
263            elif isinstance(v, (Document, EmbeddedDocument)):
264                for field_name in v._fields:
265                    v = data[k]._data.get(field_name, None)
266                    if isinstance(v, DBRef):
267                        data[k]._data[field_name] = self.object_map.get(
268                            (v.collection, v.id), v
269                        )
270                    elif isinstance(v, (dict, SON)) and "_ref" in v:
271                        data[k]._data[field_name] = self.object_map.get(
272                            (v["_ref"].collection, v["_ref"].id), v
273                        )
274                    elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
275                        item_name = "{}.{}.{}".format(name, k, field_name)
276                        data[k]._data[field_name] = self._attach_objects(
277                            v, depth, instance=instance, name=item_name
278                        )
279            elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
280                item_name = "{}.{}".format(name, k) if name else name
281                data[k] = self._attach_objects(
282                    v, depth - 1, instance=instance, name=item_name
283                )
284            elif isinstance(v, DBRef) and hasattr(v, "id"):
285                data[k] = self.object_map.get((v.collection, v.id), v)
286
287        if instance and name:
288            if is_list:
289                return tuple(data) if as_tuple else list_type(data, instance, name)
290            return BaseDict(data, instance, name)
291        depth += 1
292        return data
293