1# Copyright 2020-present MongoDB, Inc. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Unified test format runner. 16 17https://github.com/mongodb/specifications/blob/master/source/unified-test-format/unified-test-format.rst 18""" 19 20import copy 21import datetime 22import functools 23import os 24import re 25import sys 26import types 27 28from bson import json_util, Code, Decimal128, DBRef, SON, Int64, MaxKey, MinKey 29from bson.binary import Binary 30from bson.objectid import ObjectId 31from bson.py3compat import abc, integer_types, iteritems, text_type, PY3 32from bson.regex import Regex, RE_TYPE 33 34from gridfs import GridFSBucket 35 36from pymongo import ASCENDING, MongoClient 37from pymongo.client_session import ClientSession, TransactionOptions, _TxnState 38from pymongo.change_stream import ChangeStream 39from pymongo.collection import Collection 40from pymongo.database import Database 41from pymongo.errors import ( 42 BulkWriteError, ConnectionFailure, ConfigurationError, InvalidOperation, 43 NotPrimaryError, PyMongoError) 44from pymongo.monitoring import ( 45 CommandFailedEvent, CommandListener, CommandStartedEvent, 46 CommandSucceededEvent, _SENSITIVE_COMMANDS, PoolCreatedEvent, 47 PoolClearedEvent, PoolClosedEvent, ConnectionCreatedEvent, 48 ConnectionReadyEvent, ConnectionClosedEvent, 49 ConnectionCheckOutStartedEvent, ConnectionCheckOutFailedEvent, 50 ConnectionCheckedOutEvent, ConnectionCheckedInEvent) 51from pymongo.read_concern import ReadConcern 52from pymongo.read_preferences import ReadPreference 53from pymongo.results import BulkWriteResult 54from pymongo.server_api import ServerApi 55from pymongo.write_concern import WriteConcern 56 57from test import client_context, unittest, IntegrationTest 58from test.utils import ( 59 camel_to_snake, get_pool, rs_or_single_client, single_client, 60 snake_to_camel, CMAPListener) 61 62from test.version import Version 63from test.utils import ( 64 camel_to_snake_args, parse_collection_options, parse_spec_options, 65 prepare_spec_arguments) 66 67 68JSON_OPTS = json_util.JSONOptions(tz_aware=False) 69 70 71def with_metaclass(meta, *bases): 72 """Create a base class with a metaclass. 73 74 Vendored from six: https://github.com/benjaminp/six/blob/master/six.py 75 """ 76 # This requires a bit of explanation: the basic idea is to make a dummy 77 # metaclass for one level of class instantiation that replaces itself with 78 # the actual metaclass. 79 class metaclass(type): 80 81 def __new__(cls, name, this_bases, d): 82 if sys.version_info[:2] >= (3, 7): 83 # This version introduced PEP 560 that requires a bit 84 # of extra care (we mimic what is done by __build_class__). 85 resolved_bases = types.resolve_bases(bases) 86 if resolved_bases is not bases: 87 d['__orig_bases__'] = bases 88 else: 89 resolved_bases = bases 90 return meta(name, resolved_bases, d) 91 92 @classmethod 93 def __prepare__(cls, name, this_bases): 94 return meta.__prepare__(name, bases) 95 return type.__new__(metaclass, 'temporary_class', (), {}) 96 97 98def is_run_on_requirement_satisfied(requirement): 99 topology_satisfied = True 100 req_topologies = requirement.get('topologies') 101 if req_topologies: 102 topology_satisfied = client_context.is_topology_type( 103 req_topologies) 104 105 server_version = Version(*client_context.version[:3]) 106 107 min_version_satisfied = True 108 req_min_server_version = requirement.get('minServerVersion') 109 if req_min_server_version: 110 min_version_satisfied = Version.from_string( 111 req_min_server_version) <= server_version 112 113 max_version_satisfied = True 114 req_max_server_version = requirement.get('maxServerVersion') 115 if req_max_server_version: 116 max_version_satisfied = Version.from_string( 117 req_max_server_version) >= server_version 118 119 params_satisfied = True 120 params = requirement.get('serverParameters') 121 if params: 122 for param, val in params.items(): 123 if param not in client_context.server_parameters: 124 params_satisfied = False 125 elif client_context.server_parameters[param] != val: 126 params_satisfied = False 127 128 auth_satisfied = True 129 req_auth = requirement.get('auth') 130 if req_auth is not None: 131 if req_auth: 132 auth_satisfied = client_context.auth_enabled 133 else: 134 auth_satisfied = not client_context.auth_enabled 135 136 return (topology_satisfied and min_version_satisfied and 137 max_version_satisfied and params_satisfied and auth_satisfied) 138 139 140def parse_collection_or_database_options(options): 141 return parse_collection_options(options) 142 143 144def parse_bulk_write_result(result): 145 upserted_ids = {str(int_idx): result.upserted_ids[int_idx] 146 for int_idx in result.upserted_ids} 147 return { 148 'deletedCount': result.deleted_count, 149 'insertedCount': result.inserted_count, 150 'matchedCount': result.matched_count, 151 'modifiedCount': result.modified_count, 152 'upsertedCount': result.upserted_count, 153 'upsertedIds': upserted_ids} 154 155 156def parse_bulk_write_error_result(error): 157 write_result = BulkWriteResult(error.details, True) 158 return parse_bulk_write_result(write_result) 159 160 161class NonLazyCursor(object): 162 """A find cursor proxy that creates the remote cursor when initialized.""" 163 def __init__(self, find_cursor): 164 self.find_cursor = find_cursor 165 # Create the server side cursor. 166 self.first_result = next(find_cursor, None) 167 168 def __iter__(self): 169 return self 170 171 def __next__(self): 172 if self.first_result is not None: 173 first = self.first_result 174 self.first_result = None 175 return first 176 return next(self.find_cursor) 177 178 next = __next__ 179 180 def close(self): 181 self.find_cursor.close() 182 183 184class EventListenerUtil(CMAPListener, CommandListener): 185 def __init__(self, observe_events, ignore_commands, 186 observe_sensitive_commands): 187 self._event_types = set(name.lower() for name in observe_events) 188 if observe_sensitive_commands: 189 self._ignore_commands = set(ignore_commands) 190 else: 191 self._ignore_commands = _SENSITIVE_COMMANDS | set(ignore_commands) 192 self._ignore_commands.add('configurefailpoint') 193 super(EventListenerUtil, self).__init__() 194 195 def get_events(self, event_type): 196 if event_type == 'command': 197 return [e for e in self.events if 'Command' in type(e).__name__] 198 return [e for e in self.events if 'Command' not in type(e).__name__] 199 200 def add_event(self, event): 201 if type(event).__name__.lower() in self._event_types: 202 super(EventListenerUtil, self).add_event(event) 203 204 def _command_event(self, event): 205 if event.command_name.lower() not in self._ignore_commands: 206 self.add_event(event) 207 208 def started(self, event): 209 self._command_event(event) 210 211 def succeeded(self, event): 212 self._command_event(event) 213 214 def failed(self, event): 215 self._command_event(event) 216 217 218class EntityMapUtil(object): 219 """Utility class that implements an entity map as per the unified 220 test format specification.""" 221 def __init__(self, test_class): 222 self._entities = {} 223 self._listeners = {} 224 self._session_lsids = {} 225 self.test = test_class 226 227 def __getitem__(self, item): 228 try: 229 return self._entities[item] 230 except KeyError: 231 self.test.fail('Could not find entity named %s in map' % ( 232 item,)) 233 234 def __setitem__(self, key, value): 235 if not isinstance(key, text_type): 236 self.test.fail( 237 'Expected entity name of type str, got %s' % (type(key))) 238 239 if key in self._entities: 240 self.test.fail('Entity named %s already in map' % (key,)) 241 242 self._entities[key] = value 243 244 def _create_entity(self, entity_spec): 245 if len(entity_spec) != 1: 246 self.test.fail( 247 "Entity spec %s did not contain exactly one top-level key" % ( 248 entity_spec,)) 249 250 entity_type, spec = next(iteritems(entity_spec)) 251 if entity_type == 'client': 252 kwargs = {} 253 observe_events = spec.get('observeEvents', []) 254 ignore_commands = spec.get('ignoreCommandMonitoringEvents', []) 255 observe_sensitive_commands = spec.get( 256 'observeSensitiveCommands', False) 257 # TODO: SUPPORT storeEventsAsEntities 258 if len(observe_events) or len(ignore_commands): 259 ignore_commands = [cmd.lower() for cmd in ignore_commands] 260 listener = EventListenerUtil( 261 observe_events, ignore_commands, observe_sensitive_commands) 262 self._listeners[spec['id']] = listener 263 kwargs['event_listeners'] = [listener] 264 if spec.get('useMultipleMongoses'): 265 if client_context.load_balancer: 266 kwargs['h'] = client_context.MULTI_MONGOS_LB_URI 267 elif client_context.is_mongos: 268 kwargs['h'] = client_context.mongos_seeds() 269 kwargs.update(spec.get('uriOptions', {})) 270 server_api = spec.get('serverApi') 271 if server_api: 272 kwargs['server_api'] = ServerApi( 273 server_api['version'], strict=server_api.get('strict'), 274 deprecation_errors=server_api.get('deprecationErrors')) 275 client = rs_or_single_client(**kwargs) 276 self[spec['id']] = client 277 self.test.addCleanup(client.close) 278 return 279 elif entity_type == 'database': 280 client = self[spec['client']] 281 if not isinstance(client, MongoClient): 282 self.test.fail( 283 'Expected entity %s to be of type MongoClient, got %s' % ( 284 spec['client'], type(client))) 285 options = parse_collection_or_database_options( 286 spec.get('databaseOptions', {})) 287 self[spec['id']] = client.get_database( 288 spec['databaseName'], **options) 289 return 290 elif entity_type == 'collection': 291 database = self[spec['database']] 292 if not isinstance(database, Database): 293 self.test.fail( 294 'Expected entity %s to be of type Database, got %s' % ( 295 spec['database'], type(database))) 296 options = parse_collection_or_database_options( 297 spec.get('collectionOptions', {})) 298 self[spec['id']] = database.get_collection( 299 spec['collectionName'], **options) 300 return 301 elif entity_type == 'session': 302 client = self[spec['client']] 303 if not isinstance(client, MongoClient): 304 self.test.fail( 305 'Expected entity %s to be of type MongoClient, got %s' % ( 306 spec['client'], type(client))) 307 opts = camel_to_snake_args(spec.get('sessionOptions', {})) 308 if 'default_transaction_options' in opts: 309 txn_opts = parse_spec_options( 310 opts['default_transaction_options']) 311 txn_opts = TransactionOptions(**txn_opts) 312 opts = copy.deepcopy(opts) 313 opts['default_transaction_options'] = txn_opts 314 session = client.start_session(**dict(opts)) 315 self[spec['id']] = session 316 self._session_lsids[spec['id']] = copy.deepcopy(session.session_id) 317 self.test.addCleanup(session.end_session) 318 return 319 elif entity_type == 'bucket': 320 # TODO: implement the 'bucket' entity type 321 self.test.skipTest( 322 'GridFS is not currently supported (PYTHON-2459)') 323 self.test.fail( 324 'Unable to create entity of unknown type %s' % (entity_type,)) 325 326 def create_entities_from_spec(self, entity_spec): 327 for spec in entity_spec: 328 self._create_entity(spec) 329 330 def get_listener_for_client(self, client_name): 331 client = self[client_name] 332 if not isinstance(client, MongoClient): 333 self.test.fail( 334 'Expected entity %s to be of type MongoClient, got %s' % ( 335 client_name, type(client))) 336 337 listener = self._listeners.get(client_name) 338 if not listener: 339 self.test.fail( 340 'No listeners configured for client %s' % (client_name,)) 341 342 return listener 343 344 def get_lsid_for_session(self, session_name): 345 session = self[session_name] 346 if not isinstance(session, ClientSession): 347 self.test.fail( 348 'Expected entity %s to be of type ClientSession, got %s' % ( 349 session_name, type(session))) 350 351 try: 352 return session.session_id 353 except InvalidOperation: 354 # session has been closed. 355 return self._session_lsids[session_name] 356 357 358if not PY3: 359 binary_types = (Binary,) 360 long_types = (Int64, long) 361 unicode_type = unicode 362else: 363 binary_types = (Binary, bytes) 364 long_types = (Int64,) 365 unicode_type = str 366 367 368BSON_TYPE_ALIAS_MAP = { 369 # https://docs.mongodb.com/manual/reference/operator/query/type/ 370 # https://pymongo.readthedocs.io/en/stable/api/bson/index.html 371 'double': (float,), 372 'string': (text_type,), 373 'object': (abc.Mapping,), 374 'array': (abc.MutableSequence,), 375 'binData': binary_types, 376 'undefined': (type(None),), 377 'objectId': (ObjectId,), 378 'bool': (bool,), 379 'date': (datetime.datetime,), 380 'null': (type(None),), 381 'regex': (Regex, RE_TYPE), 382 'dbPointer': (DBRef,), 383 'javascript': (unicode_type, Code), 384 'symbol': (unicode_type,), 385 'javascriptWithScope': (unicode_type, Code), 386 'int': (int,), 387 'long': long_types, 388 'decimal': (Decimal128,), 389 'maxKey': (MaxKey,), 390 'minKey': (MinKey,), 391} 392 393 394class MatchEvaluatorUtil(object): 395 """Utility class that implements methods for evaluating matches as per 396 the unified test format specification.""" 397 def __init__(self, test_class): 398 self.test = test_class 399 400 def _operation_exists(self, spec, actual, key_to_compare): 401 if spec is True: 402 self.test.assertIn(key_to_compare, actual) 403 elif spec is False: 404 self.test.assertNotIn(key_to_compare, actual) 405 else: 406 self.test.fail( 407 'Expected boolean value for $$exists operator, got %s' % ( 408 spec,)) 409 410 def __type_alias_to_type(self, alias): 411 if alias not in BSON_TYPE_ALIAS_MAP: 412 self.test.fail('Unrecognized BSON type alias %s' % (alias,)) 413 return BSON_TYPE_ALIAS_MAP[alias] 414 415 def _operation_type(self, spec, actual, key_to_compare): 416 if isinstance(spec, abc.MutableSequence): 417 permissible_types = tuple([ 418 t for alias in spec for t in self.__type_alias_to_type(alias)]) 419 else: 420 permissible_types = self.__type_alias_to_type(spec) 421 self.test.assertIsInstance( 422 actual[key_to_compare], permissible_types) 423 424 def _operation_matchesEntity(self, spec, actual, key_to_compare): 425 expected_entity = self.test.entity_map[spec] 426 self.test.assertIsInstance(expected_entity, abc.Mapping) 427 self.test.assertEqual(expected_entity, actual[key_to_compare]) 428 429 def _operation_matchesHexBytes(self, spec, actual, key_to_compare): 430 raise NotImplementedError 431 432 def _operation_unsetOrMatches(self, spec, actual, key_to_compare): 433 if key_to_compare is None and not actual: 434 # top-level document can be None when unset 435 return 436 437 if key_to_compare not in actual: 438 # we add a dummy value for the compared key to pass map size check 439 actual[key_to_compare] = 'dummyValue' 440 return 441 self.match_result(spec, actual[key_to_compare], in_recursive_call=True) 442 443 def _operation_sessionLsid(self, spec, actual, key_to_compare): 444 expected_lsid = self.test.entity_map.get_lsid_for_session(spec) 445 self.test.assertEqual(expected_lsid, actual[key_to_compare]) 446 447 def _evaluate_special_operation(self, opname, spec, actual, 448 key_to_compare): 449 method_name = '_operation_%s' % (opname.strip('$'),) 450 try: 451 method = getattr(self, method_name) 452 except AttributeError: 453 self.test.fail( 454 'Unsupported special matching operator %s' % (opname,)) 455 else: 456 method(spec, actual, key_to_compare) 457 458 def _evaluate_if_special_operation(self, expectation, actual, 459 key_to_compare=None): 460 """Returns True if a special operation is evaluated, False 461 otherwise. If the ``expectation`` map contains a single key, 462 value pair we check it for a special operation. 463 If given, ``key_to_compare`` is assumed to be the key in 464 ``expectation`` whose corresponding value needs to be 465 evaluated for a possible special operation. ``key_to_compare`` 466 is ignored when ``expectation`` has only one key.""" 467 if not isinstance(expectation, abc.Mapping): 468 return False 469 470 is_special_op, opname, spec = False, False, False 471 472 if key_to_compare is not None: 473 if key_to_compare.startswith('$$'): 474 is_special_op = True 475 opname = key_to_compare 476 spec = expectation[key_to_compare] 477 key_to_compare = None 478 else: 479 nested = expectation[key_to_compare] 480 if isinstance(nested, abc.Mapping) and len(nested) == 1: 481 opname, spec = next(iteritems(nested)) 482 if opname.startswith('$$'): 483 is_special_op = True 484 elif len(expectation) == 1: 485 opname, spec = next(iteritems(expectation)) 486 if opname.startswith('$$'): 487 is_special_op = True 488 key_to_compare = None 489 490 if is_special_op: 491 self._evaluate_special_operation( 492 opname=opname, 493 spec=spec, 494 actual=actual, 495 key_to_compare=key_to_compare) 496 return True 497 498 return False 499 500 def _match_document(self, expectation, actual, is_root): 501 if self._evaluate_if_special_operation(expectation, actual): 502 return 503 504 self.test.assertIsInstance(actual, abc.Mapping) 505 for key, value in iteritems(expectation): 506 if self._evaluate_if_special_operation(expectation, actual, key): 507 continue 508 509 self.test.assertIn(key, actual) 510 self.match_result(value, actual[key], in_recursive_call=True) 511 512 if not is_root: 513 expected_keys = set(expectation.keys()) 514 for key, value in expectation.items(): 515 if value == {'$$exists': False}: 516 expected_keys.remove(key) 517 self.test.assertEqual(expected_keys, set(actual.keys())) 518 519 def match_result(self, expectation, actual, 520 in_recursive_call=False): 521 if isinstance(expectation, abc.Mapping): 522 return self._match_document( 523 expectation, actual, is_root=not in_recursive_call) 524 525 if isinstance(expectation, abc.MutableSequence): 526 self.test.assertIsInstance(actual, abc.MutableSequence) 527 for e, a in zip(expectation, actual): 528 if isinstance(e, abc.Mapping): 529 self._match_document( 530 e, a, is_root=not in_recursive_call) 531 else: 532 self.match_result(e, a, in_recursive_call=True) 533 return 534 535 # account for flexible numerics in element-wise comparison and unicode 536 # vs str on Python 2. 537 if not (isinstance(expectation, integer_types) or 538 isinstance(expectation, float) or 539 (isinstance(expectation, unicode_type) and not PY3)): 540 self.test.assertIsInstance(actual, type(expectation)) 541 self.test.assertEqual(expectation, actual) 542 543 def assertHasServiceId(self, spec, actual): 544 if 'hasServiceId' in spec: 545 if spec.get('hasServiceId'): 546 self.test.assertIsNotNone(actual.service_id) 547 self.test.assertIsInstance(actual.service_id, ObjectId) 548 else: 549 self.test.assertIsNone(actual.service_id) 550 551 def match_event(self, event_type, expectation, actual): 552 name, spec = next(iteritems(expectation)) 553 554 # every command event has the commandName field 555 if event_type == 'command': 556 command_name = spec.get('commandName') 557 if command_name: 558 self.test.assertEqual(command_name, actual.command_name) 559 560 if name == 'commandStartedEvent': 561 self.test.assertIsInstance(actual, CommandStartedEvent) 562 command = spec.get('command') 563 database_name = spec.get('databaseName') 564 if command: 565 if actual.command_name == 'update': 566 # TODO: remove this once PYTHON-1744 is done. 567 # Add upsert and multi fields back into expectations. 568 for update in command.get('updates', []): 569 update.setdefault('upsert', False) 570 update.setdefault('multi', False) 571 self.match_result(command, actual.command) 572 if database_name: 573 self.test.assertEqual( 574 database_name, actual.database_name) 575 self.assertHasServiceId(spec, actual) 576 elif name == 'commandSucceededEvent': 577 self.test.assertIsInstance(actual, CommandSucceededEvent) 578 reply = spec.get('reply') 579 if reply: 580 self.match_result(reply, actual.reply) 581 self.assertHasServiceId(spec, actual) 582 elif name == 'commandFailedEvent': 583 self.test.assertIsInstance(actual, CommandFailedEvent) 584 self.assertHasServiceId(spec, actual) 585 elif name == 'poolCreatedEvent': 586 self.test.assertIsInstance(actual, PoolCreatedEvent) 587 elif name == 'poolReadyEvent': 588 # PyMongo 3.X does not support PoolReadyEvent. 589 assert False 590 elif name == 'poolClearedEvent': 591 self.test.assertIsInstance(actual, PoolClearedEvent) 592 self.assertHasServiceId(spec, actual) 593 elif name == 'poolClosedEvent': 594 self.test.assertIsInstance(actual, PoolClosedEvent) 595 elif name == 'connectionCreatedEvent': 596 self.test.assertIsInstance(actual, ConnectionCreatedEvent) 597 elif name == 'connectionReadyEvent': 598 self.test.assertIsInstance(actual, ConnectionReadyEvent) 599 elif name == 'connectionClosedEvent': 600 self.test.assertIsInstance(actual, ConnectionClosedEvent) 601 if 'reason' in spec: 602 self.test.assertEqual(actual.reason, spec['reason']) 603 elif name == 'connectionCheckOutStartedEvent': 604 self.test.assertIsInstance(actual, ConnectionCheckOutStartedEvent) 605 elif name == 'connectionCheckOutFailedEvent': 606 self.test.assertIsInstance(actual, ConnectionCheckOutFailedEvent) 607 if 'reason' in spec: 608 self.test.assertEqual(actual.reason, spec['reason']) 609 elif name == 'connectionCheckedOutEvent': 610 self.test.assertIsInstance(actual, ConnectionCheckedOutEvent) 611 elif name == 'connectionCheckedInEvent': 612 self.test.assertIsInstance(actual, ConnectionCheckedInEvent) 613 else: 614 self.test.fail( 615 'Unsupported event type %s' % (name,)) 616 617 618def coerce_result(opname, result): 619 """Convert a pymongo result into the spec's result format.""" 620 if hasattr(result, 'acknowledged') and not result.acknowledged: 621 return {'acknowledged': False} 622 if opname == 'bulkWrite': 623 return parse_bulk_write_result(result) 624 if opname == 'insertOne': 625 return {'insertedId': result.inserted_id} 626 if opname == 'insertMany': 627 return {idx: _id for idx, _id in enumerate(result.inserted_ids)} 628 if opname in ('deleteOne', 'deleteMany'): 629 return {'deletedCount': result.deleted_count} 630 if opname in ('updateOne', 'updateMany', 'replaceOne'): 631 return { 632 'matchedCount': result.matched_count, 633 'modifiedCount': result.modified_count, 634 'upsertedCount': 0 if result.upserted_id is None else 1, 635 } 636 return result 637 638 639class UnifiedSpecTestMixinV1(IntegrationTest): 640 """Mixin class to run test cases from test specification files. 641 642 Assumes that tests conform to the `unified test format 643 <https://github.com/mongodb/specifications/blob/master/source/unified-test-format/unified-test-format.rst>`_. 644 645 Specification of the test suite being currently run is available as 646 a class attribute ``TEST_SPEC``. 647 """ 648 SCHEMA_VERSION = Version.from_string('1.5') 649 RUN_ON_LOAD_BALANCER = True 650 651 @staticmethod 652 def should_run_on(run_on_spec): 653 if not run_on_spec: 654 # Always run these tests. 655 return True 656 657 for req in run_on_spec: 658 if is_run_on_requirement_satisfied(req): 659 return True 660 return False 661 662 def insert_initial_data(self, initial_data): 663 for collection_data in initial_data: 664 coll_name = collection_data['collectionName'] 665 db_name = collection_data['databaseName'] 666 documents = collection_data['documents'] 667 668 coll = self.client.get_database(db_name).get_collection( 669 coll_name, write_concern=WriteConcern(w="majority")) 670 coll.drop() 671 672 if len(documents) > 0: 673 coll.insert_many(documents) 674 else: 675 # ensure collection exists 676 result = coll.insert_one({}) 677 coll.delete_one({'_id': result.inserted_id}) 678 679 @classmethod 680 def setUpClass(cls): 681 # super call creates internal client cls.client 682 super(UnifiedSpecTestMixinV1, cls).setUpClass() 683 684 # process file-level runOnRequirements 685 run_on_spec = cls.TEST_SPEC.get('runOnRequirements', []) 686 if not cls.should_run_on(run_on_spec): 687 raise unittest.SkipTest( 688 '%s runOnRequirements not satisfied' % (cls.__name__,)) 689 690 # add any special-casing for skipping tests here 691 if client_context.storage_engine == 'mmapv1': 692 if 'retryable-writes' in cls.TEST_SPEC['description']: 693 raise unittest.SkipTest( 694 "MMAPv1 does not support retryWrites=True") 695 696 def setUp(self): 697 super(UnifiedSpecTestMixinV1, self).setUp() 698 699 # process schemaVersion 700 # note: we check major schema version during class generation 701 # note: we do this here because we cannot run assertions in setUpClass 702 version = Version.from_string(self.TEST_SPEC['schemaVersion']) 703 self.assertLessEqual( 704 version, self.SCHEMA_VERSION, 705 'expected schema version %s or lower, got %s' % ( 706 self.SCHEMA_VERSION, version)) 707 708 # initialize internals 709 self.match_evaluator = MatchEvaluatorUtil(self) 710 711 def maybe_skip_test(self, spec): 712 # add any special-casing for skipping tests here 713 if client_context.storage_engine == 'mmapv1': 714 if 'Dirty explicit session is discarded' in spec['description']: 715 raise unittest.SkipTest( 716 "MMAPv1 does not support retryWrites=True") 717 elif 'Client side error in command starting transaction' in spec['description']: 718 raise unittest.SkipTest("Implement PYTHON-1894") 719 720 def process_error(self, exception, spec): 721 is_error = spec.get('isError') 722 is_client_error = spec.get('isClientError') 723 error_contains = spec.get('errorContains') 724 error_code = spec.get('errorCode') 725 error_code_name = spec.get('errorCodeName') 726 error_labels_contain = spec.get('errorLabelsContain') 727 error_labels_omit = spec.get('errorLabelsOmit') 728 expect_result = spec.get('expectResult') 729 730 if is_error: 731 # already satisfied because exception was raised 732 pass 733 734 if is_client_error: 735 # Connection errors are considered client errors. 736 if isinstance(exception, ConnectionFailure): 737 self.assertNotIsInstance(exception, NotPrimaryError) 738 elif isinstance(exception, (InvalidOperation, ConfigurationError)): 739 pass 740 else: 741 self.assertNotIsInstance(exception, PyMongoError) 742 743 if error_contains: 744 if isinstance(exception, BulkWriteError): 745 errmsg = str(exception.details).lower() 746 else: 747 errmsg = str(exception).lower() 748 self.assertIn(error_contains.lower(), errmsg) 749 750 if error_code: 751 self.assertEqual( 752 error_code, exception.details.get('code')) 753 754 if error_code_name: 755 self.assertEqual( 756 error_code_name, exception.details.get('codeName')) 757 758 if error_labels_contain: 759 labels = [err_label for err_label in error_labels_contain 760 if exception.has_error_label(err_label)] 761 self.assertEqual(labels, error_labels_contain) 762 763 if error_labels_omit: 764 for err_label in error_labels_omit: 765 if exception.has_error_label(err_label): 766 self.fail("Exception '%s' unexpectedly had label '%s'" % ( 767 exception, err_label)) 768 769 if expect_result: 770 if isinstance(exception, BulkWriteError): 771 result = parse_bulk_write_error_result( 772 exception) 773 self.match_evaluator.match_result(expect_result, result) 774 else: 775 self.fail("expectResult can only be specified with %s " 776 "exceptions" % (BulkWriteError,)) 777 778 def __raise_if_unsupported(self, opname, target, *target_types): 779 if not isinstance(target, target_types): 780 self.fail('Operation %s not supported for entity ' 781 'of type %s' % (opname, type(target))) 782 783 def __entityOperation_createChangeStream(self, target, *args, **kwargs): 784 if client_context.storage_engine == 'mmapv1': 785 self.skipTest("MMAPv1 does not support change streams") 786 self.__raise_if_unsupported( 787 'createChangeStream', target, MongoClient, Database, Collection) 788 stream = target.watch(*args, **kwargs) 789 self.addCleanup(stream.close) 790 return stream 791 792 def _clientOperation_createChangeStream(self, target, *args, **kwargs): 793 return self.__entityOperation_createChangeStream( 794 target, *args, **kwargs) 795 796 def _databaseOperation_createChangeStream(self, target, *args, **kwargs): 797 return self.__entityOperation_createChangeStream( 798 target, *args, **kwargs) 799 800 def _collectionOperation_createChangeStream(self, target, *args, **kwargs): 801 return self.__entityOperation_createChangeStream( 802 target, *args, **kwargs) 803 804 def _databaseOperation_runCommand(self, target, **kwargs): 805 self.__raise_if_unsupported('runCommand', target, Database) 806 # Ensure the first key is the command name. 807 ordered_command = SON([(kwargs.pop('command_name'), 1)]) 808 ordered_command.update(kwargs['command']) 809 kwargs['command'] = ordered_command 810 return target.command(**kwargs) 811 812 def _databaseOperation_listCollections(self, target, *args, **kwargs): 813 if 'batch_size' in kwargs: 814 kwargs['cursor'] = {'batchSize': kwargs.pop('batch_size')} 815 cursor = target.list_collections(*args, **kwargs) 816 return list(cursor) 817 818 def __entityOperation_aggregate(self, target, *args, **kwargs): 819 self.__raise_if_unsupported('aggregate', target, Database, Collection) 820 return list(target.aggregate(*args, **kwargs)) 821 822 def _databaseOperation_aggregate(self, target, *args, **kwargs): 823 return self.__entityOperation_aggregate(target, *args, **kwargs) 824 825 def _collectionOperation_aggregate(self, target, *args, **kwargs): 826 return self.__entityOperation_aggregate(target, *args, **kwargs) 827 828 def _collectionOperation_find(self, target, *args, **kwargs): 829 self.__raise_if_unsupported('find', target, Collection) 830 find_cursor = target.find(*args, **kwargs) 831 return list(find_cursor) 832 833 def _collectionOperation_createFindCursor(self, target, *args, **kwargs): 834 self.__raise_if_unsupported('find', target, Collection) 835 cursor = NonLazyCursor(target.find(*args, **kwargs)) 836 self.addCleanup(cursor.close) 837 return cursor 838 839 def _collectionOperation_listIndexes(self, target, *args, **kwargs): 840 if 'batch_size' in kwargs: 841 self.skipTest('PyMongo does not support batch_size for ' 842 'list_indexes') 843 return target.list_indexes(*args, **kwargs) 844 845 def _sessionOperation_withTransaction(self, target, *args, **kwargs): 846 if client_context.storage_engine == 'mmapv1': 847 self.skipTest('MMAPv1 does not support document-level locking') 848 self.__raise_if_unsupported('withTransaction', target, ClientSession) 849 return target.with_transaction(*args, **kwargs) 850 851 def _sessionOperation_startTransaction(self, target, *args, **kwargs): 852 if client_context.storage_engine == 'mmapv1': 853 self.skipTest('MMAPv1 does not support document-level locking') 854 self.__raise_if_unsupported('startTransaction', target, ClientSession) 855 return target.start_transaction(*args, **kwargs) 856 857 def _changeStreamOperation_iterateUntilDocumentOrError(self, target, 858 *args, **kwargs): 859 self.__raise_if_unsupported( 860 'iterateUntilDocumentOrError', target, ChangeStream) 861 return next(target) 862 863 def _cursor_iterateUntilDocumentOrError(self, target, *args, **kwargs): 864 self.__raise_if_unsupported( 865 'iterateUntilDocumentOrError', target, NonLazyCursor) 866 return next(target) 867 868 def _cursor_close(self, target, *args, **kwargs): 869 self.__raise_if_unsupported('close', target, NonLazyCursor) 870 return target.close() 871 872 def run_entity_operation(self, spec): 873 target = self.entity_map[spec['object']] 874 opname = spec['name'] 875 opargs = spec.get('arguments') 876 expect_error = spec.get('expectError') 877 save_as_entity = spec.get('saveResultAsEntity') 878 expect_result = spec.get('expectResult') 879 ignore = spec.get('ignoreResultAndError') 880 if ignore and (expect_error or save_as_entity or expect_result): 881 raise ValueError( 882 'ignoreResultAndError is incompatible with saveResultAsEntity' 883 ', expectError, and expectResult') 884 if opargs: 885 arguments = parse_spec_options(copy.deepcopy(opargs)) 886 prepare_spec_arguments(spec, arguments, camel_to_snake(opname), 887 self.entity_map, self.run_operations) 888 else: 889 arguments = tuple() 890 891 if isinstance(target, MongoClient): 892 method_name = '_clientOperation_%s' % (opname,) 893 elif isinstance(target, Database): 894 method_name = '_databaseOperation_%s' % (opname,) 895 elif isinstance(target, Collection): 896 method_name = '_collectionOperation_%s' % (opname,) 897 elif isinstance(target, ChangeStream): 898 method_name = '_changeStreamOperation_%s' % (opname,) 899 elif isinstance(target, NonLazyCursor): 900 method_name = '_cursor_%s' % (opname,) 901 elif isinstance(target, ClientSession): 902 method_name = '_sessionOperation_%s' % (opname,) 903 elif isinstance(target, GridFSBucket): 904 raise NotImplementedError 905 else: 906 method_name = 'doesNotExist' 907 908 try: 909 method = getattr(self, method_name) 910 except AttributeError: 911 try: 912 cmd = getattr(target, camel_to_snake(opname)) 913 except AttributeError: 914 self.fail('Unsupported operation %s on entity %s' % ( 915 opname, target)) 916 else: 917 cmd = functools.partial(method, target) 918 919 try: 920 result = cmd(**dict(arguments)) 921 except Exception as exc: 922 if ignore: 923 return 924 if expect_error: 925 return self.process_error(exc, expect_error) 926 raise 927 else: 928 if expect_error: 929 self.fail('Excepted error %s but "%s" succeeded: %s' % ( 930 expect_error, opname, result)) 931 932 if expect_result: 933 actual = coerce_result(opname, result) 934 self.match_evaluator.match_result(expect_result, actual) 935 936 if save_as_entity: 937 self.entity_map[save_as_entity] = result 938 939 def __set_fail_point(self, client, command_args): 940 if not client_context.test_commands_enabled: 941 self.skipTest('Test commands must be enabled') 942 943 cmd_on = SON([('configureFailPoint', 'failCommand')]) 944 cmd_on.update(command_args) 945 client.admin.command(cmd_on) 946 self.addCleanup( 947 client.admin.command, 948 'configureFailPoint', cmd_on['configureFailPoint'], mode='off') 949 950 def _testOperation_failPoint(self, spec): 951 self.__set_fail_point( 952 client=self.entity_map[spec['client']], 953 command_args=spec['failPoint']) 954 955 def _testOperation_targetedFailPoint(self, spec): 956 session = self.entity_map[spec['session']] 957 if not session._pinned_address: 958 self.fail("Cannot use targetedFailPoint operation with unpinned " 959 "session %s" % (spec['session'],)) 960 961 client = single_client('%s:%s' % session._pinned_address) 962 self.__set_fail_point( 963 client=client, command_args=spec['failPoint']) 964 self.addCleanup(client.close) 965 966 def _testOperation_assertSessionTransactionState(self, spec): 967 session = self.entity_map[spec['session']] 968 expected_state = getattr(_TxnState, spec['state'].upper()) 969 self.assertEqual(expected_state, session._transaction.state) 970 971 def _testOperation_assertSessionPinned(self, spec): 972 session = self.entity_map[spec['session']] 973 self.assertIsNotNone(session._transaction.pinned_address) 974 975 def _testOperation_assertSessionUnpinned(self, spec): 976 session = self.entity_map[spec['session']] 977 self.assertIsNone(session._pinned_address) 978 self.assertIsNone(session._transaction.pinned_address) 979 980 def __get_last_two_command_lsids(self, listener): 981 cmd_started_events = [] 982 for event in reversed(listener.events): 983 if isinstance(event, CommandStartedEvent): 984 cmd_started_events.append(event) 985 if len(cmd_started_events) < 2: 986 self.fail('Needed 2 CommandStartedEvents to compare lsids, ' 987 'got %s' % (len(cmd_started_events))) 988 return tuple([e.command['lsid'] for e in cmd_started_events][:2]) 989 990 def _testOperation_assertDifferentLsidOnLastTwoCommands(self, spec): 991 listener = self.entity_map.get_listener_for_client(spec['client']) 992 self.assertNotEqual(*self.__get_last_two_command_lsids(listener)) 993 994 def _testOperation_assertSameLsidOnLastTwoCommands(self, spec): 995 listener = self.entity_map.get_listener_for_client(spec['client']) 996 self.assertEqual(*self.__get_last_two_command_lsids(listener)) 997 998 def _testOperation_assertSessionDirty(self, spec): 999 session = self.entity_map[spec['session']] 1000 self.assertTrue(session._server_session.dirty) 1001 1002 def _testOperation_assertSessionNotDirty(self, spec): 1003 session = self.entity_map[spec['session']] 1004 return self.assertFalse(session._server_session.dirty) 1005 1006 def _testOperation_assertCollectionExists(self, spec): 1007 database_name = spec['databaseName'] 1008 collection_name = spec['collectionName'] 1009 collection_name_list = list( 1010 self.client.get_database(database_name).list_collection_names()) 1011 self.assertIn(collection_name, collection_name_list) 1012 1013 def _testOperation_assertCollectionNotExists(self, spec): 1014 database_name = spec['databaseName'] 1015 collection_name = spec['collectionName'] 1016 collection_name_list = list( 1017 self.client.get_database(database_name).list_collection_names()) 1018 self.assertNotIn(collection_name, collection_name_list) 1019 1020 def _testOperation_assertIndexExists(self, spec): 1021 collection = self.client[spec['databaseName']][spec['collectionName']] 1022 index_names = [idx['name'] for idx in collection.list_indexes()] 1023 self.assertIn(spec['indexName'], index_names) 1024 1025 def _testOperation_assertIndexNotExists(self, spec): 1026 collection = self.client[spec['databaseName']][spec['collectionName']] 1027 for index in collection.list_indexes(): 1028 self.assertNotEqual(spec['indexName'], index['name']) 1029 1030 def _testOperation_assertNumberConnectionsCheckedOut(self, spec): 1031 client = self.entity_map[spec['client']] 1032 pool = get_pool(client) 1033 self.assertEqual(spec['connections'], pool.active_sockets) 1034 1035 def run_special_operation(self, spec): 1036 opname = spec['name'] 1037 method_name = '_testOperation_%s' % (opname,) 1038 try: 1039 method = getattr(self, method_name) 1040 except AttributeError: 1041 self.fail('Unsupported special test operation %s' % (opname,)) 1042 else: 1043 method(spec['arguments']) 1044 1045 def run_operations(self, spec): 1046 for op in spec: 1047 target = op['object'] 1048 if target != 'testRunner': 1049 self.run_entity_operation(op) 1050 else: 1051 self.run_special_operation(op) 1052 1053 def check_events(self, spec): 1054 for event_spec in spec: 1055 client_name = event_spec['client'] 1056 events = event_spec['events'] 1057 # Valid types: 'command', 'cmap' 1058 event_type = event_spec.get('eventType', 'command') 1059 assert event_type in ('command', 'cmap') 1060 1061 listener = self.entity_map.get_listener_for_client(client_name) 1062 actual_events = listener.get_events(event_type) 1063 if len(events) == 0: 1064 self.assertEqual(actual_events, []) 1065 continue 1066 1067 if len(events) > len(actual_events): 1068 self.fail('Expected to see %s events, got %s' % ( 1069 len(events), len(actual_events))) 1070 1071 for idx, expected_event in enumerate(events): 1072 self.match_evaluator.match_event( 1073 event_type, expected_event, actual_events[idx]) 1074 1075 def verify_outcome(self, spec): 1076 for collection_data in spec: 1077 coll_name = collection_data['collectionName'] 1078 db_name = collection_data['databaseName'] 1079 expected_documents = collection_data['documents'] 1080 1081 coll = self.client.get_database(db_name).get_collection( 1082 coll_name, 1083 read_preference=ReadPreference.PRIMARY, 1084 read_concern=ReadConcern(level='local')) 1085 1086 if expected_documents: 1087 sorted_expected_documents = sorted( 1088 expected_documents, key=lambda doc: doc['_id']) 1089 actual_documents = list( 1090 coll.find({}, sort=[('_id', ASCENDING)])) 1091 self.assertListEqual(sorted_expected_documents, 1092 actual_documents) 1093 1094 def run_scenario(self, spec): 1095 # maybe skip test manually 1096 self.maybe_skip_test(spec) 1097 1098 # process test-level runOnRequirements 1099 run_on_spec = spec.get('runOnRequirements', []) 1100 if not self.should_run_on(run_on_spec): 1101 raise unittest.SkipTest('runOnRequirements not satisfied') 1102 1103 # process skipReason 1104 skip_reason = spec.get('skipReason', None) 1105 if skip_reason is not None: 1106 raise unittest.SkipTest('%s' % (skip_reason,)) 1107 1108 # process createEntities 1109 self.entity_map = EntityMapUtil(self) 1110 self.entity_map.create_entities_from_spec( 1111 self.TEST_SPEC.get('createEntities', [])) 1112 1113 # process initialData 1114 self.insert_initial_data(self.TEST_SPEC.get('initialData', [])) 1115 1116 # process operations 1117 self.run_operations(spec['operations']) 1118 1119 # process expectEvents 1120 self.check_events(spec.get('expectEvents', [])) 1121 1122 # process outcome 1123 self.verify_outcome(spec.get('outcome', [])) 1124 1125 1126class UnifiedSpecTestMeta(type): 1127 """Metaclass for generating test classes.""" 1128 def __init__(cls, *args, **kwargs): 1129 super(UnifiedSpecTestMeta, cls).__init__(*args, **kwargs) 1130 1131 def create_test(spec): 1132 def test_case(self): 1133 self.run_scenario(spec) 1134 return test_case 1135 1136 for test_spec in cls.TEST_SPEC['tests']: 1137 description = test_spec['description'] 1138 test_name = 'test_%s' % (description.strip('. '). 1139 replace(' ', '_').replace('.', '_'),) 1140 test_method = create_test(copy.deepcopy(test_spec)) 1141 test_method.__name__ = str(test_name) 1142 1143 for fail_pattern in cls.EXPECTED_FAILURES: 1144 if re.search(fail_pattern, description): 1145 test_method = unittest.expectedFailure(test_method) 1146 break 1147 1148 setattr(cls, test_name, test_method) 1149 1150 1151_ALL_MIXIN_CLASSES = [ 1152 UnifiedSpecTestMixinV1, 1153 # add mixin classes for new schema major versions here 1154] 1155 1156 1157_SCHEMA_VERSION_MAJOR_TO_MIXIN_CLASS = { 1158 KLASS.SCHEMA_VERSION[0]: KLASS for KLASS in _ALL_MIXIN_CLASSES} 1159 1160 1161def generate_test_classes(test_path, module=__name__, class_name_prefix='', 1162 expected_failures=[], 1163 bypass_test_generation_errors=False): 1164 """Method for generating test classes. Returns a dictionary where keys are 1165 the names of test classes and values are the test class objects.""" 1166 test_klasses = {} 1167 1168 def test_base_class_factory(test_spec): 1169 """Utility that creates the base class to use for test generation. 1170 This is needed to ensure that cls.TEST_SPEC is appropriately set when 1171 the metaclass __init__ is invoked.""" 1172 class SpecTestBase(with_metaclass(UnifiedSpecTestMeta)): 1173 TEST_SPEC = test_spec 1174 EXPECTED_FAILURES = expected_failures 1175 return SpecTestBase 1176 1177 for dirpath, _, filenames in os.walk(test_path): 1178 dirname = os.path.split(dirpath)[-1] 1179 1180 for filename in filenames: 1181 fpath = os.path.join(dirpath, filename) 1182 with open(fpath) as scenario_stream: 1183 # Use tz_aware=False to match how CodecOptions decodes 1184 # dates. 1185 opts = json_util.JSONOptions(tz_aware=False) 1186 scenario_def = json_util.loads( 1187 scenario_stream.read(), json_options=opts) 1188 1189 test_type = os.path.splitext(filename)[0] 1190 snake_class_name = 'Test%s_%s_%s' % ( 1191 class_name_prefix, dirname.replace('-', '_'), 1192 test_type.replace('-', '_').replace('.', '_')) 1193 class_name = snake_to_camel(snake_class_name) 1194 1195 try: 1196 schema_version = Version.from_string( 1197 scenario_def['schemaVersion']) 1198 mixin_class = _SCHEMA_VERSION_MAJOR_TO_MIXIN_CLASS.get( 1199 schema_version[0]) 1200 if mixin_class is None: 1201 raise ValueError( 1202 "test file '%s' has unsupported schemaVersion '%s'" % ( 1203 fpath, schema_version)) 1204 test_klasses[class_name] = type( 1205 class_name, 1206 (mixin_class, test_base_class_factory(scenario_def),), 1207 {'__module__': module}) 1208 except Exception: 1209 if bypass_test_generation_errors: 1210 continue 1211 raise 1212 1213 return test_klasses 1214