1# Copyright 2012-present MongoDB, Inc.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Utilities for testing pymongo
16"""
17
18import contextlib
19import copy
20import functools
21import os
22import re
23import shutil
24import sys
25import threading
26import time
27import warnings
28
29from collections import defaultdict
30from functools import partial
31
32from bson import json_util, py3compat
33from bson.objectid import ObjectId
34from bson.py3compat import abc, iteritems, string_type
35from bson.son import SON
36
37from pymongo import (MongoClient,
38                     monitoring, operations, read_preferences)
39from pymongo.collection import ReturnDocument
40from pymongo.errors import ConfigurationError, OperationFailure
41from pymongo.hello_compat import HelloCompat
42from pymongo.monitoring import _SENSITIVE_COMMANDS
43from pymongo.pool import (_CancellationContext,
44                          PoolOptions,
45                          _PoolGeneration)
46from pymongo.read_concern import ReadConcern
47from pymongo.read_preferences import ReadPreference
48from pymongo.server_selectors import (any_server_selector,
49                                      writable_server_selector)
50from pymongo.server_type import SERVER_TYPE
51from pymongo.write_concern import WriteConcern
52
53from test import (client_context,
54                  db_user,
55                  db_pwd)
56
57if sys.version_info[0] < 3:
58    # Python 2.7, use our backport.
59    from test.barrier import Barrier
60else:
61    from threading import Barrier
62
63
64IMPOSSIBLE_WRITE_CONCERN = WriteConcern(w=50)
65
66
67class BaseListener(object):
68    def __init__(self):
69        self.events = []
70
71    def reset(self):
72        self.events = []
73
74    def add_event(self, event):
75        self.events.append(event)
76
77    def event_count(self, event_type):
78        return len(self.events_by_type(event_type))
79
80    def events_by_type(self, event_type):
81        """Return the matching events by event class.
82
83        event_type can be a single class or a tuple of classes.
84        """
85        return self.matching(lambda e: isinstance(e, event_type))
86
87    def matching(self, matcher):
88        """Return the matching events."""
89        return [event for event in self.events[:] if matcher(event)]
90
91    def wait_for_event(self, event, count):
92        """Wait for a number of events to be published, or fail."""
93        wait_until(lambda: self.event_count(event) >= count,
94                   'find %s %s event(s)' % (count, event))
95
96
97class CMAPListener(BaseListener, monitoring.ConnectionPoolListener):
98
99    def connection_created(self, event):
100        self.add_event(event)
101
102    def connection_ready(self, event):
103        self.add_event(event)
104
105    def connection_closed(self, event):
106        self.add_event(event)
107
108    def connection_check_out_started(self, event):
109        self.add_event(event)
110
111    def connection_check_out_failed(self, event):
112        self.add_event(event)
113
114    def connection_checked_out(self, event):
115        self.add_event(event)
116
117    def connection_checked_in(self, event):
118        self.add_event(event)
119
120    def pool_created(self, event):
121        self.add_event(event)
122
123    def pool_cleared(self, event):
124        self.add_event(event)
125
126    def pool_closed(self, event):
127        self.add_event(event)
128
129
130class EventListener(monitoring.CommandListener):
131
132    def __init__(self):
133        self.results = defaultdict(list)
134
135    def started(self, event):
136        self.results['started'].append(event)
137
138    def succeeded(self, event):
139        self.results['succeeded'].append(event)
140
141    def failed(self, event):
142        self.results['failed'].append(event)
143
144    def started_command_names(self):
145        """Return list of command names started."""
146        return [event.command_name for event in self.results['started']]
147
148    def reset(self):
149        """Reset the state of this listener."""
150        self.results.clear()
151
152
153class TopologyEventListener(monitoring.TopologyListener):
154    def __init__(self):
155        self.results = defaultdict(list)
156
157    def closed(self, event):
158        self.results['closed'].append(event)
159
160    def description_changed(self, event):
161        self.results['description_changed'].append(event)
162
163    def opened(self, event):
164        self.results['opened'].append(event)
165
166    def reset(self):
167        """Reset the state of this listener."""
168        self.results.clear()
169
170
171class WhiteListEventListener(EventListener):
172
173    def __init__(self, *commands):
174        self.commands = set(commands)
175        super(WhiteListEventListener, self).__init__()
176
177    def started(self, event):
178        if event.command_name in self.commands:
179            super(WhiteListEventListener, self).started(event)
180
181    def succeeded(self, event):
182        if event.command_name in self.commands:
183            super(WhiteListEventListener, self).succeeded(event)
184
185    def failed(self, event):
186        if event.command_name in self.commands:
187            super(WhiteListEventListener, self).failed(event)
188
189
190class OvertCommandListener(EventListener):
191    """A CommandListener that ignores sensitive commands."""
192    def started(self, event):
193        if event.command_name.lower() not in _SENSITIVE_COMMANDS:
194            super(OvertCommandListener, self).started(event)
195
196    def succeeded(self, event):
197        if event.command_name.lower() not in _SENSITIVE_COMMANDS:
198            super(OvertCommandListener, self).succeeded(event)
199
200    def failed(self, event):
201        if event.command_name.lower() not in _SENSITIVE_COMMANDS:
202            super(OvertCommandListener, self).failed(event)
203
204
205class _ServerEventListener(object):
206    """Listens to all events."""
207
208    def __init__(self):
209        self.results = []
210
211    def opened(self, event):
212        self.results.append(event)
213
214    def description_changed(self, event):
215        self.results.append(event)
216
217    def closed(self, event):
218        self.results.append(event)
219
220    def matching(self, matcher):
221        """Return the matching events."""
222        results = self.results[:]
223        return [event for event in results if matcher(event)]
224
225    def reset(self):
226        self.results = []
227
228
229class ServerEventListener(_ServerEventListener,
230                          monitoring.ServerListener):
231    """Listens to Server events."""
232
233
234class ServerAndTopologyEventListener(ServerEventListener,
235                                     monitoring.TopologyListener):
236    """Listens to Server and Topology events."""
237
238
239class HeartbeatEventListener(BaseListener, monitoring.ServerHeartbeatListener):
240    """Listens to only server heartbeat events."""
241
242    def started(self, event):
243        self.add_event(event)
244
245    def succeeded(self, event):
246        self.add_event(event)
247
248    def failed(self, event):
249        self.add_event(event)
250
251
252class MockSocketInfo(object):
253    def __init__(self):
254        self.cancel_context = _CancellationContext()
255        self.more_to_come = False
256
257    def close_socket(self, reason):
258        pass
259
260    def __enter__(self):
261        return self
262
263    def __exit__(self, exc_type, exc_val, exc_tb):
264        pass
265
266
267class MockPool(object):
268    def __init__(self, address, options, handshake=True):
269        self.gen = _PoolGeneration()
270        self._lock = threading.Lock()
271        self.opts = PoolOptions()
272
273    def stale_generation(self, gen, service_id):
274        return self.gen.stale(gen, service_id)
275
276    def get_socket(self, all_credentials, handler=None):
277        return MockSocketInfo()
278
279    def return_socket(self, *args, **kwargs):
280        pass
281
282    def _reset(self, service_id=None):
283        with self._lock:
284            self.gen.inc(service_id)
285
286    def reset(self, service_id=None):
287        self._reset()
288
289    def close(self):
290        self._reset()
291
292    def update_is_writable(self, is_writable):
293        pass
294
295    def remove_stale_sockets(self, *args, **kwargs):
296        pass
297
298
299class ScenarioDict(dict):
300    """Dict that returns {} for any unknown key, recursively."""
301    def __init__(self, data):
302        def convert(v):
303            if isinstance(v, abc.Mapping):
304                return ScenarioDict(v)
305            if isinstance(v, (py3compat.string_type, bytes)):
306                return v
307            if isinstance(v, abc.Sequence):
308                return [convert(item) for item in v]
309            return v
310
311        dict.__init__(self, [(k, convert(v)) for k, v in data.items()])
312
313    def __getitem__(self, item):
314        try:
315            return dict.__getitem__(self, item)
316        except KeyError:
317            # Unlike a defaultdict, don't set the key, just return a dict.
318            return ScenarioDict({})
319
320
321class CompareType(object):
322    """Class that compares equal to any object of the given type."""
323    def __init__(self, type):
324        self.type = type
325
326    def __eq__(self, other):
327        return isinstance(other, self.type)
328
329    def __ne__(self, other):
330        """Needed for Python 2."""
331        return not self.__eq__(other)
332
333
334class FunctionCallRecorder(object):
335    """Utility class to wrap a callable and record its invocations."""
336    def __init__(self, function):
337        self._function = function
338        self._call_list = []
339
340    def __call__(self, *args, **kwargs):
341        self._call_list.append((args, kwargs))
342        return self._function(*args, **kwargs)
343
344    def reset(self):
345        """Wipes the call list."""
346        self._call_list = []
347
348    def call_list(self):
349        """Returns a copy of the call list."""
350        return self._call_list[:]
351
352    @property
353    def call_count(self):
354        """Returns the number of times the function has been called."""
355        return len(self._call_list)
356
357
358class TestCreator(object):
359    """Class to create test cases from specifications."""
360    def __init__(self, create_test, test_class, test_path):
361        """Create a TestCreator object.
362
363        :Parameters:
364          - `create_test`: callback that returns a test case. The callback
365            must accept the following arguments - a dictionary containing the
366            entire test specification (the `scenario_def`), a dictionary
367            containing the specification for which the test case will be
368            generated (the `test_def`).
369          - `test_class`: the unittest.TestCase class in which to create the
370            test case.
371          - `test_path`: path to the directory containing the JSON files with
372            the test specifications.
373            """
374        self._create_test = create_test
375        self._test_class = test_class
376        self.test_path = test_path
377
378    def _ensure_min_max_server_version(self, scenario_def, method):
379        """Test modifier that enforces a version range for the server on a
380        test case."""
381        if 'minServerVersion' in scenario_def:
382            min_ver = tuple(
383                int(elt) for
384                elt in scenario_def['minServerVersion'].split('.'))
385            if min_ver is not None:
386                method = client_context.require_version_min(*min_ver)(method)
387
388        if 'maxServerVersion' in scenario_def:
389            max_ver = tuple(
390                int(elt) for
391                elt in scenario_def['maxServerVersion'].split('.'))
392            if max_ver is not None:
393                method = client_context.require_version_max(*max_ver)(method)
394
395        return method
396
397    @staticmethod
398    def valid_topology(run_on_req):
399        return client_context.is_topology_type(
400            run_on_req.get('topology', ['single', 'replicaset', 'sharded',
401                                        'load-balanced']))
402
403    @staticmethod
404    def min_server_version(run_on_req):
405        version = run_on_req.get('minServerVersion')
406        if version:
407            min_ver = tuple(int(elt) for elt in version.split('.'))
408            return client_context.version >= min_ver
409        return True
410
411    @staticmethod
412    def max_server_version(run_on_req):
413        version = run_on_req.get('maxServerVersion')
414        if version:
415            max_ver = tuple(int(elt) for elt in version.split('.'))
416            return client_context.version <= max_ver
417        return True
418
419    def should_run_on(self, scenario_def):
420        run_on = scenario_def.get('runOn', [])
421        if not run_on:
422            # Always run these tests.
423            return True
424
425        for req in run_on:
426            if (self.valid_topology(req) and
427                    self.min_server_version(req) and
428                    self.max_server_version(req)):
429                return True
430        return False
431
432    def ensure_run_on(self, scenario_def, method):
433        """Test modifier that enforces a 'runOn' on a test case."""
434        return client_context._require(
435            lambda: self.should_run_on(scenario_def),
436            "runOn not satisfied",
437            method)
438
439    def tests(self, scenario_def):
440        """Allow CMAP spec test to override the location of test."""
441        return scenario_def['tests']
442
443    def create_tests(self):
444        for dirpath, _, filenames in os.walk(self.test_path):
445            dirname = os.path.split(dirpath)[-1]
446
447            for filename in filenames:
448                with open(os.path.join(dirpath, filename)) as scenario_stream:
449                    # Use tz_aware=False to match how CodecOptions decodes
450                    # dates.
451                    opts = json_util.JSONOptions(tz_aware=False)
452                    scenario_def = ScenarioDict(
453                        json_util.loads(scenario_stream.read(),
454                                        json_options=opts))
455
456                test_type = os.path.splitext(filename)[0]
457
458                # Construct test from scenario.
459                for test_def in self.tests(scenario_def):
460                    test_name = 'test_%s_%s_%s' % (
461                        dirname,
462                        test_type.replace("-", "_").replace('.', '_'),
463                        str(test_def['description'].replace(" ", "_").replace(
464                            '.', '_')))
465
466                    new_test = self._create_test(
467                        scenario_def, test_def, test_name)
468                    new_test = self._ensure_min_max_server_version(
469                        scenario_def, new_test)
470                    new_test = self.ensure_run_on(
471                        scenario_def, new_test)
472
473                    new_test.__name__ = test_name
474                    setattr(self._test_class, new_test.__name__, new_test)
475
476
477def _connection_string(h, authenticate):
478    if h.startswith("mongodb://"):
479        return h
480    elif client_context.auth_enabled and authenticate:
481        return "mongodb://%s:%s@%s" % (db_user, db_pwd, str(h))
482    else:
483        return "mongodb://%s" % (str(h),)
484
485
486def _mongo_client(host, port, authenticate=True, directConnection=False,
487                  **kwargs):
488    """Create a new client over SSL/TLS if necessary."""
489    host = host or client_context.host
490    port = port or client_context.port
491    client_options = client_context.default_client_options.copy()
492    if client_context.replica_set_name and not directConnection:
493        client_options['replicaSet'] = client_context.replica_set_name
494    client_options.update(kwargs)
495
496    client = MongoClient(_connection_string(host, authenticate), port,
497                         **client_options)
498
499    return client
500
501
502def single_client_noauth(h=None, p=None, **kwargs):
503    """Make a direct connection. Don't authenticate."""
504    return _mongo_client(h, p, authenticate=False,
505                         directConnection=True, **kwargs)
506
507
508def single_client(h=None, p=None, **kwargs):
509    """Make a direct connection, and authenticate if necessary."""
510    return _mongo_client(h, p, directConnection=True, **kwargs)
511
512
513def rs_client_noauth(h=None, p=None, **kwargs):
514    """Connect to the replica set. Don't authenticate."""
515    return _mongo_client(h, p, authenticate=False, **kwargs)
516
517
518def rs_client(h=None, p=None, **kwargs):
519    """Connect to the replica set and authenticate if necessary."""
520    return _mongo_client(h, p, **kwargs)
521
522
523def rs_or_single_client_noauth(h=None, p=None, **kwargs):
524    """Connect to the replica set if there is one, otherwise the standalone.
525
526    Like rs_or_single_client, but does not authenticate.
527    """
528    return _mongo_client(h, p, authenticate=False, **kwargs)
529
530
531def rs_or_single_client(h=None, p=None, **kwargs):
532    """Connect to the replica set if there is one, otherwise the standalone.
533
534    Authenticates if necessary.
535    """
536    return _mongo_client(h, p, **kwargs)
537
538
539def ensure_all_connected(client):
540    """Ensure that the client's connection pool has socket connections to all
541    members of a replica set. Raises ConfigurationError when called with a
542    non-replica set client.
543
544    Depending on the use-case, the caller may need to clear any event listeners
545    that are configured on the client.
546    """
547    hello = client.admin.command(HelloCompat.LEGACY_CMD)
548    if 'setName' not in hello:
549        raise ConfigurationError("cluster is not a replica set")
550
551    target_host_list = set(hello['hosts'])
552    connected_host_list = set([hello['me']])
553    admindb = client.get_database('admin')
554
555    # Run legacy hello until we have connected to each host at least once.
556    while connected_host_list != target_host_list:
557        hello = admindb.command(HelloCompat.LEGACY_CMD,
558                                   read_preference=ReadPreference.SECONDARY)
559        connected_host_list.update([hello["me"]])
560
561
562def one(s):
563    """Get one element of a set"""
564    return next(iter(s))
565
566
567def oid_generated_on_process(oid):
568    """Makes a determination as to whether the given ObjectId was generated
569    by the current process, based on the 5-byte random number in the ObjectId.
570    """
571    return ObjectId._random() == oid.binary[4:9]
572
573
574def delay(sec):
575    return '''function() { sleep(%f * 1000); return true; }''' % sec
576
577
578def get_command_line(client):
579    command_line = client.admin.command('getCmdLineOpts')
580    assert command_line['ok'] == 1, "getCmdLineOpts() failed"
581    return command_line
582
583
584def camel_to_snake(camel):
585    # Regex to convert CamelCase to snake_case.
586    snake = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', camel)
587    return re.sub('([a-z0-9])([A-Z])', r'\1_\2', snake).lower()
588
589
590def camel_to_upper_camel(camel):
591    return camel[0].upper() + camel[1:]
592
593
594def camel_to_snake_args(arguments):
595    for arg_name in list(arguments):
596        c2s = camel_to_snake(arg_name)
597        arguments[c2s] = arguments.pop(arg_name)
598    return arguments
599
600
601def snake_to_camel(snake):
602    # Regex to convert snake_case to lowerCamelCase.
603    return re.sub(r'_([a-z])', lambda m: m.group(1).upper(), snake)
604
605
606def parse_collection_options(opts):
607    if 'readPreference' in opts:
608        opts['read_preference'] = parse_read_preference(
609            opts.pop('readPreference'))
610
611    if 'writeConcern' in opts:
612        opts['write_concern'] = WriteConcern(
613            **dict(opts.pop('writeConcern')))
614
615    if 'readConcern' in opts:
616        opts['read_concern'] = ReadConcern(
617            **dict(opts.pop('readConcern')))
618    return opts
619
620
621def server_started_with_option(client, cmdline_opt, config_opt):
622    """Check if the server was started with a particular option.
623
624    :Parameters:
625      - `cmdline_opt`: The command line option (i.e. --nojournal)
626      - `config_opt`: The config file option (i.e. nojournal)
627    """
628    command_line = get_command_line(client)
629    if 'parsed' in command_line:
630        parsed = command_line['parsed']
631        if config_opt in parsed:
632            return parsed[config_opt]
633    argv = command_line['argv']
634    return cmdline_opt in argv
635
636
637def server_started_with_auth(client):
638    try:
639        command_line = get_command_line(client)
640    except OperationFailure as e:
641        msg = e.details.get('errmsg', '')
642        if e.code == 13 or 'unauthorized' in msg or 'login' in msg:
643            # Unauthorized.
644            return True
645        raise
646
647    # MongoDB >= 2.0
648    if 'parsed' in command_line:
649        parsed = command_line['parsed']
650        # MongoDB >= 2.6
651        if 'security' in parsed:
652            security = parsed['security']
653            # >= rc3
654            if 'authorization' in security:
655                return security['authorization'] == 'enabled'
656            # < rc3
657            return security.get('auth', False) or bool(security.get('keyFile'))
658        return parsed.get('auth', False) or bool(parsed.get('keyFile'))
659    # Legacy
660    argv = command_line['argv']
661    return '--auth' in argv or '--keyFile' in argv
662
663
664def server_started_with_nojournal(client):
665    command_line = get_command_line(client)
666
667    # MongoDB 2.6.
668    if 'parsed' in command_line:
669        parsed = command_line['parsed']
670        if 'storage' in parsed:
671            storage = parsed['storage']
672            if 'journal' in storage:
673                return not storage['journal']['enabled']
674
675    return server_started_with_option(client, '--nojournal', 'nojournal')
676
677
678def drop_collections(db):
679    # Drop all non-system collections in this database.
680    for coll in db.list_collection_names(
681            filter={"name": {"$regex": r"^(?!system\.)"}}):
682        db.drop_collection(coll)
683
684
685def remove_all_users(db):
686    db.command("dropAllUsersFromDatabase", 1,
687               writeConcern={"w": client_context.w})
688
689
690def joinall(threads):
691    """Join threads with a 5-minute timeout, assert joins succeeded"""
692    for t in threads:
693        t.join(300)
694        assert not t.is_alive(), "Thread %s hung" % t
695
696
697def connected(client):
698    """Convenience to wait for a newly-constructed client to connect."""
699    with warnings.catch_warnings():
700        # Ignore warning that "ping" is always routed to primary even
701        # if client's read preference isn't PRIMARY.
702        warnings.simplefilter("ignore", UserWarning)
703        client.admin.command('ping')  # Force connection.
704
705    return client
706
707
708def wait_until(predicate, success_description, timeout=10):
709    """Wait up to 10 seconds (by default) for predicate to be true.
710
711    E.g.:
712
713        wait_until(lambda: client.primary == ('a', 1),
714                   'connect to the primary')
715
716    If the lambda-expression isn't true after 10 seconds, we raise
717    AssertionError("Didn't ever connect to the primary").
718
719    Returns the predicate's first true value.
720    """
721    start = time.time()
722    interval = min(float(timeout)/100, 0.1)
723    while True:
724        retval = predicate()
725        if retval:
726            return retval
727
728        if time.time() - start > timeout:
729            raise AssertionError("Didn't ever %s" % success_description)
730
731        time.sleep(interval)
732
733
734def repl_set_step_down(client, **kwargs):
735    """Run replSetStepDown, first unfreezing a secondary with replSetFreeze."""
736    cmd = SON([('replSetStepDown', 1)])
737    cmd.update(kwargs)
738
739    # Unfreeze a secondary to ensure a speedy election.
740    client.admin.command(
741        'replSetFreeze', 0, read_preference=ReadPreference.SECONDARY)
742    client.admin.command(cmd)
743
744def is_mongos(client):
745    res = client.admin.command(HelloCompat.LEGACY_CMD)
746    return res.get('msg', '') == 'isdbgrid'
747
748
749def assertRaisesExactly(cls, fn, *args, **kwargs):
750    """
751    Unlike the standard assertRaises, this checks that a function raises a
752    specific class of exception, and not a subclass. E.g., check that
753    MongoClient() raises ConnectionFailure but not its subclass, AutoReconnect.
754    """
755    try:
756        fn(*args, **kwargs)
757    except Exception as e:
758        assert e.__class__ == cls, "got %s, expected %s" % (
759            e.__class__.__name__, cls.__name__)
760    else:
761        raise AssertionError("%s not raised" % cls)
762
763
764@contextlib.contextmanager
765def _ignore_deprecations():
766    with warnings.catch_warnings():
767        warnings.simplefilter("ignore", DeprecationWarning)
768        yield
769
770
771def ignore_deprecations(wrapped=None):
772    """A context manager or a decorator."""
773    if wrapped:
774        @functools.wraps(wrapped)
775        def wrapper(*args, **kwargs):
776            with _ignore_deprecations():
777                return wrapped(*args, **kwargs)
778
779        return wrapper
780
781    else:
782        return _ignore_deprecations()
783
784
785class DeprecationFilter(object):
786
787    def __init__(self, action="ignore"):
788        """Start filtering deprecations."""
789        self.warn_context = warnings.catch_warnings()
790        self.warn_context.__enter__()
791        warnings.simplefilter(action, DeprecationWarning)
792
793    def stop(self):
794        """Stop filtering deprecations."""
795        self.warn_context.__exit__()
796        self.warn_context = None
797
798
799def get_pool(client):
800    """Get the standalone, primary, or mongos pool."""
801    topology = client._get_topology()
802    server = topology.select_server(writable_server_selector)
803    return server.pool
804
805
806def get_pools(client):
807    """Get all pools."""
808    return [
809        server.pool for server in
810        client._get_topology().select_servers(any_server_selector)]
811
812
813# Constants for run_threads and lazy_client_trial.
814NTRIALS = 5
815NTHREADS = 10
816
817
818def run_threads(collection, target):
819    """Run a target function in many threads.
820
821    target is a function taking a Collection and an integer.
822    """
823    threads = []
824    for i in range(NTHREADS):
825        bound_target = partial(target, collection, i)
826        threads.append(threading.Thread(target=bound_target))
827
828    for t in threads:
829        t.start()
830
831    for t in threads:
832        t.join(60)
833        assert not t.is_alive()
834
835
836@contextlib.contextmanager
837def frequent_thread_switches():
838    """Make concurrency bugs more likely to manifest."""
839    interval = None
840    if not sys.platform.startswith('java'):
841        if hasattr(sys, 'getswitchinterval'):
842            interval = sys.getswitchinterval()
843            sys.setswitchinterval(1e-6)
844        else:
845            interval = sys.getcheckinterval()
846            sys.setcheckinterval(1)
847
848    try:
849        yield
850    finally:
851        if not sys.platform.startswith('java'):
852            if hasattr(sys, 'setswitchinterval'):
853                sys.setswitchinterval(interval)
854            else:
855                sys.setcheckinterval(interval)
856
857
858def lazy_client_trial(reset, target, test, get_client):
859    """Test concurrent operations on a lazily-connecting client.
860
861    `reset` takes a collection and resets it for the next trial.
862
863    `target` takes a lazily-connecting collection and an index from
864    0 to NTHREADS, and performs some operation, e.g. an insert.
865
866    `test` takes the lazily-connecting collection and asserts a
867    post-condition to prove `target` succeeded.
868    """
869    collection = client_context.client.pymongo_test.test
870
871    with frequent_thread_switches():
872        for i in range(NTRIALS):
873            reset(collection)
874            lazy_client = get_client()
875            lazy_collection = lazy_client.pymongo_test.test
876            run_threads(lazy_collection, target)
877            test(lazy_collection)
878
879
880def gevent_monkey_patched():
881    """Check if gevent's monkey patching is active."""
882    # In Python 3.6 importing gevent.socket raises an ImportWarning.
883    with warnings.catch_warnings():
884        warnings.simplefilter("ignore", ImportWarning)
885        try:
886            import socket
887            import gevent.socket
888            return socket.socket is gevent.socket.socket
889        except ImportError:
890            return False
891
892
893def eventlet_monkey_patched():
894    """Check if eventlet's monkey patching is active."""
895    try:
896        import threading
897        import eventlet
898        return (threading.current_thread.__module__ ==
899                'eventlet.green.threading')
900    except ImportError:
901        return False
902
903
904def is_greenthread_patched():
905    return gevent_monkey_patched() or eventlet_monkey_patched()
906
907
908def cdecimal_patched():
909    """Check if Python 2.7 cdecimal patching is active."""
910    try:
911        import decimal
912        import cdecimal
913        return decimal is cdecimal
914    except ImportError:
915        return False
916
917
918def disable_replication(client):
919    """Disable replication on all secondaries, requires MongoDB 3.2."""
920    for host, port in client.secondaries:
921        secondary = single_client(host, port)
922        secondary.admin.command('configureFailPoint', 'stopReplProducer',
923                                mode='alwaysOn')
924
925
926def enable_replication(client):
927    """Enable replication on all secondaries, requires MongoDB 3.2."""
928    for host, port in client.secondaries:
929        secondary = single_client(host, port)
930        secondary.admin.command('configureFailPoint', 'stopReplProducer',
931                                mode='off')
932
933
934class ExceptionCatchingThread(threading.Thread):
935    """A thread that stores any exception encountered from run()."""
936    def __init__(self, *args, **kwargs):
937        self.exc = None
938        super(ExceptionCatchingThread, self).__init__(*args, **kwargs)
939
940    def run(self):
941        try:
942            super(ExceptionCatchingThread, self).run()
943        except BaseException as exc:
944            self.exc = exc
945            raise
946
947
948def parse_read_preference(pref):
949    # Make first letter lowercase to match read_pref's modes.
950    mode_string = pref.get('mode', 'primary')
951    mode_string = mode_string[:1].lower() + mode_string[1:]
952    mode = read_preferences.read_pref_mode_from_name(mode_string)
953    max_staleness = pref.get('maxStalenessSeconds', -1)
954    tag_sets = pref.get('tag_sets')
955    return read_preferences.make_read_preference(
956        mode, tag_sets=tag_sets, max_staleness=max_staleness)
957
958
959def server_name_to_type(name):
960    """Convert a ServerType name to the corresponding value. For SDAM tests."""
961    # Special case, some tests in the spec include the PossiblePrimary
962    # type, but only single-threaded drivers need that type. We call
963    # possible primaries Unknown.
964    if name == 'PossiblePrimary':
965        return SERVER_TYPE.Unknown
966    return getattr(SERVER_TYPE, name)
967
968
969def cat_files(dest, *sources):
970    """Cat multiple files into dest."""
971    with open(dest, 'wb') as fdst:
972        for src in sources:
973            with open(src, 'rb') as fsrc:
974                shutil.copyfileobj(fsrc, fdst)
975
976
977@contextlib.contextmanager
978def assertion_context(msg):
979    """A context manager that adds info to an assertion failure."""
980    try:
981        yield
982    except AssertionError as exc:
983        msg = '%s (%s)' % (exc, msg)
984        py3compat.reraise(type(exc), msg, sys.exc_info()[2])
985
986
987def parse_spec_options(opts):
988    if 'readPreference' in opts:
989        opts['read_preference'] = parse_read_preference(
990            opts.pop('readPreference'))
991
992    if 'writeConcern' in opts:
993        opts['write_concern'] = WriteConcern(
994            **dict(opts.pop('writeConcern')))
995
996    if 'readConcern' in opts:
997        opts['read_concern'] = ReadConcern(
998            **dict(opts.pop('readConcern')))
999
1000    if 'maxTimeMS' in opts:
1001        opts['max_time_ms'] = opts.pop('maxTimeMS')
1002
1003    if 'maxCommitTimeMS' in opts:
1004        opts['max_commit_time_ms'] = opts.pop('maxCommitTimeMS')
1005
1006    if 'hint' in opts:
1007        hint = opts.pop('hint')
1008        if not isinstance(hint, string_type):
1009            hint = list(iteritems(hint))
1010        opts['hint'] = hint
1011
1012    # Properly format 'hint' arguments for the Bulk API tests.
1013    if 'requests' in opts:
1014        reqs = opts.pop('requests')
1015        for req in reqs:
1016            if 'name' in req:
1017                # CRUD v2 format
1018                args = req.pop('arguments', {})
1019                if 'hint' in args:
1020                    hint = args.pop('hint')
1021                    if not isinstance(hint, string_type):
1022                        hint = list(iteritems(hint))
1023                    args['hint'] = hint
1024                req['arguments'] = args
1025            else:
1026                # Unified test format
1027                bulk_model, spec = next(iteritems(req))
1028                if 'hint' in spec:
1029                    hint = spec.pop('hint')
1030                    if not isinstance(hint, string_type):
1031                        hint = list(iteritems(hint))
1032                    spec['hint'] = hint
1033        opts['requests'] = reqs
1034
1035    return dict(opts)
1036
1037
1038def prepare_spec_arguments(spec, arguments, opname, entity_map,
1039                           with_txn_callback):
1040    for arg_name in list(arguments):
1041        c2s = camel_to_snake(arg_name)
1042        # PyMongo accepts sort as list of tuples.
1043        if arg_name == "sort":
1044            sort_dict = arguments[arg_name]
1045            arguments[arg_name] = list(iteritems(sort_dict))
1046        # Named "key" instead not fieldName.
1047        if arg_name == "fieldName":
1048            arguments["key"] = arguments.pop(arg_name)
1049        # Aggregate uses "batchSize", while find uses batch_size.
1050        elif ((arg_name == "batchSize" or arg_name == "allowDiskUse") and
1051              opname == "aggregate"):
1052            continue
1053        # Requires boolean returnDocument.
1054        elif arg_name == "returnDocument":
1055            arguments[c2s] = getattr(ReturnDocument, arguments.pop(arg_name).upper())
1056        elif c2s == "requests":
1057            # Parse each request into a bulk write model.
1058            requests = []
1059            for request in arguments["requests"]:
1060                if 'name' in request:
1061                    # CRUD v2 format
1062                    bulk_model = camel_to_upper_camel(request["name"])
1063                    bulk_class = getattr(operations, bulk_model)
1064                    bulk_arguments = camel_to_snake_args(request["arguments"])
1065                else:
1066                    # Unified test format
1067                    bulk_model, spec = next(iteritems(request))
1068                    bulk_class = getattr(operations, camel_to_upper_camel(bulk_model))
1069                    bulk_arguments = camel_to_snake_args(spec)
1070                requests.append(bulk_class(**dict(bulk_arguments)))
1071            arguments["requests"] = requests
1072        elif arg_name == "session":
1073            arguments['session'] = entity_map[arguments['session']]
1074        elif (opname in ('command', 'run_admin_command') and
1075              arg_name == 'command'):
1076            # Ensure the first key is the command name.
1077            ordered_command = SON([(spec['command_name'], 1)])
1078            ordered_command.update(arguments['command'])
1079            arguments['command'] = ordered_command
1080        elif opname == 'open_download_stream' and arg_name == 'id':
1081            arguments['file_id'] = arguments.pop(arg_name)
1082        elif opname != 'find' and c2s == 'max_time_ms':
1083            # find is the only method that accepts snake_case max_time_ms.
1084            # All other methods take kwargs which must use the server's
1085            # camelCase maxTimeMS. See PYTHON-1855.
1086            arguments['maxTimeMS'] = arguments.pop('max_time_ms')
1087        elif opname == 'with_transaction' and arg_name == 'callback':
1088            if 'operations' in arguments[arg_name]:
1089                # CRUD v2 format
1090                callback_ops = arguments[arg_name]['operations']
1091            else:
1092                # Unified test format
1093                callback_ops = arguments[arg_name]
1094            arguments['callback'] = lambda _: with_txn_callback(
1095                copy.deepcopy(callback_ops))
1096        elif opname == 'drop_collection' and arg_name == 'collection':
1097            arguments['name_or_collection'] = arguments.pop(arg_name)
1098        elif opname == 'create_collection':
1099            if arg_name == 'collection':
1100                arguments['name'] = arguments.pop(arg_name)
1101            # Any other arguments to create_collection are passed through
1102            # **kwargs.
1103        elif opname == 'create_index' and arg_name == 'keys':
1104            arguments['keys'] = list(arguments.pop(arg_name).items())
1105        elif opname == 'drop_index' and arg_name == 'name':
1106            arguments['index_or_name'] = arguments.pop(arg_name)
1107        else:
1108            arguments[c2s] = arguments.pop(arg_name)
1109