1# Copyright 2012-present MongoDB, Inc. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Utilities for testing pymongo 16""" 17 18import contextlib 19import copy 20import functools 21import os 22import re 23import shutil 24import sys 25import threading 26import time 27import warnings 28 29from collections import defaultdict 30from functools import partial 31 32from bson import json_util, py3compat 33from bson.objectid import ObjectId 34from bson.py3compat import abc, iteritems, string_type 35from bson.son import SON 36 37from pymongo import (MongoClient, 38 monitoring, operations, read_preferences) 39from pymongo.collection import ReturnDocument 40from pymongo.errors import ConfigurationError, OperationFailure 41from pymongo.hello_compat import HelloCompat 42from pymongo.monitoring import _SENSITIVE_COMMANDS 43from pymongo.pool import (_CancellationContext, 44 PoolOptions, 45 _PoolGeneration) 46from pymongo.read_concern import ReadConcern 47from pymongo.read_preferences import ReadPreference 48from pymongo.server_selectors import (any_server_selector, 49 writable_server_selector) 50from pymongo.server_type import SERVER_TYPE 51from pymongo.write_concern import WriteConcern 52 53from test import (client_context, 54 db_user, 55 db_pwd) 56 57if sys.version_info[0] < 3: 58 # Python 2.7, use our backport. 59 from test.barrier import Barrier 60else: 61 from threading import Barrier 62 63 64IMPOSSIBLE_WRITE_CONCERN = WriteConcern(w=50) 65 66 67class BaseListener(object): 68 def __init__(self): 69 self.events = [] 70 71 def reset(self): 72 self.events = [] 73 74 def add_event(self, event): 75 self.events.append(event) 76 77 def event_count(self, event_type): 78 return len(self.events_by_type(event_type)) 79 80 def events_by_type(self, event_type): 81 """Return the matching events by event class. 82 83 event_type can be a single class or a tuple of classes. 84 """ 85 return self.matching(lambda e: isinstance(e, event_type)) 86 87 def matching(self, matcher): 88 """Return the matching events.""" 89 return [event for event in self.events[:] if matcher(event)] 90 91 def wait_for_event(self, event, count): 92 """Wait for a number of events to be published, or fail.""" 93 wait_until(lambda: self.event_count(event) >= count, 94 'find %s %s event(s)' % (count, event)) 95 96 97class CMAPListener(BaseListener, monitoring.ConnectionPoolListener): 98 99 def connection_created(self, event): 100 self.add_event(event) 101 102 def connection_ready(self, event): 103 self.add_event(event) 104 105 def connection_closed(self, event): 106 self.add_event(event) 107 108 def connection_check_out_started(self, event): 109 self.add_event(event) 110 111 def connection_check_out_failed(self, event): 112 self.add_event(event) 113 114 def connection_checked_out(self, event): 115 self.add_event(event) 116 117 def connection_checked_in(self, event): 118 self.add_event(event) 119 120 def pool_created(self, event): 121 self.add_event(event) 122 123 def pool_cleared(self, event): 124 self.add_event(event) 125 126 def pool_closed(self, event): 127 self.add_event(event) 128 129 130class EventListener(monitoring.CommandListener): 131 132 def __init__(self): 133 self.results = defaultdict(list) 134 135 def started(self, event): 136 self.results['started'].append(event) 137 138 def succeeded(self, event): 139 self.results['succeeded'].append(event) 140 141 def failed(self, event): 142 self.results['failed'].append(event) 143 144 def started_command_names(self): 145 """Return list of command names started.""" 146 return [event.command_name for event in self.results['started']] 147 148 def reset(self): 149 """Reset the state of this listener.""" 150 self.results.clear() 151 152 153class TopologyEventListener(monitoring.TopologyListener): 154 def __init__(self): 155 self.results = defaultdict(list) 156 157 def closed(self, event): 158 self.results['closed'].append(event) 159 160 def description_changed(self, event): 161 self.results['description_changed'].append(event) 162 163 def opened(self, event): 164 self.results['opened'].append(event) 165 166 def reset(self): 167 """Reset the state of this listener.""" 168 self.results.clear() 169 170 171class WhiteListEventListener(EventListener): 172 173 def __init__(self, *commands): 174 self.commands = set(commands) 175 super(WhiteListEventListener, self).__init__() 176 177 def started(self, event): 178 if event.command_name in self.commands: 179 super(WhiteListEventListener, self).started(event) 180 181 def succeeded(self, event): 182 if event.command_name in self.commands: 183 super(WhiteListEventListener, self).succeeded(event) 184 185 def failed(self, event): 186 if event.command_name in self.commands: 187 super(WhiteListEventListener, self).failed(event) 188 189 190class OvertCommandListener(EventListener): 191 """A CommandListener that ignores sensitive commands.""" 192 def started(self, event): 193 if event.command_name.lower() not in _SENSITIVE_COMMANDS: 194 super(OvertCommandListener, self).started(event) 195 196 def succeeded(self, event): 197 if event.command_name.lower() not in _SENSITIVE_COMMANDS: 198 super(OvertCommandListener, self).succeeded(event) 199 200 def failed(self, event): 201 if event.command_name.lower() not in _SENSITIVE_COMMANDS: 202 super(OvertCommandListener, self).failed(event) 203 204 205class _ServerEventListener(object): 206 """Listens to all events.""" 207 208 def __init__(self): 209 self.results = [] 210 211 def opened(self, event): 212 self.results.append(event) 213 214 def description_changed(self, event): 215 self.results.append(event) 216 217 def closed(self, event): 218 self.results.append(event) 219 220 def matching(self, matcher): 221 """Return the matching events.""" 222 results = self.results[:] 223 return [event for event in results if matcher(event)] 224 225 def reset(self): 226 self.results = [] 227 228 229class ServerEventListener(_ServerEventListener, 230 monitoring.ServerListener): 231 """Listens to Server events.""" 232 233 234class ServerAndTopologyEventListener(ServerEventListener, 235 monitoring.TopologyListener): 236 """Listens to Server and Topology events.""" 237 238 239class HeartbeatEventListener(BaseListener, monitoring.ServerHeartbeatListener): 240 """Listens to only server heartbeat events.""" 241 242 def started(self, event): 243 self.add_event(event) 244 245 def succeeded(self, event): 246 self.add_event(event) 247 248 def failed(self, event): 249 self.add_event(event) 250 251 252class MockSocketInfo(object): 253 def __init__(self): 254 self.cancel_context = _CancellationContext() 255 self.more_to_come = False 256 257 def close_socket(self, reason): 258 pass 259 260 def __enter__(self): 261 return self 262 263 def __exit__(self, exc_type, exc_val, exc_tb): 264 pass 265 266 267class MockPool(object): 268 def __init__(self, address, options, handshake=True): 269 self.gen = _PoolGeneration() 270 self._lock = threading.Lock() 271 self.opts = PoolOptions() 272 273 def stale_generation(self, gen, service_id): 274 return self.gen.stale(gen, service_id) 275 276 def get_socket(self, all_credentials, handler=None): 277 return MockSocketInfo() 278 279 def return_socket(self, *args, **kwargs): 280 pass 281 282 def _reset(self, service_id=None): 283 with self._lock: 284 self.gen.inc(service_id) 285 286 def reset(self, service_id=None): 287 self._reset() 288 289 def close(self): 290 self._reset() 291 292 def update_is_writable(self, is_writable): 293 pass 294 295 def remove_stale_sockets(self, *args, **kwargs): 296 pass 297 298 299class ScenarioDict(dict): 300 """Dict that returns {} for any unknown key, recursively.""" 301 def __init__(self, data): 302 def convert(v): 303 if isinstance(v, abc.Mapping): 304 return ScenarioDict(v) 305 if isinstance(v, (py3compat.string_type, bytes)): 306 return v 307 if isinstance(v, abc.Sequence): 308 return [convert(item) for item in v] 309 return v 310 311 dict.__init__(self, [(k, convert(v)) for k, v in data.items()]) 312 313 def __getitem__(self, item): 314 try: 315 return dict.__getitem__(self, item) 316 except KeyError: 317 # Unlike a defaultdict, don't set the key, just return a dict. 318 return ScenarioDict({}) 319 320 321class CompareType(object): 322 """Class that compares equal to any object of the given type.""" 323 def __init__(self, type): 324 self.type = type 325 326 def __eq__(self, other): 327 return isinstance(other, self.type) 328 329 def __ne__(self, other): 330 """Needed for Python 2.""" 331 return not self.__eq__(other) 332 333 334class FunctionCallRecorder(object): 335 """Utility class to wrap a callable and record its invocations.""" 336 def __init__(self, function): 337 self._function = function 338 self._call_list = [] 339 340 def __call__(self, *args, **kwargs): 341 self._call_list.append((args, kwargs)) 342 return self._function(*args, **kwargs) 343 344 def reset(self): 345 """Wipes the call list.""" 346 self._call_list = [] 347 348 def call_list(self): 349 """Returns a copy of the call list.""" 350 return self._call_list[:] 351 352 @property 353 def call_count(self): 354 """Returns the number of times the function has been called.""" 355 return len(self._call_list) 356 357 358class TestCreator(object): 359 """Class to create test cases from specifications.""" 360 def __init__(self, create_test, test_class, test_path): 361 """Create a TestCreator object. 362 363 :Parameters: 364 - `create_test`: callback that returns a test case. The callback 365 must accept the following arguments - a dictionary containing the 366 entire test specification (the `scenario_def`), a dictionary 367 containing the specification for which the test case will be 368 generated (the `test_def`). 369 - `test_class`: the unittest.TestCase class in which to create the 370 test case. 371 - `test_path`: path to the directory containing the JSON files with 372 the test specifications. 373 """ 374 self._create_test = create_test 375 self._test_class = test_class 376 self.test_path = test_path 377 378 def _ensure_min_max_server_version(self, scenario_def, method): 379 """Test modifier that enforces a version range for the server on a 380 test case.""" 381 if 'minServerVersion' in scenario_def: 382 min_ver = tuple( 383 int(elt) for 384 elt in scenario_def['minServerVersion'].split('.')) 385 if min_ver is not None: 386 method = client_context.require_version_min(*min_ver)(method) 387 388 if 'maxServerVersion' in scenario_def: 389 max_ver = tuple( 390 int(elt) for 391 elt in scenario_def['maxServerVersion'].split('.')) 392 if max_ver is not None: 393 method = client_context.require_version_max(*max_ver)(method) 394 395 return method 396 397 @staticmethod 398 def valid_topology(run_on_req): 399 return client_context.is_topology_type( 400 run_on_req.get('topology', ['single', 'replicaset', 'sharded', 401 'load-balanced'])) 402 403 @staticmethod 404 def min_server_version(run_on_req): 405 version = run_on_req.get('minServerVersion') 406 if version: 407 min_ver = tuple(int(elt) for elt in version.split('.')) 408 return client_context.version >= min_ver 409 return True 410 411 @staticmethod 412 def max_server_version(run_on_req): 413 version = run_on_req.get('maxServerVersion') 414 if version: 415 max_ver = tuple(int(elt) for elt in version.split('.')) 416 return client_context.version <= max_ver 417 return True 418 419 def should_run_on(self, scenario_def): 420 run_on = scenario_def.get('runOn', []) 421 if not run_on: 422 # Always run these tests. 423 return True 424 425 for req in run_on: 426 if (self.valid_topology(req) and 427 self.min_server_version(req) and 428 self.max_server_version(req)): 429 return True 430 return False 431 432 def ensure_run_on(self, scenario_def, method): 433 """Test modifier that enforces a 'runOn' on a test case.""" 434 return client_context._require( 435 lambda: self.should_run_on(scenario_def), 436 "runOn not satisfied", 437 method) 438 439 def tests(self, scenario_def): 440 """Allow CMAP spec test to override the location of test.""" 441 return scenario_def['tests'] 442 443 def create_tests(self): 444 for dirpath, _, filenames in os.walk(self.test_path): 445 dirname = os.path.split(dirpath)[-1] 446 447 for filename in filenames: 448 with open(os.path.join(dirpath, filename)) as scenario_stream: 449 # Use tz_aware=False to match how CodecOptions decodes 450 # dates. 451 opts = json_util.JSONOptions(tz_aware=False) 452 scenario_def = ScenarioDict( 453 json_util.loads(scenario_stream.read(), 454 json_options=opts)) 455 456 test_type = os.path.splitext(filename)[0] 457 458 # Construct test from scenario. 459 for test_def in self.tests(scenario_def): 460 test_name = 'test_%s_%s_%s' % ( 461 dirname, 462 test_type.replace("-", "_").replace('.', '_'), 463 str(test_def['description'].replace(" ", "_").replace( 464 '.', '_'))) 465 466 new_test = self._create_test( 467 scenario_def, test_def, test_name) 468 new_test = self._ensure_min_max_server_version( 469 scenario_def, new_test) 470 new_test = self.ensure_run_on( 471 scenario_def, new_test) 472 473 new_test.__name__ = test_name 474 setattr(self._test_class, new_test.__name__, new_test) 475 476 477def _connection_string(h, authenticate): 478 if h.startswith("mongodb://"): 479 return h 480 elif client_context.auth_enabled and authenticate: 481 return "mongodb://%s:%s@%s" % (db_user, db_pwd, str(h)) 482 else: 483 return "mongodb://%s" % (str(h),) 484 485 486def _mongo_client(host, port, authenticate=True, directConnection=False, 487 **kwargs): 488 """Create a new client over SSL/TLS if necessary.""" 489 host = host or client_context.host 490 port = port or client_context.port 491 client_options = client_context.default_client_options.copy() 492 if client_context.replica_set_name and not directConnection: 493 client_options['replicaSet'] = client_context.replica_set_name 494 client_options.update(kwargs) 495 496 client = MongoClient(_connection_string(host, authenticate), port, 497 **client_options) 498 499 return client 500 501 502def single_client_noauth(h=None, p=None, **kwargs): 503 """Make a direct connection. Don't authenticate.""" 504 return _mongo_client(h, p, authenticate=False, 505 directConnection=True, **kwargs) 506 507 508def single_client(h=None, p=None, **kwargs): 509 """Make a direct connection, and authenticate if necessary.""" 510 return _mongo_client(h, p, directConnection=True, **kwargs) 511 512 513def rs_client_noauth(h=None, p=None, **kwargs): 514 """Connect to the replica set. Don't authenticate.""" 515 return _mongo_client(h, p, authenticate=False, **kwargs) 516 517 518def rs_client(h=None, p=None, **kwargs): 519 """Connect to the replica set and authenticate if necessary.""" 520 return _mongo_client(h, p, **kwargs) 521 522 523def rs_or_single_client_noauth(h=None, p=None, **kwargs): 524 """Connect to the replica set if there is one, otherwise the standalone. 525 526 Like rs_or_single_client, but does not authenticate. 527 """ 528 return _mongo_client(h, p, authenticate=False, **kwargs) 529 530 531def rs_or_single_client(h=None, p=None, **kwargs): 532 """Connect to the replica set if there is one, otherwise the standalone. 533 534 Authenticates if necessary. 535 """ 536 return _mongo_client(h, p, **kwargs) 537 538 539def ensure_all_connected(client): 540 """Ensure that the client's connection pool has socket connections to all 541 members of a replica set. Raises ConfigurationError when called with a 542 non-replica set client. 543 544 Depending on the use-case, the caller may need to clear any event listeners 545 that are configured on the client. 546 """ 547 hello = client.admin.command(HelloCompat.LEGACY_CMD) 548 if 'setName' not in hello: 549 raise ConfigurationError("cluster is not a replica set") 550 551 target_host_list = set(hello['hosts']) 552 connected_host_list = set([hello['me']]) 553 admindb = client.get_database('admin') 554 555 # Run legacy hello until we have connected to each host at least once. 556 while connected_host_list != target_host_list: 557 hello = admindb.command(HelloCompat.LEGACY_CMD, 558 read_preference=ReadPreference.SECONDARY) 559 connected_host_list.update([hello["me"]]) 560 561 562def one(s): 563 """Get one element of a set""" 564 return next(iter(s)) 565 566 567def oid_generated_on_process(oid): 568 """Makes a determination as to whether the given ObjectId was generated 569 by the current process, based on the 5-byte random number in the ObjectId. 570 """ 571 return ObjectId._random() == oid.binary[4:9] 572 573 574def delay(sec): 575 return '''function() { sleep(%f * 1000); return true; }''' % sec 576 577 578def get_command_line(client): 579 command_line = client.admin.command('getCmdLineOpts') 580 assert command_line['ok'] == 1, "getCmdLineOpts() failed" 581 return command_line 582 583 584def camel_to_snake(camel): 585 # Regex to convert CamelCase to snake_case. 586 snake = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', camel) 587 return re.sub('([a-z0-9])([A-Z])', r'\1_\2', snake).lower() 588 589 590def camel_to_upper_camel(camel): 591 return camel[0].upper() + camel[1:] 592 593 594def camel_to_snake_args(arguments): 595 for arg_name in list(arguments): 596 c2s = camel_to_snake(arg_name) 597 arguments[c2s] = arguments.pop(arg_name) 598 return arguments 599 600 601def snake_to_camel(snake): 602 # Regex to convert snake_case to lowerCamelCase. 603 return re.sub(r'_([a-z])', lambda m: m.group(1).upper(), snake) 604 605 606def parse_collection_options(opts): 607 if 'readPreference' in opts: 608 opts['read_preference'] = parse_read_preference( 609 opts.pop('readPreference')) 610 611 if 'writeConcern' in opts: 612 opts['write_concern'] = WriteConcern( 613 **dict(opts.pop('writeConcern'))) 614 615 if 'readConcern' in opts: 616 opts['read_concern'] = ReadConcern( 617 **dict(opts.pop('readConcern'))) 618 return opts 619 620 621def server_started_with_option(client, cmdline_opt, config_opt): 622 """Check if the server was started with a particular option. 623 624 :Parameters: 625 - `cmdline_opt`: The command line option (i.e. --nojournal) 626 - `config_opt`: The config file option (i.e. nojournal) 627 """ 628 command_line = get_command_line(client) 629 if 'parsed' in command_line: 630 parsed = command_line['parsed'] 631 if config_opt in parsed: 632 return parsed[config_opt] 633 argv = command_line['argv'] 634 return cmdline_opt in argv 635 636 637def server_started_with_auth(client): 638 try: 639 command_line = get_command_line(client) 640 except OperationFailure as e: 641 msg = e.details.get('errmsg', '') 642 if e.code == 13 or 'unauthorized' in msg or 'login' in msg: 643 # Unauthorized. 644 return True 645 raise 646 647 # MongoDB >= 2.0 648 if 'parsed' in command_line: 649 parsed = command_line['parsed'] 650 # MongoDB >= 2.6 651 if 'security' in parsed: 652 security = parsed['security'] 653 # >= rc3 654 if 'authorization' in security: 655 return security['authorization'] == 'enabled' 656 # < rc3 657 return security.get('auth', False) or bool(security.get('keyFile')) 658 return parsed.get('auth', False) or bool(parsed.get('keyFile')) 659 # Legacy 660 argv = command_line['argv'] 661 return '--auth' in argv or '--keyFile' in argv 662 663 664def server_started_with_nojournal(client): 665 command_line = get_command_line(client) 666 667 # MongoDB 2.6. 668 if 'parsed' in command_line: 669 parsed = command_line['parsed'] 670 if 'storage' in parsed: 671 storage = parsed['storage'] 672 if 'journal' in storage: 673 return not storage['journal']['enabled'] 674 675 return server_started_with_option(client, '--nojournal', 'nojournal') 676 677 678def drop_collections(db): 679 # Drop all non-system collections in this database. 680 for coll in db.list_collection_names( 681 filter={"name": {"$regex": r"^(?!system\.)"}}): 682 db.drop_collection(coll) 683 684 685def remove_all_users(db): 686 db.command("dropAllUsersFromDatabase", 1, 687 writeConcern={"w": client_context.w}) 688 689 690def joinall(threads): 691 """Join threads with a 5-minute timeout, assert joins succeeded""" 692 for t in threads: 693 t.join(300) 694 assert not t.is_alive(), "Thread %s hung" % t 695 696 697def connected(client): 698 """Convenience to wait for a newly-constructed client to connect.""" 699 with warnings.catch_warnings(): 700 # Ignore warning that "ping" is always routed to primary even 701 # if client's read preference isn't PRIMARY. 702 warnings.simplefilter("ignore", UserWarning) 703 client.admin.command('ping') # Force connection. 704 705 return client 706 707 708def wait_until(predicate, success_description, timeout=10): 709 """Wait up to 10 seconds (by default) for predicate to be true. 710 711 E.g.: 712 713 wait_until(lambda: client.primary == ('a', 1), 714 'connect to the primary') 715 716 If the lambda-expression isn't true after 10 seconds, we raise 717 AssertionError("Didn't ever connect to the primary"). 718 719 Returns the predicate's first true value. 720 """ 721 start = time.time() 722 interval = min(float(timeout)/100, 0.1) 723 while True: 724 retval = predicate() 725 if retval: 726 return retval 727 728 if time.time() - start > timeout: 729 raise AssertionError("Didn't ever %s" % success_description) 730 731 time.sleep(interval) 732 733 734def repl_set_step_down(client, **kwargs): 735 """Run replSetStepDown, first unfreezing a secondary with replSetFreeze.""" 736 cmd = SON([('replSetStepDown', 1)]) 737 cmd.update(kwargs) 738 739 # Unfreeze a secondary to ensure a speedy election. 740 client.admin.command( 741 'replSetFreeze', 0, read_preference=ReadPreference.SECONDARY) 742 client.admin.command(cmd) 743 744def is_mongos(client): 745 res = client.admin.command(HelloCompat.LEGACY_CMD) 746 return res.get('msg', '') == 'isdbgrid' 747 748 749def assertRaisesExactly(cls, fn, *args, **kwargs): 750 """ 751 Unlike the standard assertRaises, this checks that a function raises a 752 specific class of exception, and not a subclass. E.g., check that 753 MongoClient() raises ConnectionFailure but not its subclass, AutoReconnect. 754 """ 755 try: 756 fn(*args, **kwargs) 757 except Exception as e: 758 assert e.__class__ == cls, "got %s, expected %s" % ( 759 e.__class__.__name__, cls.__name__) 760 else: 761 raise AssertionError("%s not raised" % cls) 762 763 764@contextlib.contextmanager 765def _ignore_deprecations(): 766 with warnings.catch_warnings(): 767 warnings.simplefilter("ignore", DeprecationWarning) 768 yield 769 770 771def ignore_deprecations(wrapped=None): 772 """A context manager or a decorator.""" 773 if wrapped: 774 @functools.wraps(wrapped) 775 def wrapper(*args, **kwargs): 776 with _ignore_deprecations(): 777 return wrapped(*args, **kwargs) 778 779 return wrapper 780 781 else: 782 return _ignore_deprecations() 783 784 785class DeprecationFilter(object): 786 787 def __init__(self, action="ignore"): 788 """Start filtering deprecations.""" 789 self.warn_context = warnings.catch_warnings() 790 self.warn_context.__enter__() 791 warnings.simplefilter(action, DeprecationWarning) 792 793 def stop(self): 794 """Stop filtering deprecations.""" 795 self.warn_context.__exit__() 796 self.warn_context = None 797 798 799def get_pool(client): 800 """Get the standalone, primary, or mongos pool.""" 801 topology = client._get_topology() 802 server = topology.select_server(writable_server_selector) 803 return server.pool 804 805 806def get_pools(client): 807 """Get all pools.""" 808 return [ 809 server.pool for server in 810 client._get_topology().select_servers(any_server_selector)] 811 812 813# Constants for run_threads and lazy_client_trial. 814NTRIALS = 5 815NTHREADS = 10 816 817 818def run_threads(collection, target): 819 """Run a target function in many threads. 820 821 target is a function taking a Collection and an integer. 822 """ 823 threads = [] 824 for i in range(NTHREADS): 825 bound_target = partial(target, collection, i) 826 threads.append(threading.Thread(target=bound_target)) 827 828 for t in threads: 829 t.start() 830 831 for t in threads: 832 t.join(60) 833 assert not t.is_alive() 834 835 836@contextlib.contextmanager 837def frequent_thread_switches(): 838 """Make concurrency bugs more likely to manifest.""" 839 interval = None 840 if not sys.platform.startswith('java'): 841 if hasattr(sys, 'getswitchinterval'): 842 interval = sys.getswitchinterval() 843 sys.setswitchinterval(1e-6) 844 else: 845 interval = sys.getcheckinterval() 846 sys.setcheckinterval(1) 847 848 try: 849 yield 850 finally: 851 if not sys.platform.startswith('java'): 852 if hasattr(sys, 'setswitchinterval'): 853 sys.setswitchinterval(interval) 854 else: 855 sys.setcheckinterval(interval) 856 857 858def lazy_client_trial(reset, target, test, get_client): 859 """Test concurrent operations on a lazily-connecting client. 860 861 `reset` takes a collection and resets it for the next trial. 862 863 `target` takes a lazily-connecting collection and an index from 864 0 to NTHREADS, and performs some operation, e.g. an insert. 865 866 `test` takes the lazily-connecting collection and asserts a 867 post-condition to prove `target` succeeded. 868 """ 869 collection = client_context.client.pymongo_test.test 870 871 with frequent_thread_switches(): 872 for i in range(NTRIALS): 873 reset(collection) 874 lazy_client = get_client() 875 lazy_collection = lazy_client.pymongo_test.test 876 run_threads(lazy_collection, target) 877 test(lazy_collection) 878 879 880def gevent_monkey_patched(): 881 """Check if gevent's monkey patching is active.""" 882 # In Python 3.6 importing gevent.socket raises an ImportWarning. 883 with warnings.catch_warnings(): 884 warnings.simplefilter("ignore", ImportWarning) 885 try: 886 import socket 887 import gevent.socket 888 return socket.socket is gevent.socket.socket 889 except ImportError: 890 return False 891 892 893def eventlet_monkey_patched(): 894 """Check if eventlet's monkey patching is active.""" 895 try: 896 import threading 897 import eventlet 898 return (threading.current_thread.__module__ == 899 'eventlet.green.threading') 900 except ImportError: 901 return False 902 903 904def is_greenthread_patched(): 905 return gevent_monkey_patched() or eventlet_monkey_patched() 906 907 908def cdecimal_patched(): 909 """Check if Python 2.7 cdecimal patching is active.""" 910 try: 911 import decimal 912 import cdecimal 913 return decimal is cdecimal 914 except ImportError: 915 return False 916 917 918def disable_replication(client): 919 """Disable replication on all secondaries, requires MongoDB 3.2.""" 920 for host, port in client.secondaries: 921 secondary = single_client(host, port) 922 secondary.admin.command('configureFailPoint', 'stopReplProducer', 923 mode='alwaysOn') 924 925 926def enable_replication(client): 927 """Enable replication on all secondaries, requires MongoDB 3.2.""" 928 for host, port in client.secondaries: 929 secondary = single_client(host, port) 930 secondary.admin.command('configureFailPoint', 'stopReplProducer', 931 mode='off') 932 933 934class ExceptionCatchingThread(threading.Thread): 935 """A thread that stores any exception encountered from run().""" 936 def __init__(self, *args, **kwargs): 937 self.exc = None 938 super(ExceptionCatchingThread, self).__init__(*args, **kwargs) 939 940 def run(self): 941 try: 942 super(ExceptionCatchingThread, self).run() 943 except BaseException as exc: 944 self.exc = exc 945 raise 946 947 948def parse_read_preference(pref): 949 # Make first letter lowercase to match read_pref's modes. 950 mode_string = pref.get('mode', 'primary') 951 mode_string = mode_string[:1].lower() + mode_string[1:] 952 mode = read_preferences.read_pref_mode_from_name(mode_string) 953 max_staleness = pref.get('maxStalenessSeconds', -1) 954 tag_sets = pref.get('tag_sets') 955 return read_preferences.make_read_preference( 956 mode, tag_sets=tag_sets, max_staleness=max_staleness) 957 958 959def server_name_to_type(name): 960 """Convert a ServerType name to the corresponding value. For SDAM tests.""" 961 # Special case, some tests in the spec include the PossiblePrimary 962 # type, but only single-threaded drivers need that type. We call 963 # possible primaries Unknown. 964 if name == 'PossiblePrimary': 965 return SERVER_TYPE.Unknown 966 return getattr(SERVER_TYPE, name) 967 968 969def cat_files(dest, *sources): 970 """Cat multiple files into dest.""" 971 with open(dest, 'wb') as fdst: 972 for src in sources: 973 with open(src, 'rb') as fsrc: 974 shutil.copyfileobj(fsrc, fdst) 975 976 977@contextlib.contextmanager 978def assertion_context(msg): 979 """A context manager that adds info to an assertion failure.""" 980 try: 981 yield 982 except AssertionError as exc: 983 msg = '%s (%s)' % (exc, msg) 984 py3compat.reraise(type(exc), msg, sys.exc_info()[2]) 985 986 987def parse_spec_options(opts): 988 if 'readPreference' in opts: 989 opts['read_preference'] = parse_read_preference( 990 opts.pop('readPreference')) 991 992 if 'writeConcern' in opts: 993 opts['write_concern'] = WriteConcern( 994 **dict(opts.pop('writeConcern'))) 995 996 if 'readConcern' in opts: 997 opts['read_concern'] = ReadConcern( 998 **dict(opts.pop('readConcern'))) 999 1000 if 'maxTimeMS' in opts: 1001 opts['max_time_ms'] = opts.pop('maxTimeMS') 1002 1003 if 'maxCommitTimeMS' in opts: 1004 opts['max_commit_time_ms'] = opts.pop('maxCommitTimeMS') 1005 1006 if 'hint' in opts: 1007 hint = opts.pop('hint') 1008 if not isinstance(hint, string_type): 1009 hint = list(iteritems(hint)) 1010 opts['hint'] = hint 1011 1012 # Properly format 'hint' arguments for the Bulk API tests. 1013 if 'requests' in opts: 1014 reqs = opts.pop('requests') 1015 for req in reqs: 1016 if 'name' in req: 1017 # CRUD v2 format 1018 args = req.pop('arguments', {}) 1019 if 'hint' in args: 1020 hint = args.pop('hint') 1021 if not isinstance(hint, string_type): 1022 hint = list(iteritems(hint)) 1023 args['hint'] = hint 1024 req['arguments'] = args 1025 else: 1026 # Unified test format 1027 bulk_model, spec = next(iteritems(req)) 1028 if 'hint' in spec: 1029 hint = spec.pop('hint') 1030 if not isinstance(hint, string_type): 1031 hint = list(iteritems(hint)) 1032 spec['hint'] = hint 1033 opts['requests'] = reqs 1034 1035 return dict(opts) 1036 1037 1038def prepare_spec_arguments(spec, arguments, opname, entity_map, 1039 with_txn_callback): 1040 for arg_name in list(arguments): 1041 c2s = camel_to_snake(arg_name) 1042 # PyMongo accepts sort as list of tuples. 1043 if arg_name == "sort": 1044 sort_dict = arguments[arg_name] 1045 arguments[arg_name] = list(iteritems(sort_dict)) 1046 # Named "key" instead not fieldName. 1047 if arg_name == "fieldName": 1048 arguments["key"] = arguments.pop(arg_name) 1049 # Aggregate uses "batchSize", while find uses batch_size. 1050 elif ((arg_name == "batchSize" or arg_name == "allowDiskUse") and 1051 opname == "aggregate"): 1052 continue 1053 # Requires boolean returnDocument. 1054 elif arg_name == "returnDocument": 1055 arguments[c2s] = getattr(ReturnDocument, arguments.pop(arg_name).upper()) 1056 elif c2s == "requests": 1057 # Parse each request into a bulk write model. 1058 requests = [] 1059 for request in arguments["requests"]: 1060 if 'name' in request: 1061 # CRUD v2 format 1062 bulk_model = camel_to_upper_camel(request["name"]) 1063 bulk_class = getattr(operations, bulk_model) 1064 bulk_arguments = camel_to_snake_args(request["arguments"]) 1065 else: 1066 # Unified test format 1067 bulk_model, spec = next(iteritems(request)) 1068 bulk_class = getattr(operations, camel_to_upper_camel(bulk_model)) 1069 bulk_arguments = camel_to_snake_args(spec) 1070 requests.append(bulk_class(**dict(bulk_arguments))) 1071 arguments["requests"] = requests 1072 elif arg_name == "session": 1073 arguments['session'] = entity_map[arguments['session']] 1074 elif (opname in ('command', 'run_admin_command') and 1075 arg_name == 'command'): 1076 # Ensure the first key is the command name. 1077 ordered_command = SON([(spec['command_name'], 1)]) 1078 ordered_command.update(arguments['command']) 1079 arguments['command'] = ordered_command 1080 elif opname == 'open_download_stream' and arg_name == 'id': 1081 arguments['file_id'] = arguments.pop(arg_name) 1082 elif opname != 'find' and c2s == 'max_time_ms': 1083 # find is the only method that accepts snake_case max_time_ms. 1084 # All other methods take kwargs which must use the server's 1085 # camelCase maxTimeMS. See PYTHON-1855. 1086 arguments['maxTimeMS'] = arguments.pop('max_time_ms') 1087 elif opname == 'with_transaction' and arg_name == 'callback': 1088 if 'operations' in arguments[arg_name]: 1089 # CRUD v2 format 1090 callback_ops = arguments[arg_name]['operations'] 1091 else: 1092 # Unified test format 1093 callback_ops = arguments[arg_name] 1094 arguments['callback'] = lambda _: with_txn_callback( 1095 copy.deepcopy(callback_ops)) 1096 elif opname == 'drop_collection' and arg_name == 'collection': 1097 arguments['name_or_collection'] = arguments.pop(arg_name) 1098 elif opname == 'create_collection': 1099 if arg_name == 'collection': 1100 arguments['name'] = arguments.pop(arg_name) 1101 # Any other arguments to create_collection are passed through 1102 # **kwargs. 1103 elif opname == 'create_index' and arg_name == 'keys': 1104 arguments['keys'] = list(arguments.pop(arg_name).items()) 1105 elif opname == 'drop_index' and arg_name == 'name': 1106 arguments['index_or_name'] = arguments.pop(arg_name) 1107 else: 1108 arguments[c2s] = arguments.pop(arg_name) 1109