1# Copyright 2019-present MongoDB, Inc.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Utilities for testing driver specs."""
16
17import copy
18import functools
19import threading
20
21
22from bson import decode, encode
23from bson.binary import Binary, STANDARD
24from bson.codec_options import CodecOptions
25from bson.int64 import Int64
26from bson.py3compat import iteritems, abc, string_type, text_type
27from bson.son import SON
28
29from gridfs import GridFSBucket
30
31from pymongo import (client_session,
32                     helpers,
33                     operations)
34from pymongo.command_cursor import CommandCursor
35from pymongo.cursor import Cursor
36from pymongo.errors import (BulkWriteError,
37                            OperationFailure,
38                            PyMongoError)
39from pymongo.read_concern import ReadConcern
40from pymongo.read_preferences import ReadPreference
41from pymongo.results import _WriteResult, BulkWriteResult
42from pymongo.write_concern import WriteConcern
43
44from test import (client_context,
45                  client_knobs,
46                  IntegrationTest,
47                  unittest)
48from test.utils import (camel_to_snake,
49                        camel_to_snake_args,
50                        camel_to_upper_camel,
51                        CompareType,
52                        CMAPListener,
53                        OvertCommandListener,
54                        parse_spec_options,
55                        parse_read_preference,
56                        prepare_spec_arguments,
57                        rs_client,
58                        ServerAndTopologyEventListener,
59                        HeartbeatEventListener)
60
61
62class SpecRunnerThread(threading.Thread):
63    def __init__(self, name):
64        super(SpecRunnerThread, self).__init__()
65        self.name = name
66        self.exc = None
67        self.setDaemon(True)
68        self.cond = threading.Condition()
69        self.ops = []
70        self.stopped = False
71
72    def schedule(self, work):
73        self.ops.append(work)
74        with self.cond:
75            self.cond.notify()
76
77    def stop(self):
78        self.stopped = True
79        with self.cond:
80            self.cond.notify()
81
82    def run(self):
83        while not self.stopped or self.ops:
84            if not self. ops:
85                with self.cond:
86                    self.cond.wait(10)
87            if self.ops:
88                try:
89                    work = self.ops.pop(0)
90                    work()
91                except Exception as exc:
92                    self.exc = exc
93                    self.stop()
94
95
96class SpecRunner(IntegrationTest):
97
98    @classmethod
99    def setUpClass(cls):
100        super(SpecRunner, cls).setUpClass()
101        cls.mongos_clients = []
102
103        # Speed up the tests by decreasing the heartbeat frequency.
104        cls.knobs = client_knobs(heartbeat_frequency=0.1,
105                                 min_heartbeat_interval=0.1)
106        cls.knobs.enable()
107
108    @classmethod
109    def tearDownClass(cls):
110        cls.knobs.disable()
111        super(SpecRunner, cls).tearDownClass()
112
113    def setUp(self):
114        super(SpecRunner, self).setUp()
115        self.targets = {}
116        self.listener = None
117        self.pool_listener = None
118        self.server_listener = None
119        self.maxDiff = None
120
121    def _set_fail_point(self, client, command_args):
122        cmd = SON([('configureFailPoint', 'failCommand')])
123        cmd.update(command_args)
124        client.admin.command(cmd)
125
126    def set_fail_point(self, command_args):
127        cmd = SON([('configureFailPoint', 'failCommand')])
128        cmd.update(command_args)
129        clients = self.mongos_clients if self.mongos_clients else [self.client]
130        for client in clients:
131            self._set_fail_point(client, cmd)
132
133    def targeted_fail_point(self, session, fail_point):
134        """Run the targetedFailPoint test operation.
135
136        Enable the fail point on the session's pinned mongos.
137        """
138        clients = {c.address: c for c in self.mongos_clients}
139        client = clients[session._pinned_address]
140        self._set_fail_point(client, fail_point)
141        self.addCleanup(self.set_fail_point, {'mode': 'off'})
142
143    def assert_session_pinned(self, session):
144        """Run the assertSessionPinned test operation.
145
146        Assert that the given session is pinned.
147        """
148        self.assertIsNotNone(session._transaction.pinned_address)
149
150    def assert_session_unpinned(self, session):
151        """Run the assertSessionUnpinned test operation.
152
153        Assert that the given session is not pinned.
154        """
155        self.assertIsNone(session._pinned_address)
156        self.assertIsNone(session._transaction.pinned_address)
157
158    def assert_collection_exists(self, database, collection):
159        """Run the assertCollectionExists test operation."""
160        db = self.client[database]
161        self.assertIn(collection, db.list_collection_names())
162
163    def assert_collection_not_exists(self, database, collection):
164        """Run the assertCollectionNotExists test operation."""
165        db = self.client[database]
166        self.assertNotIn(collection, db.list_collection_names())
167
168    def assert_index_exists(self, database, collection, index):
169        """Run the assertIndexExists test operation."""
170        coll = self.client[database][collection]
171        self.assertIn(index, [doc['name'] for doc in coll.list_indexes()])
172
173    def assert_index_not_exists(self, database, collection, index):
174        """Run the assertIndexNotExists test operation."""
175        coll = self.client[database][collection]
176        self.assertNotIn(index, [doc['name'] for doc in coll.list_indexes()])
177
178    def assertErrorLabelsContain(self, exc, expected_labels):
179        labels = [l for l in expected_labels if exc.has_error_label(l)]
180        self.assertEqual(labels, expected_labels)
181
182    def assertErrorLabelsOmit(self, exc, omit_labels):
183        for label in omit_labels:
184            self.assertFalse(
185                exc.has_error_label(label),
186                msg='error labels should not contain %s' % (label,))
187
188    def kill_all_sessions(self):
189        clients = self.mongos_clients if self.mongos_clients else [self.client]
190        for client in clients:
191            try:
192                client.admin.command('killAllSessions', [])
193            except OperationFailure:
194                # "operation was interrupted" by killing the command's
195                # own session.
196                pass
197
198    def check_command_result(self, expected_result, result):
199        # Only compare the keys in the expected result.
200        filtered_result = {}
201        for key in expected_result:
202            try:
203                filtered_result[key] = result[key]
204            except KeyError:
205                pass
206        self.assertEqual(filtered_result, expected_result)
207
208    # TODO: factor the following function with test_crud.py.
209    def check_result(self, expected_result, result):
210        if isinstance(result, _WriteResult):
211            for res in expected_result:
212                prop = camel_to_snake(res)
213                # SPEC-869: Only BulkWriteResult has upserted_count.
214                if (prop == "upserted_count"
215                        and not isinstance(result, BulkWriteResult)):
216                    if result.upserted_id is not None:
217                        upserted_count = 1
218                    else:
219                        upserted_count = 0
220                    self.assertEqual(upserted_count, expected_result[res], prop)
221                elif prop == "inserted_ids":
222                    # BulkWriteResult does not have inserted_ids.
223                    if isinstance(result, BulkWriteResult):
224                        self.assertEqual(len(expected_result[res]),
225                                         result.inserted_count)
226                    else:
227                        # InsertManyResult may be compared to [id1] from the
228                        # crud spec or {"0": id1} from the retryable write spec.
229                        ids = expected_result[res]
230                        if isinstance(ids, dict):
231                            ids = [ids[str(i)] for i in range(len(ids))]
232                        self.assertEqual(ids, result.inserted_ids, prop)
233                elif prop == "upserted_ids":
234                    # Convert indexes from strings to integers.
235                    ids = expected_result[res]
236                    expected_ids = {}
237                    for str_index in ids:
238                        expected_ids[int(str_index)] = ids[str_index]
239                    self.assertEqual(expected_ids, result.upserted_ids, prop)
240                else:
241                    self.assertEqual(
242                        getattr(result, prop), expected_result[res], prop)
243
244            return True
245        else:
246            self.assertEqual(result, expected_result)
247
248    def get_object_name(self, op):
249        """Allow subclasses to override handling of 'object'
250
251        Transaction spec says 'object' is required.
252        """
253        return op['object']
254
255    @staticmethod
256    def parse_options(opts):
257        return parse_spec_options(opts)
258
259    def run_operation(self, sessions, collection, operation):
260        original_collection = collection
261        name = camel_to_snake(operation['name'])
262        if name == 'run_command':
263            name = 'command'
264        elif name == 'download_by_name':
265            name = 'open_download_stream_by_name'
266        elif name == 'download':
267            name = 'open_download_stream'
268
269        database = collection.database
270        collection = database.get_collection(collection.name)
271        if 'collectionOptions' in operation:
272            collection = collection.with_options(
273                **self.parse_options(operation['collectionOptions']))
274
275        object_name = self.get_object_name(operation)
276        if object_name == 'gridfsbucket':
277            # Only create the GridFSBucket when we need it (for the gridfs
278            # retryable reads tests).
279            obj = GridFSBucket(
280                database, bucket_name=collection.name,
281                disable_md5=True)
282        else:
283            objects = {
284                'client': database.client,
285                'database': database,
286                'collection': collection,
287                'testRunner': self
288            }
289            objects.update(sessions)
290            obj = objects[object_name]
291
292        # Combine arguments with options and handle special cases.
293        arguments = operation.get('arguments', {})
294        arguments.update(arguments.pop("options", {}))
295        self.parse_options(arguments)
296
297        cmd = getattr(obj, name)
298
299        with_txn_callback = functools.partial(
300            self.run_operations, sessions, original_collection,
301            in_with_transaction=True)
302        prepare_spec_arguments(operation, arguments, name, sessions,
303                               with_txn_callback)
304
305        if name == 'run_on_thread':
306            args = {'sessions': sessions, 'collection': collection}
307            args.update(arguments)
308            arguments = args
309        result = cmd(**dict(arguments))
310
311        # Cleanup open change stream cursors.
312        if name == "watch":
313            self.addCleanup(result.close)
314
315        if name == "aggregate":
316            if arguments["pipeline"] and "$out" in arguments["pipeline"][-1]:
317                # Read from the primary to ensure causal consistency.
318                out = collection.database.get_collection(
319                    arguments["pipeline"][-1]["$out"],
320                    read_preference=ReadPreference.PRIMARY)
321                return out.find()
322        if name == "map_reduce":
323            if isinstance(result, dict) and 'results' in result:
324                return result['results']
325        if 'download' in name:
326            result = Binary(result.read())
327
328        if isinstance(result, Cursor) or isinstance(result, CommandCursor):
329            return list(result)
330
331        return result
332
333    def allowable_errors(self, op):
334        """Allow encryption spec to override expected error classes."""
335        return (PyMongoError,)
336
337    def _run_op(self, sessions, collection, op, in_with_transaction):
338        expected_result = op.get('result')
339        if expect_error(op):
340            with self.assertRaises(self.allowable_errors(op),
341                                   msg=op['name']) as context:
342                self.run_operation(sessions, collection, op.copy())
343
344            if expect_error_message(expected_result):
345                if isinstance(context.exception, BulkWriteError):
346                    errmsg = str(context.exception.details).lower()
347                else:
348                    errmsg = str(context.exception).lower()
349                self.assertIn(expected_result['errorContains'].lower(),
350                              errmsg)
351            if expect_error_code(expected_result):
352                self.assertEqual(expected_result['errorCodeName'],
353                                 context.exception.details.get('codeName'))
354            if expect_error_labels_contain(expected_result):
355                self.assertErrorLabelsContain(
356                    context.exception,
357                    expected_result['errorLabelsContain'])
358            if expect_error_labels_omit(expected_result):
359                self.assertErrorLabelsOmit(
360                    context.exception,
361                    expected_result['errorLabelsOmit'])
362
363            # Reraise the exception if we're in the with_transaction
364            # callback.
365            if in_with_transaction:
366                raise context.exception
367        else:
368            result = self.run_operation(sessions, collection, op.copy())
369            if 'result' in op:
370                if op['name'] == 'runCommand':
371                    self.check_command_result(expected_result, result)
372                else:
373                    self.check_result(expected_result, result)
374
375    def run_operations(self, sessions, collection, ops,
376                       in_with_transaction=False):
377        for op in ops:
378            self._run_op(sessions, collection, op, in_with_transaction)
379
380    # TODO: factor with test_command_monitoring.py
381    def check_events(self, test, listener, session_ids):
382        res = listener.results
383        if not len(test['expectations']):
384            return
385
386        # Give a nicer message when there are missing or extra events
387        cmds = decode_raw([event.command for event in res['started']])
388        self.assertEqual(
389            len(res['started']), len(test['expectations']), cmds)
390        for i, expectation in enumerate(test['expectations']):
391            event_type = next(iter(expectation))
392            event = res['started'][i]
393
394            # The tests substitute 42 for any number other than 0.
395            if (event.command_name == 'getMore'
396                    and event.command['getMore']):
397                event.command['getMore'] = Int64(42)
398            elif event.command_name == 'killCursors':
399                event.command['cursors'] = [Int64(42)]
400            elif event.command_name == 'update':
401                # TODO: remove this once PYTHON-1744 is done.
402                # Add upsert and multi fields back into expectations.
403                updates = expectation[event_type]['command']['updates']
404                for update in updates:
405                    update.setdefault('upsert', False)
406                    update.setdefault('multi', False)
407
408            # Replace afterClusterTime: 42 with actual afterClusterTime.
409            expected_cmd = expectation[event_type]['command']
410            expected_read_concern = expected_cmd.get('readConcern')
411            if expected_read_concern is not None:
412                time = expected_read_concern.get('afterClusterTime')
413                if time == 42:
414                    actual_time = event.command.get(
415                        'readConcern', {}).get('afterClusterTime')
416                    if actual_time is not None:
417                        expected_read_concern['afterClusterTime'] = actual_time
418
419            recovery_token = expected_cmd.get('recoveryToken')
420            if recovery_token == 42:
421                expected_cmd['recoveryToken'] = CompareType(dict)
422
423            # Replace lsid with a name like "session0" to match test.
424            if 'lsid' in event.command:
425                for name, lsid in session_ids.items():
426                    if event.command['lsid'] == lsid:
427                        event.command['lsid'] = name
428                        break
429
430            for attr, expected in expectation[event_type].items():
431                actual = getattr(event, attr)
432                expected = wrap_types(expected)
433                if isinstance(expected, dict):
434                    for key, val in expected.items():
435                        if val is None:
436                            if key in actual:
437                                self.fail("Unexpected key [%s] in %r" % (
438                                    key, actual))
439                        elif key not in actual:
440                            self.fail("Expected key [%s] in %r" % (
441                                key, actual))
442                        else:
443                            self.assertEqual(val, decode_raw(actual[key]),
444                                             "Key [%s] in %s" % (key, actual))
445                else:
446                    self.assertEqual(actual, expected)
447
448    def maybe_skip_scenario(self, test):
449        if test.get('skipReason'):
450            self.skipTest(test.get('skipReason'))
451
452    def get_scenario_db_name(self, scenario_def):
453        """Allow subclasses to override a test's database name."""
454        return scenario_def['database_name']
455
456    def get_scenario_coll_name(self, scenario_def):
457        """Allow subclasses to override a test's collection name."""
458        return scenario_def['collection_name']
459
460    def get_outcome_coll_name(self, outcome, collection):
461        """Allow subclasses to override outcome collection."""
462        return collection.name
463
464    def run_test_ops(self, sessions, collection, test):
465        """Added to allow retryable writes spec to override a test's
466        operation."""
467        self.run_operations(sessions, collection, test['operations'])
468
469    def parse_client_options(self, opts):
470        """Allow encryption spec to override a clientOptions parsing."""
471        # Convert test['clientOptions'] to dict to avoid a Jython bug using
472        # "**" with ScenarioDict.
473        return dict(opts)
474
475    def setup_scenario(self, scenario_def):
476        """Allow specs to override a test's setup."""
477        db_name = self.get_scenario_db_name(scenario_def)
478        coll_name = self.get_scenario_coll_name(scenario_def)
479        db = client_context.client.get_database(
480            db_name, write_concern=WriteConcern(w='majority'))
481        coll = db[coll_name]
482        coll.drop()
483        db.create_collection(coll_name)
484        if scenario_def['data']:
485            # Load data.
486            coll.insert_many(scenario_def['data'])
487
488    def run_scenario(self, scenario_def, test):
489        self.maybe_skip_scenario(test)
490
491        # Kill all sessions before and after each test to prevent an open
492        # transaction (from a test failure) from blocking collection/database
493        # operations during test set up and tear down.
494        self.kill_all_sessions()
495        self.addCleanup(self.kill_all_sessions)
496        self.setup_scenario(scenario_def)
497        database_name = self.get_scenario_db_name(scenario_def)
498        collection_name = self.get_scenario_coll_name(scenario_def)
499        # SPEC-1245 workaround StaleDbVersion on distinct
500        for c in self.mongos_clients:
501            c[database_name][collection_name].distinct("x")
502
503        # Configure the fail point before creating the client.
504        if 'failPoint' in test:
505            fp = test['failPoint']
506            self.set_fail_point(fp)
507            self.addCleanup(self.set_fail_point, {
508                'configureFailPoint': fp['configureFailPoint'], 'mode': 'off'})
509
510        listener = OvertCommandListener()
511        pool_listener = CMAPListener()
512        server_listener = ServerAndTopologyEventListener()
513        # Create a new client, to avoid interference from pooled sessions.
514        client_options = self.parse_client_options(test['clientOptions'])
515        # MMAPv1 does not support retryable writes.
516        if (client_options.get('retryWrites') is True and
517                client_context.storage_engine == 'mmapv1'):
518            self.skipTest("MMAPv1 does not support retryWrites=True")
519        use_multi_mongos = test['useMultipleMongoses']
520        host = None
521        if use_multi_mongos:
522            if client_context.load_balancer:
523                host = client_context.MULTI_MONGOS_LB_URI
524            elif client_context.is_mongos:
525                host = client_context.mongos_seeds()
526        client = rs_client(
527            h=host,
528            event_listeners=[listener, pool_listener, server_listener],
529            **client_options)
530        self.scenario_client = client
531        self.listener = listener
532        self.pool_listener = pool_listener
533        self.server_listener = server_listener
534        # Close the client explicitly to avoid having too many threads open.
535        self.addCleanup(client.close)
536
537        # Create session0 and session1.
538        sessions = {}
539        session_ids = {}
540        for i in range(2):
541            # Don't attempt to create sessions if they are not supported by
542            # the running server version.
543            if not client_context.sessions_enabled:
544                break
545            session_name = 'session%d' % i
546            opts = camel_to_snake_args(test['sessionOptions'][session_name])
547            if 'default_transaction_options' in opts:
548                txn_opts = self.parse_options(
549                    opts['default_transaction_options'])
550                txn_opts = client_session.TransactionOptions(**txn_opts)
551                opts['default_transaction_options'] = txn_opts
552
553            s = client.start_session(**dict(opts))
554
555            sessions[session_name] = s
556            # Store lsid so we can access it after end_session, in check_events.
557            session_ids[session_name] = s.session_id
558
559        self.addCleanup(end_sessions, sessions)
560
561        collection = client[database_name][collection_name]
562        self.run_test_ops(sessions, collection, test)
563
564        end_sessions(sessions)
565
566        self.check_events(test, listener, session_ids)
567
568        # Disable fail points.
569        if 'failPoint' in test:
570            fp = test['failPoint']
571            self.set_fail_point({
572                'configureFailPoint': fp['configureFailPoint'], 'mode': 'off'})
573
574        # Assert final state is expected.
575        outcome = test['outcome']
576        expected_c = outcome.get('collection')
577        if expected_c is not None:
578            outcome_coll_name = self.get_outcome_coll_name(
579                outcome, collection)
580
581            # Read from the primary with local read concern to ensure causal
582            # consistency.
583            outcome_coll = client_context.client[
584                collection.database.name].get_collection(
585                outcome_coll_name,
586                read_preference=ReadPreference.PRIMARY,
587                read_concern=ReadConcern('local'))
588            actual_data = list(outcome_coll.find(sort=[('_id', 1)]))
589
590            # The expected data needs to be the left hand side here otherwise
591            # CompareType(Binary) doesn't work.
592            self.assertEqual(wrap_types(expected_c['data']), actual_data)
593
594
595def expect_any_error(op):
596    if isinstance(op, dict):
597        return op.get('error')
598
599    return False
600
601
602def expect_error_message(expected_result):
603    if isinstance(expected_result, dict):
604        return isinstance(expected_result['errorContains'], text_type)
605
606    return False
607
608
609def expect_error_code(expected_result):
610    if isinstance(expected_result, dict):
611        return expected_result['errorCodeName']
612
613    return False
614
615
616def expect_error_labels_contain(expected_result):
617    if isinstance(expected_result, dict):
618        return expected_result['errorLabelsContain']
619
620    return False
621
622
623def expect_error_labels_omit(expected_result):
624    if isinstance(expected_result, dict):
625        return expected_result['errorLabelsOmit']
626
627    return False
628
629
630def expect_error(op):
631    expected_result = op.get('result')
632    return (expect_any_error(op) or
633            expect_error_message(expected_result)
634            or expect_error_code(expected_result)
635            or expect_error_labels_contain(expected_result)
636            or expect_error_labels_omit(expected_result))
637
638
639def end_sessions(sessions):
640    for s in sessions.values():
641        # Aborts the transaction if it's open.
642        s.end_session()
643
644
645OPTS = CodecOptions(document_class=dict, uuid_representation=STANDARD)
646
647
648def decode_raw(val):
649    """Decode RawBSONDocuments in the given container."""
650    if isinstance(val, (list, abc.Mapping)):
651        return decode(encode({'v': val}, codec_options=OPTS), OPTS)['v']
652    return val
653
654
655TYPES = {
656    'binData': Binary,
657    'long': Int64,
658}
659
660
661def wrap_types(val):
662    """Support $$type assertion in command results."""
663    if isinstance(val, list):
664        return [wrap_types(v) for v in val]
665    if isinstance(val, abc.Mapping):
666        typ = val.get('$$type')
667        if typ:
668            return CompareType(TYPES[typ])
669        d = {}
670        for key in val:
671            d[key] = wrap_types(val[key])
672        return d
673    return val
674