1# Copyright 2019-present MongoDB, Inc. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Utilities for testing driver specs.""" 16 17import copy 18import functools 19import threading 20 21 22from bson import decode, encode 23from bson.binary import Binary, STANDARD 24from bson.codec_options import CodecOptions 25from bson.int64 import Int64 26from bson.py3compat import iteritems, abc, string_type, text_type 27from bson.son import SON 28 29from gridfs import GridFSBucket 30 31from pymongo import (client_session, 32 helpers, 33 operations) 34from pymongo.command_cursor import CommandCursor 35from pymongo.cursor import Cursor 36from pymongo.errors import (BulkWriteError, 37 OperationFailure, 38 PyMongoError) 39from pymongo.read_concern import ReadConcern 40from pymongo.read_preferences import ReadPreference 41from pymongo.results import _WriteResult, BulkWriteResult 42from pymongo.write_concern import WriteConcern 43 44from test import (client_context, 45 client_knobs, 46 IntegrationTest, 47 unittest) 48from test.utils import (camel_to_snake, 49 camel_to_snake_args, 50 camel_to_upper_camel, 51 CompareType, 52 CMAPListener, 53 OvertCommandListener, 54 parse_spec_options, 55 parse_read_preference, 56 prepare_spec_arguments, 57 rs_client, 58 ServerAndTopologyEventListener, 59 HeartbeatEventListener) 60 61 62class SpecRunnerThread(threading.Thread): 63 def __init__(self, name): 64 super(SpecRunnerThread, self).__init__() 65 self.name = name 66 self.exc = None 67 self.setDaemon(True) 68 self.cond = threading.Condition() 69 self.ops = [] 70 self.stopped = False 71 72 def schedule(self, work): 73 self.ops.append(work) 74 with self.cond: 75 self.cond.notify() 76 77 def stop(self): 78 self.stopped = True 79 with self.cond: 80 self.cond.notify() 81 82 def run(self): 83 while not self.stopped or self.ops: 84 if not self. ops: 85 with self.cond: 86 self.cond.wait(10) 87 if self.ops: 88 try: 89 work = self.ops.pop(0) 90 work() 91 except Exception as exc: 92 self.exc = exc 93 self.stop() 94 95 96class SpecRunner(IntegrationTest): 97 98 @classmethod 99 def setUpClass(cls): 100 super(SpecRunner, cls).setUpClass() 101 cls.mongos_clients = [] 102 103 # Speed up the tests by decreasing the heartbeat frequency. 104 cls.knobs = client_knobs(heartbeat_frequency=0.1, 105 min_heartbeat_interval=0.1) 106 cls.knobs.enable() 107 108 @classmethod 109 def tearDownClass(cls): 110 cls.knobs.disable() 111 super(SpecRunner, cls).tearDownClass() 112 113 def setUp(self): 114 super(SpecRunner, self).setUp() 115 self.targets = {} 116 self.listener = None 117 self.pool_listener = None 118 self.server_listener = None 119 self.maxDiff = None 120 121 def _set_fail_point(self, client, command_args): 122 cmd = SON([('configureFailPoint', 'failCommand')]) 123 cmd.update(command_args) 124 client.admin.command(cmd) 125 126 def set_fail_point(self, command_args): 127 cmd = SON([('configureFailPoint', 'failCommand')]) 128 cmd.update(command_args) 129 clients = self.mongos_clients if self.mongos_clients else [self.client] 130 for client in clients: 131 self._set_fail_point(client, cmd) 132 133 def targeted_fail_point(self, session, fail_point): 134 """Run the targetedFailPoint test operation. 135 136 Enable the fail point on the session's pinned mongos. 137 """ 138 clients = {c.address: c for c in self.mongos_clients} 139 client = clients[session._pinned_address] 140 self._set_fail_point(client, fail_point) 141 self.addCleanup(self.set_fail_point, {'mode': 'off'}) 142 143 def assert_session_pinned(self, session): 144 """Run the assertSessionPinned test operation. 145 146 Assert that the given session is pinned. 147 """ 148 self.assertIsNotNone(session._transaction.pinned_address) 149 150 def assert_session_unpinned(self, session): 151 """Run the assertSessionUnpinned test operation. 152 153 Assert that the given session is not pinned. 154 """ 155 self.assertIsNone(session._pinned_address) 156 self.assertIsNone(session._transaction.pinned_address) 157 158 def assert_collection_exists(self, database, collection): 159 """Run the assertCollectionExists test operation.""" 160 db = self.client[database] 161 self.assertIn(collection, db.list_collection_names()) 162 163 def assert_collection_not_exists(self, database, collection): 164 """Run the assertCollectionNotExists test operation.""" 165 db = self.client[database] 166 self.assertNotIn(collection, db.list_collection_names()) 167 168 def assert_index_exists(self, database, collection, index): 169 """Run the assertIndexExists test operation.""" 170 coll = self.client[database][collection] 171 self.assertIn(index, [doc['name'] for doc in coll.list_indexes()]) 172 173 def assert_index_not_exists(self, database, collection, index): 174 """Run the assertIndexNotExists test operation.""" 175 coll = self.client[database][collection] 176 self.assertNotIn(index, [doc['name'] for doc in coll.list_indexes()]) 177 178 def assertErrorLabelsContain(self, exc, expected_labels): 179 labels = [l for l in expected_labels if exc.has_error_label(l)] 180 self.assertEqual(labels, expected_labels) 181 182 def assertErrorLabelsOmit(self, exc, omit_labels): 183 for label in omit_labels: 184 self.assertFalse( 185 exc.has_error_label(label), 186 msg='error labels should not contain %s' % (label,)) 187 188 def kill_all_sessions(self): 189 clients = self.mongos_clients if self.mongos_clients else [self.client] 190 for client in clients: 191 try: 192 client.admin.command('killAllSessions', []) 193 except OperationFailure: 194 # "operation was interrupted" by killing the command's 195 # own session. 196 pass 197 198 def check_command_result(self, expected_result, result): 199 # Only compare the keys in the expected result. 200 filtered_result = {} 201 for key in expected_result: 202 try: 203 filtered_result[key] = result[key] 204 except KeyError: 205 pass 206 self.assertEqual(filtered_result, expected_result) 207 208 # TODO: factor the following function with test_crud.py. 209 def check_result(self, expected_result, result): 210 if isinstance(result, _WriteResult): 211 for res in expected_result: 212 prop = camel_to_snake(res) 213 # SPEC-869: Only BulkWriteResult has upserted_count. 214 if (prop == "upserted_count" 215 and not isinstance(result, BulkWriteResult)): 216 if result.upserted_id is not None: 217 upserted_count = 1 218 else: 219 upserted_count = 0 220 self.assertEqual(upserted_count, expected_result[res], prop) 221 elif prop == "inserted_ids": 222 # BulkWriteResult does not have inserted_ids. 223 if isinstance(result, BulkWriteResult): 224 self.assertEqual(len(expected_result[res]), 225 result.inserted_count) 226 else: 227 # InsertManyResult may be compared to [id1] from the 228 # crud spec or {"0": id1} from the retryable write spec. 229 ids = expected_result[res] 230 if isinstance(ids, dict): 231 ids = [ids[str(i)] for i in range(len(ids))] 232 self.assertEqual(ids, result.inserted_ids, prop) 233 elif prop == "upserted_ids": 234 # Convert indexes from strings to integers. 235 ids = expected_result[res] 236 expected_ids = {} 237 for str_index in ids: 238 expected_ids[int(str_index)] = ids[str_index] 239 self.assertEqual(expected_ids, result.upserted_ids, prop) 240 else: 241 self.assertEqual( 242 getattr(result, prop), expected_result[res], prop) 243 244 return True 245 else: 246 self.assertEqual(result, expected_result) 247 248 def get_object_name(self, op): 249 """Allow subclasses to override handling of 'object' 250 251 Transaction spec says 'object' is required. 252 """ 253 return op['object'] 254 255 @staticmethod 256 def parse_options(opts): 257 return parse_spec_options(opts) 258 259 def run_operation(self, sessions, collection, operation): 260 original_collection = collection 261 name = camel_to_snake(operation['name']) 262 if name == 'run_command': 263 name = 'command' 264 elif name == 'download_by_name': 265 name = 'open_download_stream_by_name' 266 elif name == 'download': 267 name = 'open_download_stream' 268 269 database = collection.database 270 collection = database.get_collection(collection.name) 271 if 'collectionOptions' in operation: 272 collection = collection.with_options( 273 **self.parse_options(operation['collectionOptions'])) 274 275 object_name = self.get_object_name(operation) 276 if object_name == 'gridfsbucket': 277 # Only create the GridFSBucket when we need it (for the gridfs 278 # retryable reads tests). 279 obj = GridFSBucket( 280 database, bucket_name=collection.name, 281 disable_md5=True) 282 else: 283 objects = { 284 'client': database.client, 285 'database': database, 286 'collection': collection, 287 'testRunner': self 288 } 289 objects.update(sessions) 290 obj = objects[object_name] 291 292 # Combine arguments with options and handle special cases. 293 arguments = operation.get('arguments', {}) 294 arguments.update(arguments.pop("options", {})) 295 self.parse_options(arguments) 296 297 cmd = getattr(obj, name) 298 299 with_txn_callback = functools.partial( 300 self.run_operations, sessions, original_collection, 301 in_with_transaction=True) 302 prepare_spec_arguments(operation, arguments, name, sessions, 303 with_txn_callback) 304 305 if name == 'run_on_thread': 306 args = {'sessions': sessions, 'collection': collection} 307 args.update(arguments) 308 arguments = args 309 result = cmd(**dict(arguments)) 310 311 # Cleanup open change stream cursors. 312 if name == "watch": 313 self.addCleanup(result.close) 314 315 if name == "aggregate": 316 if arguments["pipeline"] and "$out" in arguments["pipeline"][-1]: 317 # Read from the primary to ensure causal consistency. 318 out = collection.database.get_collection( 319 arguments["pipeline"][-1]["$out"], 320 read_preference=ReadPreference.PRIMARY) 321 return out.find() 322 if name == "map_reduce": 323 if isinstance(result, dict) and 'results' in result: 324 return result['results'] 325 if 'download' in name: 326 result = Binary(result.read()) 327 328 if isinstance(result, Cursor) or isinstance(result, CommandCursor): 329 return list(result) 330 331 return result 332 333 def allowable_errors(self, op): 334 """Allow encryption spec to override expected error classes.""" 335 return (PyMongoError,) 336 337 def _run_op(self, sessions, collection, op, in_with_transaction): 338 expected_result = op.get('result') 339 if expect_error(op): 340 with self.assertRaises(self.allowable_errors(op), 341 msg=op['name']) as context: 342 self.run_operation(sessions, collection, op.copy()) 343 344 if expect_error_message(expected_result): 345 if isinstance(context.exception, BulkWriteError): 346 errmsg = str(context.exception.details).lower() 347 else: 348 errmsg = str(context.exception).lower() 349 self.assertIn(expected_result['errorContains'].lower(), 350 errmsg) 351 if expect_error_code(expected_result): 352 self.assertEqual(expected_result['errorCodeName'], 353 context.exception.details.get('codeName')) 354 if expect_error_labels_contain(expected_result): 355 self.assertErrorLabelsContain( 356 context.exception, 357 expected_result['errorLabelsContain']) 358 if expect_error_labels_omit(expected_result): 359 self.assertErrorLabelsOmit( 360 context.exception, 361 expected_result['errorLabelsOmit']) 362 363 # Reraise the exception if we're in the with_transaction 364 # callback. 365 if in_with_transaction: 366 raise context.exception 367 else: 368 result = self.run_operation(sessions, collection, op.copy()) 369 if 'result' in op: 370 if op['name'] == 'runCommand': 371 self.check_command_result(expected_result, result) 372 else: 373 self.check_result(expected_result, result) 374 375 def run_operations(self, sessions, collection, ops, 376 in_with_transaction=False): 377 for op in ops: 378 self._run_op(sessions, collection, op, in_with_transaction) 379 380 # TODO: factor with test_command_monitoring.py 381 def check_events(self, test, listener, session_ids): 382 res = listener.results 383 if not len(test['expectations']): 384 return 385 386 # Give a nicer message when there are missing or extra events 387 cmds = decode_raw([event.command for event in res['started']]) 388 self.assertEqual( 389 len(res['started']), len(test['expectations']), cmds) 390 for i, expectation in enumerate(test['expectations']): 391 event_type = next(iter(expectation)) 392 event = res['started'][i] 393 394 # The tests substitute 42 for any number other than 0. 395 if (event.command_name == 'getMore' 396 and event.command['getMore']): 397 event.command['getMore'] = Int64(42) 398 elif event.command_name == 'killCursors': 399 event.command['cursors'] = [Int64(42)] 400 elif event.command_name == 'update': 401 # TODO: remove this once PYTHON-1744 is done. 402 # Add upsert and multi fields back into expectations. 403 updates = expectation[event_type]['command']['updates'] 404 for update in updates: 405 update.setdefault('upsert', False) 406 update.setdefault('multi', False) 407 408 # Replace afterClusterTime: 42 with actual afterClusterTime. 409 expected_cmd = expectation[event_type]['command'] 410 expected_read_concern = expected_cmd.get('readConcern') 411 if expected_read_concern is not None: 412 time = expected_read_concern.get('afterClusterTime') 413 if time == 42: 414 actual_time = event.command.get( 415 'readConcern', {}).get('afterClusterTime') 416 if actual_time is not None: 417 expected_read_concern['afterClusterTime'] = actual_time 418 419 recovery_token = expected_cmd.get('recoveryToken') 420 if recovery_token == 42: 421 expected_cmd['recoveryToken'] = CompareType(dict) 422 423 # Replace lsid with a name like "session0" to match test. 424 if 'lsid' in event.command: 425 for name, lsid in session_ids.items(): 426 if event.command['lsid'] == lsid: 427 event.command['lsid'] = name 428 break 429 430 for attr, expected in expectation[event_type].items(): 431 actual = getattr(event, attr) 432 expected = wrap_types(expected) 433 if isinstance(expected, dict): 434 for key, val in expected.items(): 435 if val is None: 436 if key in actual: 437 self.fail("Unexpected key [%s] in %r" % ( 438 key, actual)) 439 elif key not in actual: 440 self.fail("Expected key [%s] in %r" % ( 441 key, actual)) 442 else: 443 self.assertEqual(val, decode_raw(actual[key]), 444 "Key [%s] in %s" % (key, actual)) 445 else: 446 self.assertEqual(actual, expected) 447 448 def maybe_skip_scenario(self, test): 449 if test.get('skipReason'): 450 self.skipTest(test.get('skipReason')) 451 452 def get_scenario_db_name(self, scenario_def): 453 """Allow subclasses to override a test's database name.""" 454 return scenario_def['database_name'] 455 456 def get_scenario_coll_name(self, scenario_def): 457 """Allow subclasses to override a test's collection name.""" 458 return scenario_def['collection_name'] 459 460 def get_outcome_coll_name(self, outcome, collection): 461 """Allow subclasses to override outcome collection.""" 462 return collection.name 463 464 def run_test_ops(self, sessions, collection, test): 465 """Added to allow retryable writes spec to override a test's 466 operation.""" 467 self.run_operations(sessions, collection, test['operations']) 468 469 def parse_client_options(self, opts): 470 """Allow encryption spec to override a clientOptions parsing.""" 471 # Convert test['clientOptions'] to dict to avoid a Jython bug using 472 # "**" with ScenarioDict. 473 return dict(opts) 474 475 def setup_scenario(self, scenario_def): 476 """Allow specs to override a test's setup.""" 477 db_name = self.get_scenario_db_name(scenario_def) 478 coll_name = self.get_scenario_coll_name(scenario_def) 479 db = client_context.client.get_database( 480 db_name, write_concern=WriteConcern(w='majority')) 481 coll = db[coll_name] 482 coll.drop() 483 db.create_collection(coll_name) 484 if scenario_def['data']: 485 # Load data. 486 coll.insert_many(scenario_def['data']) 487 488 def run_scenario(self, scenario_def, test): 489 self.maybe_skip_scenario(test) 490 491 # Kill all sessions before and after each test to prevent an open 492 # transaction (from a test failure) from blocking collection/database 493 # operations during test set up and tear down. 494 self.kill_all_sessions() 495 self.addCleanup(self.kill_all_sessions) 496 self.setup_scenario(scenario_def) 497 database_name = self.get_scenario_db_name(scenario_def) 498 collection_name = self.get_scenario_coll_name(scenario_def) 499 # SPEC-1245 workaround StaleDbVersion on distinct 500 for c in self.mongos_clients: 501 c[database_name][collection_name].distinct("x") 502 503 # Configure the fail point before creating the client. 504 if 'failPoint' in test: 505 fp = test['failPoint'] 506 self.set_fail_point(fp) 507 self.addCleanup(self.set_fail_point, { 508 'configureFailPoint': fp['configureFailPoint'], 'mode': 'off'}) 509 510 listener = OvertCommandListener() 511 pool_listener = CMAPListener() 512 server_listener = ServerAndTopologyEventListener() 513 # Create a new client, to avoid interference from pooled sessions. 514 client_options = self.parse_client_options(test['clientOptions']) 515 # MMAPv1 does not support retryable writes. 516 if (client_options.get('retryWrites') is True and 517 client_context.storage_engine == 'mmapv1'): 518 self.skipTest("MMAPv1 does not support retryWrites=True") 519 use_multi_mongos = test['useMultipleMongoses'] 520 host = None 521 if use_multi_mongos: 522 if client_context.load_balancer: 523 host = client_context.MULTI_MONGOS_LB_URI 524 elif client_context.is_mongos: 525 host = client_context.mongos_seeds() 526 client = rs_client( 527 h=host, 528 event_listeners=[listener, pool_listener, server_listener], 529 **client_options) 530 self.scenario_client = client 531 self.listener = listener 532 self.pool_listener = pool_listener 533 self.server_listener = server_listener 534 # Close the client explicitly to avoid having too many threads open. 535 self.addCleanup(client.close) 536 537 # Create session0 and session1. 538 sessions = {} 539 session_ids = {} 540 for i in range(2): 541 # Don't attempt to create sessions if they are not supported by 542 # the running server version. 543 if not client_context.sessions_enabled: 544 break 545 session_name = 'session%d' % i 546 opts = camel_to_snake_args(test['sessionOptions'][session_name]) 547 if 'default_transaction_options' in opts: 548 txn_opts = self.parse_options( 549 opts['default_transaction_options']) 550 txn_opts = client_session.TransactionOptions(**txn_opts) 551 opts['default_transaction_options'] = txn_opts 552 553 s = client.start_session(**dict(opts)) 554 555 sessions[session_name] = s 556 # Store lsid so we can access it after end_session, in check_events. 557 session_ids[session_name] = s.session_id 558 559 self.addCleanup(end_sessions, sessions) 560 561 collection = client[database_name][collection_name] 562 self.run_test_ops(sessions, collection, test) 563 564 end_sessions(sessions) 565 566 self.check_events(test, listener, session_ids) 567 568 # Disable fail points. 569 if 'failPoint' in test: 570 fp = test['failPoint'] 571 self.set_fail_point({ 572 'configureFailPoint': fp['configureFailPoint'], 'mode': 'off'}) 573 574 # Assert final state is expected. 575 outcome = test['outcome'] 576 expected_c = outcome.get('collection') 577 if expected_c is not None: 578 outcome_coll_name = self.get_outcome_coll_name( 579 outcome, collection) 580 581 # Read from the primary with local read concern to ensure causal 582 # consistency. 583 outcome_coll = client_context.client[ 584 collection.database.name].get_collection( 585 outcome_coll_name, 586 read_preference=ReadPreference.PRIMARY, 587 read_concern=ReadConcern('local')) 588 actual_data = list(outcome_coll.find(sort=[('_id', 1)])) 589 590 # The expected data needs to be the left hand side here otherwise 591 # CompareType(Binary) doesn't work. 592 self.assertEqual(wrap_types(expected_c['data']), actual_data) 593 594 595def expect_any_error(op): 596 if isinstance(op, dict): 597 return op.get('error') 598 599 return False 600 601 602def expect_error_message(expected_result): 603 if isinstance(expected_result, dict): 604 return isinstance(expected_result['errorContains'], text_type) 605 606 return False 607 608 609def expect_error_code(expected_result): 610 if isinstance(expected_result, dict): 611 return expected_result['errorCodeName'] 612 613 return False 614 615 616def expect_error_labels_contain(expected_result): 617 if isinstance(expected_result, dict): 618 return expected_result['errorLabelsContain'] 619 620 return False 621 622 623def expect_error_labels_omit(expected_result): 624 if isinstance(expected_result, dict): 625 return expected_result['errorLabelsOmit'] 626 627 return False 628 629 630def expect_error(op): 631 expected_result = op.get('result') 632 return (expect_any_error(op) or 633 expect_error_message(expected_result) 634 or expect_error_code(expected_result) 635 or expect_error_labels_contain(expected_result) 636 or expect_error_labels_omit(expected_result)) 637 638 639def end_sessions(sessions): 640 for s in sessions.values(): 641 # Aborts the transaction if it's open. 642 s.end_session() 643 644 645OPTS = CodecOptions(document_class=dict, uuid_representation=STANDARD) 646 647 648def decode_raw(val): 649 """Decode RawBSONDocuments in the given container.""" 650 if isinstance(val, (list, abc.Mapping)): 651 return decode(encode({'v': val}, codec_options=OPTS), OPTS)['v'] 652 return val 653 654 655TYPES = { 656 'binData': Binary, 657 'long': Int64, 658} 659 660 661def wrap_types(val): 662 """Support $$type assertion in command results.""" 663 if isinstance(val, list): 664 return [wrap_types(v) for v in val] 665 if isinstance(val, abc.Mapping): 666 typ = val.get('$$type') 667 if typ: 668 return CompareType(TYPES[typ]) 669 d = {} 670 for key in val: 671 d[key] = wrap_types(val[key]) 672 return d 673 return val 674