1# Copyright 2010-present MongoDB, Inc. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Test suite for pymongo, bson, and gridfs. 16""" 17 18import gc 19import os 20import socket 21import sys 22import threading 23import time 24import traceback 25import unittest 26import warnings 27 28try: 29 from xmlrunner import XMLTestRunner 30 HAVE_XML = True 31# ValueError is raised when version 3+ is installed on Jython 2.7. 32except (ImportError, ValueError): 33 HAVE_XML = False 34 35try: 36 import ipaddress 37 HAVE_IPADDRESS = True 38except ImportError: 39 HAVE_IPADDRESS = False 40 41from contextlib import contextmanager 42from functools import wraps 43from unittest import SkipTest 44 45import pymongo 46import pymongo.errors 47 48from bson.son import SON 49from pymongo import common, message 50from pymongo.common import partition_node 51from pymongo.hello_compat import HelloCompat 52from pymongo.server_api import ServerApi 53from pymongo.ssl_support import HAVE_SSL, _ssl 54from pymongo.uri_parser import parse_uri 55from test.version import Version 56 57if HAVE_SSL: 58 import ssl 59 60try: 61 # Enable the fault handler to dump the traceback of each running thread 62 # after a segfault. 63 import faulthandler 64 faulthandler.enable() 65except ImportError: 66 pass 67 68# Enable debug output for uncollectable objects. PyPy does not have set_debug. 69if hasattr(gc, 'set_debug'): 70 gc.set_debug( 71 gc.DEBUG_UNCOLLECTABLE | 72 getattr(gc, 'DEBUG_OBJECTS', 0) | 73 getattr(gc, 'DEBUG_INSTANCES', 0)) 74 75# The host and port of a single mongod or mongos, or the seed host 76# for a replica set. 77host = os.environ.get("DB_IP", 'localhost') 78port = int(os.environ.get("DB_PORT", 27017)) 79 80db_user = os.environ.get("DB_USER", "user") 81db_pwd = os.environ.get("DB_PASSWORD", "password") 82 83CERT_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), 84 'certificates') 85CLIENT_PEM = os.environ.get('CLIENT_PEM', 86 os.path.join(CERT_PATH, 'client.pem')) 87CA_PEM = os.environ.get('CA_PEM', os.path.join(CERT_PATH, 'ca.pem')) 88 89TLS_OPTIONS = dict(tls=True) 90if CLIENT_PEM: 91 TLS_OPTIONS['tlsCertificateKeyFile'] = CLIENT_PEM 92if CA_PEM: 93 TLS_OPTIONS['tlsCAFile'] = CA_PEM 94 95COMPRESSORS = os.environ.get("COMPRESSORS") 96MONGODB_API_VERSION = os.environ.get("MONGODB_API_VERSION") 97TEST_LOADBALANCER = bool(os.environ.get("TEST_LOADBALANCER")) 98SINGLE_MONGOS_LB_URI = os.environ.get("SINGLE_MONGOS_LB_URI") 99MULTI_MONGOS_LB_URI = os.environ.get("MULTI_MONGOS_LB_URI") 100if TEST_LOADBALANCER: 101 # Remove after PYTHON-2712 102 from pymongo import pool 103 pool._MOCK_SERVICE_ID = True 104 res = parse_uri(SINGLE_MONGOS_LB_URI) 105 host, port = res['nodelist'][0] 106 db_user = res['username'] or db_user 107 db_pwd = res['password'] or db_pwd 108 109 110def is_server_resolvable(): 111 """Returns True if 'server' is resolvable.""" 112 socket_timeout = socket.getdefaulttimeout() 113 socket.setdefaulttimeout(1) 114 try: 115 try: 116 socket.gethostbyname('server') 117 return True 118 except socket.error: 119 return False 120 finally: 121 socket.setdefaulttimeout(socket_timeout) 122 123 124def _create_user(authdb, user, pwd=None, roles=None, **kwargs): 125 cmd = SON([('createUser', user)]) 126 # X509 doesn't use a password 127 if pwd: 128 cmd['pwd'] = pwd 129 cmd['roles'] = roles or ['root'] 130 cmd.update(**kwargs) 131 return authdb.command(cmd) 132 133 134class client_knobs(object): 135 def __init__( 136 self, 137 heartbeat_frequency=None, 138 min_heartbeat_interval=None, 139 kill_cursor_frequency=None, 140 events_queue_frequency=None): 141 self.heartbeat_frequency = heartbeat_frequency 142 self.min_heartbeat_interval = min_heartbeat_interval 143 self.kill_cursor_frequency = kill_cursor_frequency 144 self.events_queue_frequency = events_queue_frequency 145 146 self.old_heartbeat_frequency = None 147 self.old_min_heartbeat_interval = None 148 self.old_kill_cursor_frequency = None 149 self.old_events_queue_frequency = None 150 self._enabled = True 151 self._stack = None 152 153 def enable(self): 154 self.old_heartbeat_frequency = common.HEARTBEAT_FREQUENCY 155 self.old_min_heartbeat_interval = common.MIN_HEARTBEAT_INTERVAL 156 self.old_kill_cursor_frequency = common.KILL_CURSOR_FREQUENCY 157 self.old_events_queue_frequency = common.EVENTS_QUEUE_FREQUENCY 158 159 if self.heartbeat_frequency is not None: 160 common.HEARTBEAT_FREQUENCY = self.heartbeat_frequency 161 162 if self.min_heartbeat_interval is not None: 163 common.MIN_HEARTBEAT_INTERVAL = self.min_heartbeat_interval 164 165 if self.kill_cursor_frequency is not None: 166 common.KILL_CURSOR_FREQUENCY = self.kill_cursor_frequency 167 168 if self.events_queue_frequency is not None: 169 common.EVENTS_QUEUE_FREQUENCY = self.events_queue_frequency 170 self._enabled = True 171 # Store the allocation traceback to catch non-disabled client_knobs. 172 self._stack = ''.join(traceback.format_stack()) 173 174 def __enter__(self): 175 self.enable() 176 177 def disable(self): 178 common.HEARTBEAT_FREQUENCY = self.old_heartbeat_frequency 179 common.MIN_HEARTBEAT_INTERVAL = self.old_min_heartbeat_interval 180 common.KILL_CURSOR_FREQUENCY = self.old_kill_cursor_frequency 181 common.EVENTS_QUEUE_FREQUENCY = self.old_events_queue_frequency 182 self._enabled = False 183 184 def __exit__(self, exc_type, exc_val, exc_tb): 185 self.disable() 186 187 def __del__(self): 188 if self._enabled: 189 msg = ( 190 'ERROR: client_knobs still enabled! HEARTBEAT_FREQUENCY=%s, ' 191 'MIN_HEARTBEAT_INTERVAL=%s, KILL_CURSOR_FREQUENCY=%s, ' 192 'EVENTS_QUEUE_FREQUENCY=%s, stack:\n%s' % ( 193 common.HEARTBEAT_FREQUENCY, 194 common.MIN_HEARTBEAT_INTERVAL, 195 common.KILL_CURSOR_FREQUENCY, 196 common.EVENTS_QUEUE_FREQUENCY, 197 self._stack)) 198 self.disable() 199 raise Exception(msg) 200 201 202def _all_users(db): 203 return set(u['user'] for u in db.command('usersInfo').get('users', [])) 204 205 206class ClientContext(object): 207 MULTI_MONGOS_LB_URI = MULTI_MONGOS_LB_URI 208 209 def __init__(self): 210 """Create a client and grab essential information from the server.""" 211 self.connection_attempts = [] 212 self.connected = False 213 self.w = None 214 self.nodes = set() 215 self.replica_set_name = None 216 self.cmd_line = None 217 self.server_status = None 218 self.version = Version(-1) # Needs to be comparable with Version 219 self.auth_enabled = False 220 self.test_commands_enabled = False 221 self.server_parameters = None 222 self.is_mongos = False 223 self.mongoses = [] 224 self.is_rs = False 225 self.has_ipv6 = False 226 self.tls = False 227 self.ssl_certfile = False 228 self.server_is_resolvable = is_server_resolvable() 229 self.default_client_options = {} 230 self.sessions_enabled = False 231 self.client = None 232 self.conn_lock = threading.Lock() 233 self.is_data_lake = False 234 self.load_balancer = TEST_LOADBALANCER 235 if self.load_balancer: 236 self.default_client_options["loadBalanced"] = True 237 if COMPRESSORS: 238 self.default_client_options["compressors"] = COMPRESSORS 239 if MONGODB_API_VERSION: 240 server_api = ServerApi(MONGODB_API_VERSION) 241 self.default_client_options["server_api"] = server_api 242 243 @property 244 def hello(self): 245 return self.client.admin.command(HelloCompat.LEGACY_CMD) 246 247 def _connect(self, host, port, **kwargs): 248 # Jython takes a long time to connect. 249 if sys.platform.startswith('java'): 250 timeout_ms = 10000 251 else: 252 timeout_ms = 5000 253 kwargs.update(self.default_client_options) 254 client = pymongo.MongoClient( 255 host, port, serverSelectionTimeoutMS=timeout_ms, **kwargs) 256 try: 257 try: 258 client.admin.command('ping') # Can we connect? 259 except pymongo.errors.OperationFailure as exc: 260 # SERVER-32063 261 self.connection_attempts.append( 262 'connected client %r, but hello failed: %s' % ( 263 client, exc)) 264 else: 265 self.connection_attempts.append( 266 'successfully connected client %r' % (client,)) 267 # If connected, then return client with default timeout 268 return pymongo.MongoClient(host, port, **kwargs) 269 except pymongo.errors.ConnectionFailure as exc: 270 self.connection_attempts.append( 271 'failed to connect client %r: %s' % (client, exc)) 272 return None 273 finally: 274 client.close() 275 276 def _init_client(self): 277 self.client = self._connect(host, port) 278 279 if self.client is not None: 280 # Return early when connected to dataLake as mongohoused does not 281 # support the getCmdLineOpts command and is tested without TLS. 282 build_info = self.client.admin.command('buildInfo') 283 if 'dataLake' in build_info: 284 self.is_data_lake = True 285 self.auth_enabled = True 286 self.client = self._connect( 287 host, port, username=db_user, password=db_pwd) 288 self.connected = True 289 return 290 291 if HAVE_SSL and not self.client: 292 # Is MongoDB configured for SSL? 293 self.client = self._connect(host, port, **TLS_OPTIONS) 294 if self.client: 295 self.tls = True 296 self.default_client_options.update(TLS_OPTIONS) 297 self.ssl_certfile = True 298 299 if self.client: 300 self.connected = True 301 302 try: 303 self.cmd_line = self.client.admin.command('getCmdLineOpts') 304 except pymongo.errors.OperationFailure as e: 305 msg = e.details.get('errmsg', '') 306 if e.code == 13 or 'unauthorized' in msg or 'login' in msg: 307 # Unauthorized. 308 self.auth_enabled = True 309 else: 310 raise 311 else: 312 self.auth_enabled = self._server_started_with_auth() 313 314 if self.auth_enabled: 315 # See if db_user already exists. 316 if not self._check_user_provided(): 317 _create_user(self.client.admin, db_user, db_pwd) 318 319 self.client = self._connect( 320 host, port, username=db_user, password=db_pwd, 321 replicaSet=self.replica_set_name, 322 **self.default_client_options) 323 324 # May not have this if OperationFailure was raised earlier. 325 self.cmd_line = self.client.admin.command('getCmdLineOpts') 326 327 self.server_status = self.client.admin.command('serverStatus') 328 if self.storage_engine == "mmapv1": 329 # MMAPv1 does not support retryWrites=True. 330 self.default_client_options['retryWrites'] = False 331 332 hello = self.hello 333 self.sessions_enabled = 'logicalSessionTimeoutMinutes' in hello 334 335 if 'setName' in hello: 336 self.replica_set_name = str(hello['setName']) 337 self.is_rs = True 338 if self.auth_enabled: 339 # It doesn't matter which member we use as the seed here. 340 self.client = pymongo.MongoClient( 341 host, 342 port, 343 username=db_user, 344 password=db_pwd, 345 replicaSet=self.replica_set_name, 346 **self.default_client_options) 347 else: 348 self.client = pymongo.MongoClient( 349 host, 350 port, 351 replicaSet=self.replica_set_name, 352 **self.default_client_options) 353 354 # Get the authoritative hello result from the primary. 355 hello = self.hello 356 nodes = [partition_node(node.lower()) 357 for node in hello.get('hosts', [])] 358 nodes.extend([partition_node(node.lower()) 359 for node in hello.get('passives', [])]) 360 nodes.extend([partition_node(node.lower()) 361 for node in hello.get('arbiters', [])]) 362 self.nodes = set(nodes) 363 else: 364 self.nodes = set([(host, port)]) 365 self.w = len(hello.get("hosts", [])) or 1 366 self.version = Version.from_client(self.client) 367 self.server_parameters = self.client.admin.command( 368 'getParameter', '*') 369 370 if 'enableTestCommands=1' in self.cmd_line['argv']: 371 self.test_commands_enabled = True 372 elif 'parsed' in self.cmd_line: 373 params = self.cmd_line['parsed'].get('setParameter', []) 374 if 'enableTestCommands=1' in params: 375 self.test_commands_enabled = True 376 else: 377 params = self.cmd_line['parsed'].get('setParameter', {}) 378 if params.get('enableTestCommands') == '1': 379 self.test_commands_enabled = True 380 381 self.is_mongos = (self.hello.get('msg') == 'isdbgrid') 382 self.has_ipv6 = self._server_started_with_ipv6() 383 if self.is_mongos: 384 # Check for another mongos on the next port. 385 address = self.client.address 386 next_address = address[0], address[1] + 1 387 self.mongoses.append(address) 388 mongos_client = self._connect(*next_address, 389 **self.default_client_options) 390 if mongos_client: 391 hello = mongos_client.admin.command(HelloCompat.LEGACY_CMD) 392 if hello.get('msg') == 'isdbgrid': 393 self.mongoses.append(next_address) 394 395 def init(self): 396 with self.conn_lock: 397 if not self.client and not self.connection_attempts: 398 self._init_client() 399 400 def connection_attempt_info(self): 401 return '\n'.join(self.connection_attempts) 402 403 @property 404 def host(self): 405 if self.is_rs: 406 primary = self.client.primary 407 return str(primary[0]) if primary is not None else host 408 return host 409 410 @property 411 def port(self): 412 if self.is_rs: 413 primary = self.client.primary 414 return primary[1] if primary is not None else port 415 return port 416 417 @property 418 def pair(self): 419 return "%s:%d" % (self.host, self.port) 420 421 @property 422 def has_secondaries(self): 423 if not self.client: 424 return False 425 return bool(len(self.client.secondaries)) 426 427 @property 428 def storage_engine(self): 429 try: 430 return self.server_status.get("storageEngine", {}).get("name") 431 except AttributeError: 432 # Raised if self.server_status is None. 433 return None 434 435 def _check_user_provided(self): 436 """Return True if db_user/db_password is already an admin user.""" 437 client = pymongo.MongoClient( 438 host, port, 439 username=db_user, 440 password=db_pwd, 441 serverSelectionTimeoutMS=100, 442 **self.default_client_options) 443 444 try: 445 return db_user in _all_users(client.admin) 446 except pymongo.errors.OperationFailure as e: 447 msg = e.details.get('errmsg', '') 448 if e.code == 18 or 'auth fails' in msg: 449 # Auth failed. 450 return False 451 else: 452 raise 453 454 def _server_started_with_auth(self): 455 # MongoDB >= 2.0 456 if 'parsed' in self.cmd_line: 457 parsed = self.cmd_line['parsed'] 458 # MongoDB >= 2.6 459 if 'security' in parsed: 460 security = parsed['security'] 461 # >= rc3 462 if 'authorization' in security: 463 return security['authorization'] == 'enabled' 464 # < rc3 465 return (security.get('auth', False) or 466 bool(security.get('keyFile'))) 467 return parsed.get('auth', False) or bool(parsed.get('keyFile')) 468 # Legacy 469 argv = self.cmd_line['argv'] 470 return '--auth' in argv or '--keyFile' in argv 471 472 def _server_started_with_ipv6(self): 473 if not socket.has_ipv6: 474 return False 475 476 if 'parsed' in self.cmd_line: 477 if not self.cmd_line['parsed'].get('net', {}).get('ipv6'): 478 return False 479 else: 480 if '--ipv6' not in self.cmd_line['argv']: 481 return False 482 483 # The server was started with --ipv6. Is there an IPv6 route to it? 484 try: 485 for info in socket.getaddrinfo(self.host, self.port): 486 if info[0] == socket.AF_INET6: 487 return True 488 except socket.error: 489 pass 490 491 return False 492 493 def _require(self, condition, msg, func=None): 494 def make_wrapper(f): 495 @wraps(f) 496 def wrap(*args, **kwargs): 497 self.init() 498 # Always raise SkipTest if we can't connect to MongoDB 499 if not self.connected: 500 raise SkipTest( 501 "Cannot connect to MongoDB on %s" % (self.pair,)) 502 if condition(): 503 return f(*args, **kwargs) 504 raise SkipTest(msg) 505 return wrap 506 507 if func is None: 508 def decorate(f): 509 return make_wrapper(f) 510 return decorate 511 return make_wrapper(func) 512 513 def create_user(self, dbname, user, pwd=None, roles=None, **kwargs): 514 kwargs['writeConcern'] = {'w': self.w} 515 return _create_user(self.client[dbname], user, pwd, roles, **kwargs) 516 517 def drop_user(self, dbname, user): 518 self.client[dbname].command( 519 'dropUser', user, writeConcern={'w': self.w}) 520 521 def require_connection(self, func): 522 """Run a test only if we can connect to MongoDB.""" 523 return self._require( 524 lambda: True, # _require checks if we're connected 525 "Cannot connect to MongoDB on %s" % (self.pair,), 526 func=func) 527 528 def require_data_lake(self, func): 529 """Run a test only if we are connected to Atlas Data Lake.""" 530 return self._require( 531 lambda: self.is_data_lake, 532 "Not connected to Atlas Data Lake on %s" % (self.pair,), 533 func=func) 534 535 def require_no_mmap(self, func): 536 """Run a test only if the server is not using the MMAPv1 storage 537 engine. Only works for standalone and replica sets; tests are 538 run regardless of storage engine on sharded clusters. """ 539 def is_not_mmap(): 540 if self.is_mongos: 541 return True 542 return self.storage_engine != 'mmapv1' 543 544 return self._require( 545 is_not_mmap, "Storage engine must not be MMAPv1", func=func) 546 547 def require_version_min(self, *ver): 548 """Run a test only if the server version is at least ``version``.""" 549 other_version = Version(*ver) 550 return self._require(lambda: self.version >= other_version, 551 "Server version must be at least %s" 552 % str(other_version)) 553 554 def require_version_max(self, *ver): 555 """Run a test only if the server version is at most ``version``.""" 556 other_version = Version(*ver) 557 return self._require(lambda: self.version <= other_version, 558 "Server version must be at most %s" 559 % str(other_version)) 560 561 def require_auth(self, func): 562 """Run a test only if the server is running with auth enabled.""" 563 return self.check_auth_with_sharding( 564 self._require(lambda: self.auth_enabled, 565 "Authentication is not enabled on the server", 566 func=func)) 567 568 def require_no_auth(self, func): 569 """Run a test only if the server is running without auth enabled.""" 570 return self._require(lambda: not self.auth_enabled, 571 "Authentication must not be enabled on the server", 572 func=func) 573 574 def require_replica_set(self, func): 575 """Run a test only if the client is connected to a replica set.""" 576 return self._require(lambda: self.is_rs, 577 "Not connected to a replica set", 578 func=func) 579 580 def require_secondaries_count(self, count): 581 """Run a test only if the client is connected to a replica set that has 582 `count` secondaries. 583 """ 584 def sec_count(): 585 return 0 if not self.client else len(self.client.secondaries) 586 return self._require(lambda: sec_count() >= count, 587 "Not enough secondaries available") 588 589 @property 590 def supports_secondary_read_pref(self): 591 if self.has_secondaries: 592 return True 593 if self.is_mongos: 594 shard = self.client.config.shards.find_one()['host'] 595 num_members = shard.count(',') + 1 596 return num_members > 1 597 return False 598 599 def require_secondary_read_pref(self): 600 """Run a test only if the client is connected to a cluster that 601 supports secondary read preference 602 """ 603 return self._require(lambda: self.supports_secondary_read_pref, 604 "This cluster does not support secondary read " 605 "preference") 606 607 def require_no_replica_set(self, func): 608 """Run a test if the client is *not* connected to a replica set.""" 609 return self._require( 610 lambda: not self.is_rs, 611 "Connected to a replica set, not a standalone mongod", 612 func=func) 613 614 def require_ipv6(self, func): 615 """Run a test only if the client can connect to a server via IPv6.""" 616 return self._require(lambda: self.has_ipv6, 617 "No IPv6", 618 func=func) 619 620 def require_no_mongos(self, func): 621 """Run a test only if the client is not connected to a mongos.""" 622 return self._require(lambda: not self.is_mongos, 623 "Must be connected to a mongod, not a mongos", 624 func=func) 625 626 def require_mongos(self, func): 627 """Run a test only if the client is connected to a mongos.""" 628 return self._require(lambda: self.is_mongos, 629 "Must be connected to a mongos", 630 func=func) 631 632 def require_multiple_mongoses(self, func): 633 """Run a test only if the client is connected to a sharded cluster 634 that has 2 mongos nodes.""" 635 return self._require(lambda: len(self.mongoses) > 1, 636 "Must have multiple mongoses available", 637 func=func) 638 639 def require_standalone(self, func): 640 """Run a test only if the client is connected to a standalone.""" 641 return self._require(lambda: not (self.is_mongos or self.is_rs), 642 "Must be connected to a standalone", 643 func=func) 644 645 def require_no_standalone(self, func): 646 """Run a test only if the client is not connected to a standalone.""" 647 return self._require(lambda: self.is_mongos or self.is_rs, 648 "Must be connected to a replica set or mongos", 649 func=func) 650 651 def require_load_balancer(self, func): 652 """Run a test only if the client is connected to a load balancer.""" 653 return self._require(lambda: self.load_balancer, 654 "Must be connected to a load balancer", 655 func=func) 656 657 def require_no_load_balancer(self, func): 658 """Run a test only if the client is not connected to a load balancer. 659 """ 660 return self._require(lambda: not self.load_balancer, 661 "Must not be connected to a load balancer", 662 func=func) 663 664 def check_auth_with_sharding(self, func): 665 """Skip a test when connected to mongos < 2.0 and running with auth.""" 666 condition = lambda: not (self.auth_enabled and 667 self.is_mongos and self.version < (2,)) 668 return self._require(condition, 669 "Auth with sharding requires MongoDB >= 2.0.0", 670 func=func) 671 672 def is_topology_type(self, topologies): 673 unknown = set(topologies) - {'single', 'replicaset', 'sharded', 674 'sharded-replicaset', 'load-balanced'} 675 if unknown: 676 raise AssertionError('Unknown topologies: %r' % (unknown,)) 677 if self.load_balancer: 678 if 'load-balanced' in topologies: 679 return True 680 return False 681 if 'single' in topologies and not (self.is_mongos or self.is_rs): 682 return True 683 if 'replicaset' in topologies and self.is_rs: 684 return True 685 if 'sharded' in topologies and self.is_mongos: 686 return True 687 if 'sharded-replicaset' in topologies and self.is_mongos: 688 shards = list(client_context.client.config.shards.find()) 689 for shard in shards: 690 # For a 3-member RS-backed sharded cluster, shard['host'] 691 # will be 'replicaName/ip1:port1,ip2:port2,ip3:port3' 692 # Otherwise it will be 'ip1:port1' 693 host_spec = shard['host'] 694 if not len(host_spec.split('/')) > 1: 695 return False 696 return True 697 return False 698 699 def require_cluster_type(self, topologies=[]): 700 """Run a test only if the client is connected to a cluster that 701 conforms to one of the specified topologies. Acceptable topologies 702 are 'single', 'replicaset', and 'sharded'.""" 703 def _is_valid_topology(): 704 return self.is_topology_type(topologies) 705 return self._require( 706 _is_valid_topology, 707 "Cluster type not in %s" % (topologies)) 708 709 def require_test_commands(self, func): 710 """Run a test only if the server has test commands enabled.""" 711 return self._require(lambda: self.test_commands_enabled, 712 "Test commands must be enabled", 713 func=func) 714 715 def require_failCommand_fail_point(self, func): 716 """Run a test only if the server supports the failCommand fail 717 point.""" 718 return self._require(lambda: self.supports_failCommand_fail_point, 719 "failCommand fail point must be supported", 720 func=func) 721 722 def require_failCommand_appName(self, func): 723 """Run a test only if the server supports the failCommand appName.""" 724 # SERVER-47195 725 return self._require(lambda: (self.test_commands_enabled and 726 self.version >= (4, 4, -1)), 727 "failCommand appName must be supported", 728 func=func) 729 730 def require_tls(self, func): 731 """Run a test only if the client can connect over TLS.""" 732 return self._require(lambda: self.tls, 733 "Must be able to connect via TLS", 734 func=func) 735 736 def require_no_tls(self, func): 737 """Run a test only if the client can connect over TLS.""" 738 return self._require(lambda: not self.tls, 739 "Must be able to connect without TLS", 740 func=func) 741 742 def require_ssl_certfile(self, func): 743 """Run a test only if the client can connect with ssl_certfile.""" 744 return self._require(lambda: self.ssl_certfile, 745 "Must be able to connect with ssl_certfile", 746 func=func) 747 748 def require_server_resolvable(self, func): 749 """Run a test only if the hostname 'server' is resolvable.""" 750 return self._require(lambda: self.server_is_resolvable, 751 "No hosts entry for 'server'. Cannot validate " 752 "hostname in the certificate", 753 func=func) 754 755 def require_sessions(self, func): 756 """Run a test only if the deployment supports sessions.""" 757 return self._require(lambda: self.sessions_enabled, 758 "Sessions not supported", 759 func=func) 760 761 def supports_transactions(self): 762 if self.storage_engine == 'mmapv1': 763 return False 764 765 if self.version.at_least(4, 1, 8): 766 return self.is_mongos or self.is_rs 767 768 if self.version.at_least(4, 0): 769 return self.is_rs 770 771 return False 772 773 def require_transactions(self, func): 774 """Run a test only if the deployment might support transactions. 775 776 *Might* because this does not test the storage engine or FCV. 777 """ 778 return self._require(self.supports_transactions, 779 "Transactions are not supported", 780 func=func) 781 782 def require_no_api_version(self, func): 783 """Skip this test when testing with requireApiVersion.""" 784 return self._require(lambda: not MONGODB_API_VERSION, 785 "This test does not work with requireApiVersion", 786 func=func) 787 788 def mongos_seeds(self): 789 return ','.join('%s:%s' % address for address in self.mongoses) 790 791 @property 792 def supports_reindex(self): 793 """Does the connected server support reindex?""" 794 return not ((self.version.at_least(4, 1, 0) and self.is_mongos) or 795 (self.version.at_least(4, 5, 0) and ( 796 self.is_mongos or self.is_rs))) 797 798 @property 799 def supports_getpreverror(self): 800 """Does the connected server support getpreverror?""" 801 return not (self.version.at_least(4, 1, 0) or self.is_mongos) 802 803 @property 804 def supports_failCommand_fail_point(self): 805 """Does the server support the failCommand fail point?""" 806 if self.is_mongos: 807 return (self.version.at_least(4, 1, 5) and 808 self.test_commands_enabled) 809 else: 810 return (self.version.at_least(4, 0) and 811 self.test_commands_enabled) 812 813 814 @property 815 def requires_hint_with_min_max_queries(self): 816 """Does the server require a hint with min/max queries.""" 817 # Changed in SERVER-39567. 818 return self.version.at_least(4, 1, 10) 819 820 821# Reusable client context 822client_context = ClientContext() 823 824 825def sanitize_cmd(cmd): 826 cp = cmd.copy() 827 cp.pop('$clusterTime', None) 828 cp.pop('$db', None) 829 cp.pop('$readPreference', None) 830 cp.pop('lsid', None) 831 if MONGODB_API_VERSION: 832 # Versioned api parameters 833 cp.pop('apiVersion', None) 834 # OP_MSG encoding may move the payload type one field to the 835 # end of the command. Do the same here. 836 name = next(iter(cp)) 837 try: 838 identifier = message._FIELD_MAP[name] 839 docs = cp.pop(identifier) 840 cp[identifier] = docs 841 except KeyError: 842 pass 843 return cp 844 845 846def sanitize_reply(reply): 847 cp = reply.copy() 848 cp.pop('$clusterTime', None) 849 cp.pop('operationTime', None) 850 return cp 851 852 853class PyMongoTestCase(unittest.TestCase): 854 def assertEqualCommand(self, expected, actual, msg=None): 855 self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg) 856 857 def assertEqualReply(self, expected, actual, msg=None): 858 self.assertEqual(sanitize_reply(expected), sanitize_reply(actual), msg) 859 860 @contextmanager 861 def fail_point(self, command_args): 862 cmd_on = SON([('configureFailPoint', 'failCommand')]) 863 cmd_on.update(command_args) 864 client_context.client.admin.command(cmd_on) 865 try: 866 yield 867 finally: 868 client_context.client.admin.command( 869 'configureFailPoint', cmd_on['configureFailPoint'], mode='off') 870 871 872class IntegrationTest(PyMongoTestCase): 873 """Base class for TestCases that need a connection to MongoDB to pass.""" 874 875 @classmethod 876 @client_context.require_connection 877 def setUpClass(cls): 878 if (client_context.load_balancer and 879 not getattr(cls, 'RUN_ON_LOAD_BALANCER', False)): 880 raise SkipTest('this test does not support load balancers') 881 cls.client = client_context.client 882 cls.db = cls.client.pymongo_test 883 if client_context.auth_enabled: 884 cls.credentials = {'username': db_user, 'password': db_pwd} 885 else: 886 cls.credentials = {} 887 888 def patch_system_certs(self, ca_certs): 889 patcher = SystemCertsPatcher(ca_certs) 890 self.addCleanup(patcher.disable) 891 892 893# Use assertRaisesRegex if available, otherwise use Python 2.7's 894# deprecated assertRaisesRegexp, with a 'p'. 895if not hasattr(unittest.TestCase, 'assertRaisesRegex'): 896 unittest.TestCase.assertRaisesRegex = unittest.TestCase.assertRaisesRegexp 897 898 899class MockClientTest(unittest.TestCase): 900 """Base class for TestCases that use MockClient. 901 902 This class is *not* an IntegrationTest: if properly written, MockClient 903 tests do not require a running server. 904 905 The class temporarily overrides HEARTBEAT_FREQUENCY to speed up tests. 906 """ 907 908 # MockClients tests that use replicaSet, directConnection=True, pass 909 # multiple seed addresses, or wait for heartbeat events are incompatible 910 # with loadBalanced=True. 911 @classmethod 912 @client_context.require_no_load_balancer 913 def setUpClass(cls): 914 pass 915 916 def setUp(self): 917 super(MockClientTest, self).setUp() 918 919 self.client_knobs = client_knobs( 920 heartbeat_frequency=0.001, 921 min_heartbeat_interval=0.001) 922 923 self.client_knobs.enable() 924 925 def tearDown(self): 926 self.client_knobs.disable() 927 super(MockClientTest, self).tearDown() 928 929 930def setup(): 931 client_context.init() 932 warnings.resetwarnings() 933 warnings.simplefilter("always") 934 935 936def _get_executors(topology): 937 executors = [] 938 for server in topology._servers.values(): 939 # Some MockMonitor do not have an _executor. 940 if hasattr(server._monitor, '_executor'): 941 executors.append(server._monitor._executor) 942 if hasattr(server._monitor, '_rtt_monitor'): 943 executors.append(server._monitor._rtt_monitor._executor) 944 executors.append(topology._Topology__events_executor) 945 if topology._srv_monitor: 946 executors.append(topology._srv_monitor._executor) 947 948 return [e for e in executors if e is not None] 949 950 951def all_executors_stopped(topology): 952 running = [e for e in _get_executors(topology) if not e._stopped] 953 if running: 954 print(' Topology %s has THREADS RUNNING: %s, created at: %s' % ( 955 topology, running, topology._settings._stack)) 956 return False 957 return True 958 959 960def print_unclosed_clients(): 961 from pymongo.topology import Topology 962 processed = set() 963 # Call collect to manually cleanup any would-be gc'd clients to avoid 964 # false positives. 965 gc.collect() 966 for obj in gc.get_objects(): 967 try: 968 if isinstance(obj, Topology): 969 # Avoid printing the same Topology multiple times. 970 if obj._topology_id in processed: 971 continue 972 all_executors_stopped(obj) 973 processed.add(obj._topology_id) 974 except ReferenceError: 975 pass 976 977 978def teardown(): 979 garbage = [] 980 for g in gc.garbage: 981 garbage.append('GARBAGE: %r' % (g,)) 982 garbage.append(' gc.get_referents: %r' % (gc.get_referents(g),)) 983 garbage.append(' gc.get_referrers: %r' % (gc.get_referrers(g),)) 984 if garbage: 985 assert False, '\n'.join(garbage) 986 c = client_context.client 987 if c: 988 if not client_context.is_data_lake: 989 c.drop_database("pymongo-pooling-tests") 990 c.drop_database("pymongo_test") 991 c.drop_database("pymongo_test1") 992 c.drop_database("pymongo_test2") 993 c.drop_database("pymongo_test_mike") 994 c.drop_database("pymongo_test_bernie") 995 c.close() 996 997 # Jython does not support gc.get_objects. 998 if not sys.platform.startswith('java'): 999 print_unclosed_clients() 1000 1001 1002class PymongoTestRunner(unittest.TextTestRunner): 1003 def run(self, test): 1004 setup() 1005 result = super(PymongoTestRunner, self).run(test) 1006 teardown() 1007 return result 1008 1009 1010if HAVE_XML: 1011 class PymongoXMLTestRunner(XMLTestRunner): 1012 def run(self, test): 1013 setup() 1014 result = super(PymongoXMLTestRunner, self).run(test) 1015 teardown() 1016 return result 1017 1018 1019def test_cases(suite): 1020 """Iterator over all TestCases within a TestSuite.""" 1021 for suite_or_case in suite._tests: 1022 if isinstance(suite_or_case, unittest.TestCase): 1023 # unittest.TestCase 1024 yield suite_or_case 1025 else: 1026 # unittest.TestSuite 1027 for case in test_cases(suite_or_case): 1028 yield case 1029 1030 1031# Helper method to workaround https://bugs.python.org/issue21724 1032def clear_warning_registry(): 1033 """Clear the __warningregistry__ for all modules.""" 1034 for name, module in list(sys.modules.items()): 1035 if hasattr(module, "__warningregistry__"): 1036 setattr(module, "__warningregistry__", {}) 1037 1038 1039class SystemCertsPatcher(object): 1040 def __init__(self, ca_certs): 1041 if sys.version_info < (2, 7, 9): 1042 raise SkipTest("Can't load system CA certificates.") 1043 if (ssl.OPENSSL_VERSION.lower().startswith('libressl') and 1044 sys.platform == 'darwin' and not _ssl.IS_PYOPENSSL): 1045 raise SkipTest( 1046 "LibreSSL on OSX doesn't support setting CA certificates " 1047 "using SSL_CERT_FILE environment variable.") 1048 self.original_certs = os.environ.get('SSL_CERT_FILE') 1049 # Tell OpenSSL where CA certificates live. 1050 os.environ['SSL_CERT_FILE'] = ca_certs 1051 1052 def disable(self): 1053 if self.original_certs is None: 1054 os.environ.pop('SSL_CERT_FILE') 1055 else: 1056 os.environ['SSL_CERT_FILE'] = self.original_certs 1057