1from contextlib import contextmanager
2
3from pymongo.read_concern import ReadConcern
4from pymongo.write_concern import WriteConcern
5
6from mongoengine.common import _import_class
7from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
8from mongoengine.pymongo_support import count_documents
9
10__all__ = (
11    "switch_db",
12    "switch_collection",
13    "no_dereference",
14    "no_sub_classes",
15    "query_counter",
16    "set_write_concern",
17    "set_read_write_concern",
18)
19
20
21class switch_db:
22    """switch_db alias context manager.
23
24    Example ::
25
26        # Register connections
27        register_connection('default', 'mongoenginetest')
28        register_connection('testdb-1', 'mongoenginetest2')
29
30        class Group(Document):
31            name = StringField()
32
33        Group(name='test').save()  # Saves in the default db
34
35        with switch_db(Group, 'testdb-1') as Group:
36            Group(name='hello testdb!').save()  # Saves in testdb-1
37    """
38
39    def __init__(self, cls, db_alias):
40        """Construct the switch_db context manager
41
42        :param cls: the class to change the registered db
43        :param db_alias: the name of the specific database to use
44        """
45        self.cls = cls
46        self.collection = cls._get_collection()
47        self.db_alias = db_alias
48        self.ori_db_alias = cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME)
49
50    def __enter__(self):
51        """Change the db_alias and clear the cached collection."""
52        self.cls._meta["db_alias"] = self.db_alias
53        self.cls._collection = None
54        return self.cls
55
56    def __exit__(self, t, value, traceback):
57        """Reset the db_alias and collection."""
58        self.cls._meta["db_alias"] = self.ori_db_alias
59        self.cls._collection = self.collection
60
61
62class switch_collection:
63    """switch_collection alias context manager.
64
65    Example ::
66
67        class Group(Document):
68            name = StringField()
69
70        Group(name='test').save()  # Saves in the default db
71
72        with switch_collection(Group, 'group1') as Group:
73            Group(name='hello testdb!').save()  # Saves in group1 collection
74    """
75
76    def __init__(self, cls, collection_name):
77        """Construct the switch_collection context manager.
78
79        :param cls: the class to change the registered db
80        :param collection_name: the name of the collection to use
81        """
82        self.cls = cls
83        self.ori_collection = cls._get_collection()
84        self.ori_get_collection_name = cls._get_collection_name
85        self.collection_name = collection_name
86
87    def __enter__(self):
88        """Change the _get_collection_name and clear the cached collection."""
89
90        @classmethod
91        def _get_collection_name(cls):
92            return self.collection_name
93
94        self.cls._get_collection_name = _get_collection_name
95        self.cls._collection = None
96        return self.cls
97
98    def __exit__(self, t, value, traceback):
99        """Reset the collection."""
100        self.cls._collection = self.ori_collection
101        self.cls._get_collection_name = self.ori_get_collection_name
102
103
104class no_dereference:
105    """no_dereference context manager.
106
107    Turns off all dereferencing in Documents for the duration of the context
108    manager::
109
110        with no_dereference(Group) as Group:
111            Group.objects.find()
112    """
113
114    def __init__(self, cls):
115        """Construct the no_dereference context manager.
116
117        :param cls: the class to turn dereferencing off on
118        """
119        self.cls = cls
120
121        ReferenceField = _import_class("ReferenceField")
122        GenericReferenceField = _import_class("GenericReferenceField")
123        ComplexBaseField = _import_class("ComplexBaseField")
124
125        self.deref_fields = [
126            k
127            for k, v in self.cls._fields.items()
128            if isinstance(v, (ReferenceField, GenericReferenceField, ComplexBaseField))
129        ]
130
131    def __enter__(self):
132        """Change the objects default and _auto_dereference values."""
133        for field in self.deref_fields:
134            self.cls._fields[field]._auto_dereference = False
135        return self.cls
136
137    def __exit__(self, t, value, traceback):
138        """Reset the default and _auto_dereference values."""
139        for field in self.deref_fields:
140            self.cls._fields[field]._auto_dereference = True
141        return self.cls
142
143
144class no_sub_classes:
145    """no_sub_classes context manager.
146
147    Only returns instances of this class and no sub (inherited) classes::
148
149        with no_sub_classes(Group) as Group:
150            Group.objects.find()
151    """
152
153    def __init__(self, cls):
154        """Construct the no_sub_classes context manager.
155
156        :param cls: the class to turn querying sub classes on
157        """
158        self.cls = cls
159        self.cls_initial_subclasses = None
160
161    def __enter__(self):
162        """Change the objects default and _auto_dereference values."""
163        self.cls_initial_subclasses = self.cls._subclasses
164        self.cls._subclasses = (self.cls._class_name,)
165        return self.cls
166
167    def __exit__(self, t, value, traceback):
168        """Reset the default and _auto_dereference values."""
169        self.cls._subclasses = self.cls_initial_subclasses
170
171
172class query_counter:
173    """Query_counter context manager to get the number of queries.
174    This works by updating the `profiling_level` of the database so that all queries get logged,
175    resetting the db.system.profile collection at the beginning of the context and counting the new entries.
176
177    This was designed for debugging purpose. In fact it is a global counter so queries issued by other threads/processes
178    can interfere with it
179
180    Be aware that:
181    - Iterating over large amount of documents (>101) makes pymongo issue `getmore` queries to fetch the next batch of
182        documents (https://docs.mongodb.com/manual/tutorial/iterate-a-cursor/#cursor-batches)
183    - Some queries are ignored by default by the counter (killcursors, db.system.indexes)
184    """
185
186    def __init__(self, alias=DEFAULT_CONNECTION_NAME):
187        """Construct the query_counter
188        """
189        self.db = get_db(alias=alias)
190        self.initial_profiling_level = None
191        self._ctx_query_counter = 0  # number of queries issued by the context
192
193        self._ignored_query = {
194            "ns": {"$ne": "%s.system.indexes" % self.db.name},
195            "op": {"$ne": "killcursors"},  # MONGODB < 3.2
196            "command.killCursors": {"$exists": False},  # MONGODB >= 3.2
197        }
198
199    def _turn_on_profiling(self):
200        self.initial_profiling_level = self.db.profiling_level()
201        self.db.set_profiling_level(0)
202        self.db.system.profile.drop()
203        self.db.set_profiling_level(2)
204
205    def _resets_profiling(self):
206        self.db.set_profiling_level(self.initial_profiling_level)
207
208    def __enter__(self):
209        self._turn_on_profiling()
210        return self
211
212    def __exit__(self, t, value, traceback):
213        self._resets_profiling()
214
215    def __eq__(self, value):
216        counter = self._get_count()
217        return value == counter
218
219    def __ne__(self, value):
220        return not self.__eq__(value)
221
222    def __lt__(self, value):
223        return self._get_count() < value
224
225    def __le__(self, value):
226        return self._get_count() <= value
227
228    def __gt__(self, value):
229        return self._get_count() > value
230
231    def __ge__(self, value):
232        return self._get_count() >= value
233
234    def __int__(self):
235        return self._get_count()
236
237    def __repr__(self):
238        """repr query_counter as the number of queries."""
239        return "%s" % self._get_count()
240
241    def _get_count(self):
242        """Get the number of queries by counting the current number of entries in db.system.profile
243        and substracting the queries issued by this context. In fact everytime this is called, 1 query is
244        issued so we need to balance that
245        """
246        count = (
247            count_documents(self.db.system.profile, self._ignored_query)
248            - self._ctx_query_counter
249        )
250        self._ctx_query_counter += (
251            1  # Account for the query we just issued to gather the information
252        )
253        return count
254
255
256@contextmanager
257def set_write_concern(collection, write_concerns):
258    combined_concerns = dict(collection.write_concern.document.items())
259    combined_concerns.update(write_concerns)
260    yield collection.with_options(write_concern=WriteConcern(**combined_concerns))
261
262
263@contextmanager
264def set_read_write_concern(collection, write_concerns, read_concerns):
265    combined_write_concerns = dict(collection.write_concern.document.items())
266
267    if write_concerns is not None:
268        combined_write_concerns.update(write_concerns)
269
270    combined_read_concerns = dict(collection.read_concern.document.items())
271
272    if read_concerns is not None:
273        combined_read_concerns.update(read_concerns)
274
275    yield collection.with_options(
276        write_concern=WriteConcern(**combined_write_concerns),
277        read_concern=ReadConcern(**combined_read_concerns),
278    )
279