1# Copyright 2010-present MongoDB, Inc.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Test suite for pymongo, bson, and gridfs.
16"""
17
18import gc
19import os
20import socket
21import sys
22import threading
23import time
24import traceback
25import unittest
26import warnings
27
28try:
29    from xmlrunner import XMLTestRunner
30    HAVE_XML = True
31# ValueError is raised when version 3+ is installed on Jython 2.7.
32except (ImportError, ValueError):
33    HAVE_XML = False
34
35try:
36    import ipaddress
37    HAVE_IPADDRESS = True
38except ImportError:
39    HAVE_IPADDRESS = False
40
41from contextlib import contextmanager
42from functools import wraps
43from unittest import SkipTest
44
45import pymongo
46import pymongo.errors
47
48from bson.son import SON
49from pymongo import common, message
50from pymongo.common import partition_node
51from pymongo.hello_compat import HelloCompat
52from pymongo.server_api import ServerApi
53from pymongo.ssl_support import HAVE_SSL, _ssl
54from pymongo.uri_parser import parse_uri
55from test.version import Version
56
57if HAVE_SSL:
58    import ssl
59
60try:
61    # Enable the fault handler to dump the traceback of each running thread
62    # after a segfault.
63    import faulthandler
64    faulthandler.enable()
65except ImportError:
66    pass
67
68# Enable debug output for uncollectable objects. PyPy does not have set_debug.
69if hasattr(gc, 'set_debug'):
70    gc.set_debug(
71        gc.DEBUG_UNCOLLECTABLE |
72        getattr(gc, 'DEBUG_OBJECTS', 0) |
73        getattr(gc, 'DEBUG_INSTANCES', 0))
74
75# The host and port of a single mongod or mongos, or the seed host
76# for a replica set.
77host = os.environ.get("DB_IP", 'localhost')
78port = int(os.environ.get("DB_PORT", 27017))
79
80db_user = os.environ.get("DB_USER", "user")
81db_pwd = os.environ.get("DB_PASSWORD", "password")
82
83CERT_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)),
84                         'certificates')
85CLIENT_PEM = os.environ.get('CLIENT_PEM',
86                            os.path.join(CERT_PATH, 'client.pem'))
87CA_PEM = os.environ.get('CA_PEM', os.path.join(CERT_PATH, 'ca.pem'))
88
89TLS_OPTIONS = dict(tls=True)
90if CLIENT_PEM:
91    TLS_OPTIONS['tlsCertificateKeyFile'] = CLIENT_PEM
92if CA_PEM:
93    TLS_OPTIONS['tlsCAFile'] = CA_PEM
94
95COMPRESSORS = os.environ.get("COMPRESSORS")
96MONGODB_API_VERSION = os.environ.get("MONGODB_API_VERSION")
97TEST_LOADBALANCER = bool(os.environ.get("TEST_LOADBALANCER"))
98SINGLE_MONGOS_LB_URI = os.environ.get("SINGLE_MONGOS_LB_URI")
99MULTI_MONGOS_LB_URI = os.environ.get("MULTI_MONGOS_LB_URI")
100if TEST_LOADBALANCER:
101    # Remove after PYTHON-2712
102    from pymongo import pool
103    pool._MOCK_SERVICE_ID = True
104    res = parse_uri(SINGLE_MONGOS_LB_URI)
105    host, port = res['nodelist'][0]
106    db_user = res['username'] or db_user
107    db_pwd = res['password'] or db_pwd
108
109
110def is_server_resolvable():
111    """Returns True if 'server' is resolvable."""
112    socket_timeout = socket.getdefaulttimeout()
113    socket.setdefaulttimeout(1)
114    try:
115        try:
116            socket.gethostbyname('server')
117            return True
118        except socket.error:
119            return False
120    finally:
121        socket.setdefaulttimeout(socket_timeout)
122
123
124def _create_user(authdb, user, pwd=None, roles=None, **kwargs):
125    cmd = SON([('createUser', user)])
126    # X509 doesn't use a password
127    if pwd:
128        cmd['pwd'] = pwd
129    cmd['roles'] = roles or ['root']
130    cmd.update(**kwargs)
131    return authdb.command(cmd)
132
133
134class client_knobs(object):
135    def __init__(
136            self,
137            heartbeat_frequency=None,
138            min_heartbeat_interval=None,
139            kill_cursor_frequency=None,
140            events_queue_frequency=None):
141        self.heartbeat_frequency = heartbeat_frequency
142        self.min_heartbeat_interval = min_heartbeat_interval
143        self.kill_cursor_frequency = kill_cursor_frequency
144        self.events_queue_frequency = events_queue_frequency
145
146        self.old_heartbeat_frequency = None
147        self.old_min_heartbeat_interval = None
148        self.old_kill_cursor_frequency = None
149        self.old_events_queue_frequency = None
150        self._enabled = True
151        self._stack = None
152
153    def enable(self):
154        self.old_heartbeat_frequency = common.HEARTBEAT_FREQUENCY
155        self.old_min_heartbeat_interval = common.MIN_HEARTBEAT_INTERVAL
156        self.old_kill_cursor_frequency = common.KILL_CURSOR_FREQUENCY
157        self.old_events_queue_frequency = common.EVENTS_QUEUE_FREQUENCY
158
159        if self.heartbeat_frequency is not None:
160            common.HEARTBEAT_FREQUENCY = self.heartbeat_frequency
161
162        if self.min_heartbeat_interval is not None:
163            common.MIN_HEARTBEAT_INTERVAL = self.min_heartbeat_interval
164
165        if self.kill_cursor_frequency is not None:
166            common.KILL_CURSOR_FREQUENCY = self.kill_cursor_frequency
167
168        if self.events_queue_frequency is not None:
169            common.EVENTS_QUEUE_FREQUENCY = self.events_queue_frequency
170        self._enabled = True
171        # Store the allocation traceback to catch non-disabled client_knobs.
172        self._stack = ''.join(traceback.format_stack())
173
174    def __enter__(self):
175        self.enable()
176
177    def disable(self):
178        common.HEARTBEAT_FREQUENCY = self.old_heartbeat_frequency
179        common.MIN_HEARTBEAT_INTERVAL = self.old_min_heartbeat_interval
180        common.KILL_CURSOR_FREQUENCY = self.old_kill_cursor_frequency
181        common.EVENTS_QUEUE_FREQUENCY = self.old_events_queue_frequency
182        self._enabled = False
183
184    def __exit__(self, exc_type, exc_val, exc_tb):
185        self.disable()
186
187    def __del__(self):
188        if self._enabled:
189            msg = (
190                'ERROR: client_knobs still enabled! HEARTBEAT_FREQUENCY=%s, '
191                'MIN_HEARTBEAT_INTERVAL=%s, KILL_CURSOR_FREQUENCY=%s, '
192                'EVENTS_QUEUE_FREQUENCY=%s, stack:\n%s' % (
193                    common.HEARTBEAT_FREQUENCY,
194                    common.MIN_HEARTBEAT_INTERVAL,
195                    common.KILL_CURSOR_FREQUENCY,
196                    common.EVENTS_QUEUE_FREQUENCY,
197                    self._stack))
198            self.disable()
199            raise Exception(msg)
200
201
202def _all_users(db):
203    return set(u['user'] for u in db.command('usersInfo').get('users', []))
204
205
206class ClientContext(object):
207    MULTI_MONGOS_LB_URI = MULTI_MONGOS_LB_URI
208
209    def __init__(self):
210        """Create a client and grab essential information from the server."""
211        self.connection_attempts = []
212        self.connected = False
213        self.w = None
214        self.nodes = set()
215        self.replica_set_name = None
216        self.cmd_line = None
217        self.server_status = None
218        self.version = Version(-1)  # Needs to be comparable with Version
219        self.auth_enabled = False
220        self.test_commands_enabled = False
221        self.server_parameters = None
222        self.is_mongos = False
223        self.mongoses = []
224        self.is_rs = False
225        self.has_ipv6 = False
226        self.tls = False
227        self.ssl_certfile = False
228        self.server_is_resolvable = is_server_resolvable()
229        self.default_client_options = {}
230        self.sessions_enabled = False
231        self.client = None
232        self.conn_lock = threading.Lock()
233        self.is_data_lake = False
234        self.load_balancer = TEST_LOADBALANCER
235        if self.load_balancer:
236            self.default_client_options["loadBalanced"] = True
237        if COMPRESSORS:
238            self.default_client_options["compressors"] = COMPRESSORS
239        if MONGODB_API_VERSION:
240            server_api = ServerApi(MONGODB_API_VERSION)
241            self.default_client_options["server_api"] = server_api
242
243    @property
244    def hello(self):
245        return self.client.admin.command(HelloCompat.LEGACY_CMD)
246
247    def _connect(self, host, port, **kwargs):
248        # Jython takes a long time to connect.
249        if sys.platform.startswith('java'):
250            timeout_ms = 10000
251        else:
252            timeout_ms = 5000
253        kwargs.update(self.default_client_options)
254        client = pymongo.MongoClient(
255            host, port, serverSelectionTimeoutMS=timeout_ms, **kwargs)
256        try:
257            try:
258                client.admin.command('ping')  # Can we connect?
259            except pymongo.errors.OperationFailure as exc:
260                # SERVER-32063
261                self.connection_attempts.append(
262                    'connected client %r, but hello failed: %s' % (
263                        client, exc))
264            else:
265                self.connection_attempts.append(
266                    'successfully connected client %r' % (client,))
267            # If connected, then return client with default timeout
268            return pymongo.MongoClient(host, port, **kwargs)
269        except pymongo.errors.ConnectionFailure as exc:
270            self.connection_attempts.append(
271                'failed to connect client %r: %s' % (client, exc))
272            return None
273        finally:
274            client.close()
275
276    def _init_client(self):
277        self.client = self._connect(host, port)
278
279        if self.client is not None:
280            # Return early when connected to dataLake as mongohoused does not
281            # support the getCmdLineOpts command and is tested without TLS.
282            build_info = self.client.admin.command('buildInfo')
283            if 'dataLake' in build_info:
284                self.is_data_lake = True
285                self.auth_enabled = True
286                self.client = self._connect(
287                    host, port, username=db_user, password=db_pwd)
288                self.connected = True
289                return
290
291        if HAVE_SSL and not self.client:
292            # Is MongoDB configured for SSL?
293            self.client = self._connect(host, port, **TLS_OPTIONS)
294            if self.client:
295                self.tls = True
296                self.default_client_options.update(TLS_OPTIONS)
297                self.ssl_certfile = True
298
299        if self.client:
300            self.connected = True
301
302            try:
303                self.cmd_line = self.client.admin.command('getCmdLineOpts')
304            except pymongo.errors.OperationFailure as e:
305                msg = e.details.get('errmsg', '')
306                if e.code == 13 or 'unauthorized' in msg or 'login' in msg:
307                    # Unauthorized.
308                    self.auth_enabled = True
309                else:
310                    raise
311            else:
312                self.auth_enabled = self._server_started_with_auth()
313
314            if self.auth_enabled:
315                # See if db_user already exists.
316                if not self._check_user_provided():
317                    _create_user(self.client.admin, db_user, db_pwd)
318
319                self.client = self._connect(
320                    host, port, username=db_user, password=db_pwd,
321                    replicaSet=self.replica_set_name,
322                    **self.default_client_options)
323
324                # May not have this if OperationFailure was raised earlier.
325                self.cmd_line = self.client.admin.command('getCmdLineOpts')
326
327            self.server_status = self.client.admin.command('serverStatus')
328            if self.storage_engine == "mmapv1":
329                # MMAPv1 does not support retryWrites=True.
330                self.default_client_options['retryWrites'] = False
331
332            hello = self.hello
333            self.sessions_enabled = 'logicalSessionTimeoutMinutes' in hello
334
335            if 'setName' in hello:
336                self.replica_set_name = str(hello['setName'])
337                self.is_rs = True
338                if self.auth_enabled:
339                    # It doesn't matter which member we use as the seed here.
340                    self.client = pymongo.MongoClient(
341                        host,
342                        port,
343                        username=db_user,
344                        password=db_pwd,
345                        replicaSet=self.replica_set_name,
346                        **self.default_client_options)
347                else:
348                    self.client = pymongo.MongoClient(
349                        host,
350                        port,
351                        replicaSet=self.replica_set_name,
352                        **self.default_client_options)
353
354                # Get the authoritative hello result from the primary.
355                hello = self.hello
356                nodes = [partition_node(node.lower())
357                         for node in hello.get('hosts', [])]
358                nodes.extend([partition_node(node.lower())
359                              for node in hello.get('passives', [])])
360                nodes.extend([partition_node(node.lower())
361                              for node in hello.get('arbiters', [])])
362                self.nodes = set(nodes)
363            else:
364                self.nodes = set([(host, port)])
365            self.w = len(hello.get("hosts", [])) or 1
366            self.version = Version.from_client(self.client)
367            self.server_parameters = self.client.admin.command(
368                'getParameter', '*')
369
370            if 'enableTestCommands=1' in self.cmd_line['argv']:
371                self.test_commands_enabled = True
372            elif 'parsed' in self.cmd_line:
373                params = self.cmd_line['parsed'].get('setParameter', [])
374                if 'enableTestCommands=1' in params:
375                    self.test_commands_enabled = True
376                else:
377                    params = self.cmd_line['parsed'].get('setParameter', {})
378                    if params.get('enableTestCommands') == '1':
379                        self.test_commands_enabled = True
380
381            self.is_mongos = (self.hello.get('msg') == 'isdbgrid')
382            self.has_ipv6 = self._server_started_with_ipv6()
383            if self.is_mongos:
384                # Check for another mongos on the next port.
385                address = self.client.address
386                next_address = address[0], address[1] + 1
387                self.mongoses.append(address)
388                mongos_client = self._connect(*next_address,
389                                              **self.default_client_options)
390                if mongos_client:
391                    hello = mongos_client.admin.command(HelloCompat.LEGACY_CMD)
392                    if hello.get('msg') == 'isdbgrid':
393                        self.mongoses.append(next_address)
394
395    def init(self):
396        with self.conn_lock:
397            if not self.client and not self.connection_attempts:
398                self._init_client()
399
400    def connection_attempt_info(self):
401        return '\n'.join(self.connection_attempts)
402
403    @property
404    def host(self):
405        if self.is_rs:
406            primary = self.client.primary
407            return str(primary[0]) if primary is not None else host
408        return host
409
410    @property
411    def port(self):
412        if self.is_rs:
413            primary = self.client.primary
414            return primary[1] if primary is not None else port
415        return port
416
417    @property
418    def pair(self):
419        return "%s:%d" % (self.host, self.port)
420
421    @property
422    def has_secondaries(self):
423        if not self.client:
424            return False
425        return bool(len(self.client.secondaries))
426
427    @property
428    def storage_engine(self):
429        try:
430            return self.server_status.get("storageEngine", {}).get("name")
431        except AttributeError:
432            # Raised if self.server_status is None.
433            return None
434
435    def _check_user_provided(self):
436        """Return True if db_user/db_password is already an admin user."""
437        client = pymongo.MongoClient(
438            host, port,
439            username=db_user,
440            password=db_pwd,
441            serverSelectionTimeoutMS=100,
442            **self.default_client_options)
443
444        try:
445            return db_user in _all_users(client.admin)
446        except pymongo.errors.OperationFailure as e:
447            msg = e.details.get('errmsg', '')
448            if e.code == 18 or 'auth fails' in msg:
449                # Auth failed.
450                return False
451            else:
452                raise
453
454    def _server_started_with_auth(self):
455        # MongoDB >= 2.0
456        if 'parsed' in self.cmd_line:
457            parsed = self.cmd_line['parsed']
458            # MongoDB >= 2.6
459            if 'security' in parsed:
460                security = parsed['security']
461                # >= rc3
462                if 'authorization' in security:
463                    return security['authorization'] == 'enabled'
464                # < rc3
465                return (security.get('auth', False) or
466                        bool(security.get('keyFile')))
467            return parsed.get('auth', False) or bool(parsed.get('keyFile'))
468        # Legacy
469        argv = self.cmd_line['argv']
470        return '--auth' in argv or '--keyFile' in argv
471
472    def _server_started_with_ipv6(self):
473        if not socket.has_ipv6:
474            return False
475
476        if 'parsed' in self.cmd_line:
477            if not self.cmd_line['parsed'].get('net', {}).get('ipv6'):
478                return False
479        else:
480            if '--ipv6' not in self.cmd_line['argv']:
481                return False
482
483        # The server was started with --ipv6. Is there an IPv6 route to it?
484        try:
485            for info in socket.getaddrinfo(self.host, self.port):
486                if info[0] == socket.AF_INET6:
487                    return True
488        except socket.error:
489            pass
490
491        return False
492
493    def _require(self, condition, msg, func=None):
494        def make_wrapper(f):
495            @wraps(f)
496            def wrap(*args, **kwargs):
497                self.init()
498                # Always raise SkipTest if we can't connect to MongoDB
499                if not self.connected:
500                    raise SkipTest(
501                        "Cannot connect to MongoDB on %s" % (self.pair,))
502                if condition():
503                    return f(*args, **kwargs)
504                raise SkipTest(msg)
505            return wrap
506
507        if func is None:
508            def decorate(f):
509                return make_wrapper(f)
510            return decorate
511        return make_wrapper(func)
512
513    def create_user(self, dbname, user, pwd=None, roles=None, **kwargs):
514        kwargs['writeConcern'] = {'w': self.w}
515        return _create_user(self.client[dbname], user, pwd, roles, **kwargs)
516
517    def drop_user(self, dbname, user):
518        self.client[dbname].command(
519            'dropUser', user, writeConcern={'w': self.w})
520
521    def require_connection(self, func):
522        """Run a test only if we can connect to MongoDB."""
523        return self._require(
524            lambda: True,  # _require checks if we're connected
525            "Cannot connect to MongoDB on %s" % (self.pair,),
526            func=func)
527
528    def require_data_lake(self, func):
529        """Run a test only if we are connected to Atlas Data Lake."""
530        return self._require(
531            lambda: self.is_data_lake,
532            "Not connected to Atlas Data Lake on %s" % (self.pair,),
533            func=func)
534
535    def require_no_mmap(self, func):
536        """Run a test only if the server is not using the MMAPv1 storage
537        engine. Only works for standalone and replica sets; tests are
538        run regardless of storage engine on sharded clusters. """
539        def is_not_mmap():
540            if self.is_mongos:
541                return True
542            return self.storage_engine != 'mmapv1'
543
544        return self._require(
545            is_not_mmap, "Storage engine must not be MMAPv1", func=func)
546
547    def require_version_min(self, *ver):
548        """Run a test only if the server version is at least ``version``."""
549        other_version = Version(*ver)
550        return self._require(lambda: self.version >= other_version,
551                             "Server version must be at least %s"
552                             % str(other_version))
553
554    def require_version_max(self, *ver):
555        """Run a test only if the server version is at most ``version``."""
556        other_version = Version(*ver)
557        return self._require(lambda: self.version <= other_version,
558                             "Server version must be at most %s"
559                             % str(other_version))
560
561    def require_auth(self, func):
562        """Run a test only if the server is running with auth enabled."""
563        return self.check_auth_with_sharding(
564            self._require(lambda: self.auth_enabled,
565                          "Authentication is not enabled on the server",
566                          func=func))
567
568    def require_no_auth(self, func):
569        """Run a test only if the server is running without auth enabled."""
570        return self._require(lambda: not self.auth_enabled,
571                             "Authentication must not be enabled on the server",
572                             func=func)
573
574    def require_replica_set(self, func):
575        """Run a test only if the client is connected to a replica set."""
576        return self._require(lambda: self.is_rs,
577                             "Not connected to a replica set",
578                             func=func)
579
580    def require_secondaries_count(self, count):
581        """Run a test only if the client is connected to a replica set that has
582        `count` secondaries.
583        """
584        def sec_count():
585            return 0 if not self.client else len(self.client.secondaries)
586        return self._require(lambda: sec_count() >= count,
587                             "Not enough secondaries available")
588
589    @property
590    def supports_secondary_read_pref(self):
591        if self.has_secondaries:
592            return True
593        if self.is_mongos:
594            shard = self.client.config.shards.find_one()['host']
595            num_members = shard.count(',') + 1
596            return num_members > 1
597        return False
598
599    def require_secondary_read_pref(self):
600        """Run a test only if the client is connected to a cluster that
601        supports secondary read preference
602        """
603        return self._require(lambda: self.supports_secondary_read_pref,
604                             "This cluster does not support secondary read "
605                             "preference")
606
607    def require_no_replica_set(self, func):
608        """Run a test if the client is *not* connected to a replica set."""
609        return self._require(
610            lambda: not self.is_rs,
611            "Connected to a replica set, not a standalone mongod",
612            func=func)
613
614    def require_ipv6(self, func):
615        """Run a test only if the client can connect to a server via IPv6."""
616        return self._require(lambda: self.has_ipv6,
617                             "No IPv6",
618                             func=func)
619
620    def require_no_mongos(self, func):
621        """Run a test only if the client is not connected to a mongos."""
622        return self._require(lambda: not self.is_mongos,
623                             "Must be connected to a mongod, not a mongos",
624                             func=func)
625
626    def require_mongos(self, func):
627        """Run a test only if the client is connected to a mongos."""
628        return self._require(lambda: self.is_mongos,
629                             "Must be connected to a mongos",
630                             func=func)
631
632    def require_multiple_mongoses(self, func):
633        """Run a test only if the client is connected to a sharded cluster
634        that has 2 mongos nodes."""
635        return self._require(lambda: len(self.mongoses) > 1,
636                             "Must have multiple mongoses available",
637                             func=func)
638
639    def require_standalone(self, func):
640        """Run a test only if the client is connected to a standalone."""
641        return self._require(lambda: not (self.is_mongos or self.is_rs),
642                             "Must be connected to a standalone",
643                             func=func)
644
645    def require_no_standalone(self, func):
646        """Run a test only if the client is not connected to a standalone."""
647        return self._require(lambda: self.is_mongos or self.is_rs,
648                             "Must be connected to a replica set or mongos",
649                             func=func)
650
651    def require_load_balancer(self, func):
652        """Run a test only if the client is connected to a load balancer."""
653        return self._require(lambda: self.load_balancer,
654                             "Must be connected to a load balancer",
655                             func=func)
656
657    def require_no_load_balancer(self, func):
658        """Run a test only if the client is not connected to a load balancer.
659        """
660        return self._require(lambda: not self.load_balancer,
661                             "Must not be connected to a load balancer",
662                             func=func)
663
664    def check_auth_with_sharding(self, func):
665        """Skip a test when connected to mongos < 2.0 and running with auth."""
666        condition = lambda: not (self.auth_enabled and
667                         self.is_mongos and self.version < (2,))
668        return self._require(condition,
669                             "Auth with sharding requires MongoDB >= 2.0.0",
670                             func=func)
671
672    def is_topology_type(self, topologies):
673        unknown = set(topologies) - {'single', 'replicaset', 'sharded',
674                                     'sharded-replicaset', 'load-balanced'}
675        if unknown:
676            raise AssertionError('Unknown topologies: %r' % (unknown,))
677        if self.load_balancer:
678            if 'load-balanced' in topologies:
679                return True
680            return False
681        if 'single' in topologies and not (self.is_mongos or self.is_rs):
682            return True
683        if 'replicaset' in topologies and self.is_rs:
684            return True
685        if 'sharded' in topologies and self.is_mongos:
686            return True
687        if 'sharded-replicaset' in topologies and self.is_mongos:
688            shards = list(client_context.client.config.shards.find())
689            for shard in shards:
690                # For a 3-member RS-backed sharded cluster, shard['host']
691                # will be 'replicaName/ip1:port1,ip2:port2,ip3:port3'
692                # Otherwise it will be 'ip1:port1'
693                host_spec = shard['host']
694                if not len(host_spec.split('/')) > 1:
695                    return False
696            return True
697        return False
698
699    def require_cluster_type(self, topologies=[]):
700        """Run a test only if the client is connected to a cluster that
701        conforms to one of the specified topologies. Acceptable topologies
702        are 'single', 'replicaset', and 'sharded'."""
703        def _is_valid_topology():
704            return self.is_topology_type(topologies)
705        return self._require(
706            _is_valid_topology,
707            "Cluster type not in %s" % (topologies))
708
709    def require_test_commands(self, func):
710        """Run a test only if the server has test commands enabled."""
711        return self._require(lambda: self.test_commands_enabled,
712                             "Test commands must be enabled",
713                             func=func)
714
715    def require_failCommand_fail_point(self, func):
716        """Run a test only if the server supports the failCommand fail
717        point."""
718        return self._require(lambda: self.supports_failCommand_fail_point,
719                             "failCommand fail point must be supported",
720                             func=func)
721
722    def require_failCommand_appName(self, func):
723        """Run a test only if the server supports the failCommand appName."""
724        # SERVER-47195
725        return self._require(lambda: (self.test_commands_enabled and
726                                      self.version >= (4, 4, -1)),
727                             "failCommand appName must be supported",
728                             func=func)
729
730    def require_tls(self, func):
731        """Run a test only if the client can connect over TLS."""
732        return self._require(lambda: self.tls,
733                             "Must be able to connect via TLS",
734                             func=func)
735
736    def require_no_tls(self, func):
737        """Run a test only if the client can connect over TLS."""
738        return self._require(lambda: not self.tls,
739                             "Must be able to connect without TLS",
740                             func=func)
741
742    def require_ssl_certfile(self, func):
743        """Run a test only if the client can connect with ssl_certfile."""
744        return self._require(lambda: self.ssl_certfile,
745                             "Must be able to connect with ssl_certfile",
746                             func=func)
747
748    def require_server_resolvable(self, func):
749        """Run a test only if the hostname 'server' is resolvable."""
750        return self._require(lambda: self.server_is_resolvable,
751                             "No hosts entry for 'server'. Cannot validate "
752                             "hostname in the certificate",
753                             func=func)
754
755    def require_sessions(self, func):
756        """Run a test only if the deployment supports sessions."""
757        return self._require(lambda: self.sessions_enabled,
758                             "Sessions not supported",
759                             func=func)
760
761    def supports_transactions(self):
762        if self.storage_engine == 'mmapv1':
763            return False
764
765        if self.version.at_least(4, 1, 8):
766            return self.is_mongos or self.is_rs
767
768        if self.version.at_least(4, 0):
769            return self.is_rs
770
771        return False
772
773    def require_transactions(self, func):
774        """Run a test only if the deployment might support transactions.
775
776        *Might* because this does not test the storage engine or FCV.
777        """
778        return self._require(self.supports_transactions,
779                             "Transactions are not supported",
780                             func=func)
781
782    def require_no_api_version(self, func):
783        """Skip this test when testing with requireApiVersion."""
784        return self._require(lambda: not MONGODB_API_VERSION,
785                             "This test does not work with requireApiVersion",
786                             func=func)
787
788    def mongos_seeds(self):
789        return ','.join('%s:%s' % address for address in self.mongoses)
790
791    @property
792    def supports_reindex(self):
793        """Does the connected server support reindex?"""
794        return not ((self.version.at_least(4, 1, 0) and self.is_mongos) or
795                    (self.version.at_least(4, 5, 0) and (
796                            self.is_mongos or self.is_rs)))
797
798    @property
799    def supports_getpreverror(self):
800        """Does the connected server support getpreverror?"""
801        return not (self.version.at_least(4, 1, 0) or self.is_mongos)
802
803    @property
804    def supports_failCommand_fail_point(self):
805        """Does the server support the failCommand fail point?"""
806        if self.is_mongos:
807            return (self.version.at_least(4, 1, 5) and
808                    self.test_commands_enabled)
809        else:
810            return (self.version.at_least(4, 0) and
811                    self.test_commands_enabled)
812
813
814    @property
815    def requires_hint_with_min_max_queries(self):
816        """Does the server require a hint with min/max queries."""
817        # Changed in SERVER-39567.
818        return self.version.at_least(4, 1, 10)
819
820
821# Reusable client context
822client_context = ClientContext()
823
824
825def sanitize_cmd(cmd):
826    cp = cmd.copy()
827    cp.pop('$clusterTime', None)
828    cp.pop('$db', None)
829    cp.pop('$readPreference', None)
830    cp.pop('lsid', None)
831    if MONGODB_API_VERSION:
832        # Versioned api parameters
833        cp.pop('apiVersion', None)
834    # OP_MSG encoding may move the payload type one field to the
835    # end of the command. Do the same here.
836    name = next(iter(cp))
837    try:
838        identifier = message._FIELD_MAP[name]
839        docs = cp.pop(identifier)
840        cp[identifier] = docs
841    except KeyError:
842        pass
843    return cp
844
845
846def sanitize_reply(reply):
847    cp = reply.copy()
848    cp.pop('$clusterTime', None)
849    cp.pop('operationTime', None)
850    return cp
851
852
853class PyMongoTestCase(unittest.TestCase):
854    def assertEqualCommand(self, expected, actual, msg=None):
855        self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg)
856
857    def assertEqualReply(self, expected, actual, msg=None):
858        self.assertEqual(sanitize_reply(expected), sanitize_reply(actual), msg)
859
860    @contextmanager
861    def fail_point(self, command_args):
862        cmd_on = SON([('configureFailPoint', 'failCommand')])
863        cmd_on.update(command_args)
864        client_context.client.admin.command(cmd_on)
865        try:
866            yield
867        finally:
868            client_context.client.admin.command(
869                'configureFailPoint', cmd_on['configureFailPoint'], mode='off')
870
871
872class IntegrationTest(PyMongoTestCase):
873    """Base class for TestCases that need a connection to MongoDB to pass."""
874
875    @classmethod
876    @client_context.require_connection
877    def setUpClass(cls):
878        if (client_context.load_balancer and
879                not getattr(cls, 'RUN_ON_LOAD_BALANCER', False)):
880            raise SkipTest('this test does not support load balancers')
881        cls.client = client_context.client
882        cls.db = cls.client.pymongo_test
883        if client_context.auth_enabled:
884            cls.credentials = {'username': db_user, 'password': db_pwd}
885        else:
886            cls.credentials = {}
887
888    def patch_system_certs(self, ca_certs):
889        patcher = SystemCertsPatcher(ca_certs)
890        self.addCleanup(patcher.disable)
891
892
893# Use assertRaisesRegex if available, otherwise use Python 2.7's
894# deprecated assertRaisesRegexp, with a 'p'.
895if not hasattr(unittest.TestCase, 'assertRaisesRegex'):
896    unittest.TestCase.assertRaisesRegex = unittest.TestCase.assertRaisesRegexp
897
898
899class MockClientTest(unittest.TestCase):
900    """Base class for TestCases that use MockClient.
901
902    This class is *not* an IntegrationTest: if properly written, MockClient
903    tests do not require a running server.
904
905    The class temporarily overrides HEARTBEAT_FREQUENCY to speed up tests.
906    """
907
908    # MockClients tests that use replicaSet, directConnection=True, pass
909    # multiple seed addresses, or wait for heartbeat events are incompatible
910    # with loadBalanced=True.
911    @classmethod
912    @client_context.require_no_load_balancer
913    def setUpClass(cls):
914        pass
915
916    def setUp(self):
917        super(MockClientTest, self).setUp()
918
919        self.client_knobs = client_knobs(
920            heartbeat_frequency=0.001,
921            min_heartbeat_interval=0.001)
922
923        self.client_knobs.enable()
924
925    def tearDown(self):
926        self.client_knobs.disable()
927        super(MockClientTest, self).tearDown()
928
929
930def setup():
931    client_context.init()
932    warnings.resetwarnings()
933    warnings.simplefilter("always")
934
935
936def _get_executors(topology):
937    executors = []
938    for server in topology._servers.values():
939        # Some MockMonitor do not have an _executor.
940        if hasattr(server._monitor, '_executor'):
941            executors.append(server._monitor._executor)
942        if hasattr(server._monitor, '_rtt_monitor'):
943            executors.append(server._monitor._rtt_monitor._executor)
944    executors.append(topology._Topology__events_executor)
945    if topology._srv_monitor:
946        executors.append(topology._srv_monitor._executor)
947
948    return [e for e in executors if e is not None]
949
950
951def all_executors_stopped(topology):
952    running = [e for e in _get_executors(topology) if not e._stopped]
953    if running:
954        print('  Topology %s has THREADS RUNNING: %s, created at: %s' % (
955            topology, running, topology._settings._stack))
956        return False
957    return True
958
959
960def print_unclosed_clients():
961    from pymongo.topology import Topology
962    processed = set()
963    # Call collect to manually cleanup any would-be gc'd clients to avoid
964    # false positives.
965    gc.collect()
966    for obj in gc.get_objects():
967        try:
968            if isinstance(obj, Topology):
969                # Avoid printing the same Topology multiple times.
970                if obj._topology_id in processed:
971                    continue
972                all_executors_stopped(obj)
973                processed.add(obj._topology_id)
974        except ReferenceError:
975            pass
976
977
978def teardown():
979    garbage = []
980    for g in gc.garbage:
981        garbage.append('GARBAGE: %r' % (g,))
982        garbage.append('  gc.get_referents: %r' % (gc.get_referents(g),))
983        garbage.append('  gc.get_referrers: %r' % (gc.get_referrers(g),))
984    if garbage:
985        assert False, '\n'.join(garbage)
986    c = client_context.client
987    if c:
988        if not client_context.is_data_lake:
989            c.drop_database("pymongo-pooling-tests")
990            c.drop_database("pymongo_test")
991            c.drop_database("pymongo_test1")
992            c.drop_database("pymongo_test2")
993            c.drop_database("pymongo_test_mike")
994            c.drop_database("pymongo_test_bernie")
995        c.close()
996
997    # Jython does not support gc.get_objects.
998    if not sys.platform.startswith('java'):
999        print_unclosed_clients()
1000
1001
1002class PymongoTestRunner(unittest.TextTestRunner):
1003    def run(self, test):
1004        setup()
1005        result = super(PymongoTestRunner, self).run(test)
1006        teardown()
1007        return result
1008
1009
1010if HAVE_XML:
1011    class PymongoXMLTestRunner(XMLTestRunner):
1012        def run(self, test):
1013            setup()
1014            result = super(PymongoXMLTestRunner, self).run(test)
1015            teardown()
1016            return result
1017
1018
1019def test_cases(suite):
1020    """Iterator over all TestCases within a TestSuite."""
1021    for suite_or_case in suite._tests:
1022        if isinstance(suite_or_case, unittest.TestCase):
1023            # unittest.TestCase
1024            yield suite_or_case
1025        else:
1026            # unittest.TestSuite
1027            for case in test_cases(suite_or_case):
1028                yield case
1029
1030
1031# Helper method to workaround https://bugs.python.org/issue21724
1032def clear_warning_registry():
1033    """Clear the __warningregistry__ for all modules."""
1034    for name, module in list(sys.modules.items()):
1035        if hasattr(module, "__warningregistry__"):
1036            setattr(module, "__warningregistry__", {})
1037
1038
1039class SystemCertsPatcher(object):
1040    def __init__(self, ca_certs):
1041        if sys.version_info < (2, 7, 9):
1042            raise SkipTest("Can't load system CA certificates.")
1043        if (ssl.OPENSSL_VERSION.lower().startswith('libressl') and
1044                sys.platform == 'darwin' and not _ssl.IS_PYOPENSSL):
1045            raise SkipTest(
1046                "LibreSSL on OSX doesn't support setting CA certificates "
1047                "using SSL_CERT_FILE environment variable.")
1048        self.original_certs = os.environ.get('SSL_CERT_FILE')
1049        # Tell OpenSSL where CA certificates live.
1050        os.environ['SSL_CERT_FILE'] = ca_certs
1051
1052    def disable(self):
1053        if self.original_certs is None:
1054            os.environ.pop('SSL_CERT_FILE')
1055        else:
1056            os.environ['SSL_CERT_FILE'] = self.original_certs
1057