1# Copyright 2020-present MongoDB, Inc.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Unified test format runner.
16
17https://github.com/mongodb/specifications/blob/master/source/unified-test-format/unified-test-format.rst
18"""
19
20import copy
21import datetime
22import functools
23import os
24import re
25import sys
26import types
27
28from bson import json_util, Code, Decimal128, DBRef, SON, Int64, MaxKey, MinKey
29from bson.binary import Binary
30from bson.objectid import ObjectId
31from bson.py3compat import abc, integer_types, iteritems, text_type, PY3
32from bson.regex import Regex, RE_TYPE
33
34from gridfs import GridFSBucket
35
36from pymongo import ASCENDING, MongoClient
37from pymongo.client_session import ClientSession, TransactionOptions, _TxnState
38from pymongo.change_stream import ChangeStream
39from pymongo.collection import Collection
40from pymongo.database import Database
41from pymongo.errors import (
42    BulkWriteError, ConnectionFailure, ConfigurationError, InvalidOperation,
43    NotPrimaryError, PyMongoError)
44from pymongo.monitoring import (
45    CommandFailedEvent, CommandListener, CommandStartedEvent,
46    CommandSucceededEvent, _SENSITIVE_COMMANDS, PoolCreatedEvent,
47    PoolClearedEvent, PoolClosedEvent, ConnectionCreatedEvent,
48    ConnectionReadyEvent, ConnectionClosedEvent,
49    ConnectionCheckOutStartedEvent, ConnectionCheckOutFailedEvent,
50    ConnectionCheckedOutEvent, ConnectionCheckedInEvent)
51from pymongo.read_concern import ReadConcern
52from pymongo.read_preferences import ReadPreference
53from pymongo.results import BulkWriteResult
54from pymongo.server_api import ServerApi
55from pymongo.write_concern import WriteConcern
56
57from test import client_context, unittest, IntegrationTest
58from test.utils import (
59    camel_to_snake, get_pool, rs_or_single_client, single_client,
60    snake_to_camel, CMAPListener)
61
62from test.version import Version
63from test.utils import (
64    camel_to_snake_args, parse_collection_options, parse_spec_options,
65    prepare_spec_arguments)
66
67
68JSON_OPTS = json_util.JSONOptions(tz_aware=False)
69
70
71def with_metaclass(meta, *bases):
72    """Create a base class with a metaclass.
73
74    Vendored from six: https://github.com/benjaminp/six/blob/master/six.py
75    """
76    # This requires a bit of explanation: the basic idea is to make a dummy
77    # metaclass for one level of class instantiation that replaces itself with
78    # the actual metaclass.
79    class metaclass(type):
80
81        def __new__(cls, name, this_bases, d):
82            if sys.version_info[:2] >= (3, 7):
83                # This version introduced PEP 560 that requires a bit
84                # of extra care (we mimic what is done by __build_class__).
85                resolved_bases = types.resolve_bases(bases)
86                if resolved_bases is not bases:
87                    d['__orig_bases__'] = bases
88            else:
89                resolved_bases = bases
90            return meta(name, resolved_bases, d)
91
92        @classmethod
93        def __prepare__(cls, name, this_bases):
94            return meta.__prepare__(name, bases)
95    return type.__new__(metaclass, 'temporary_class', (), {})
96
97
98def is_run_on_requirement_satisfied(requirement):
99    topology_satisfied = True
100    req_topologies = requirement.get('topologies')
101    if req_topologies:
102        topology_satisfied = client_context.is_topology_type(
103            req_topologies)
104
105    server_version = Version(*client_context.version[:3])
106
107    min_version_satisfied = True
108    req_min_server_version = requirement.get('minServerVersion')
109    if req_min_server_version:
110        min_version_satisfied = Version.from_string(
111            req_min_server_version) <= server_version
112
113    max_version_satisfied = True
114    req_max_server_version = requirement.get('maxServerVersion')
115    if req_max_server_version:
116        max_version_satisfied = Version.from_string(
117            req_max_server_version) >= server_version
118
119    params_satisfied = True
120    params = requirement.get('serverParameters')
121    if params:
122        for param, val in params.items():
123            if param not in client_context.server_parameters:
124                params_satisfied = False
125            elif client_context.server_parameters[param] != val:
126                params_satisfied = False
127
128    auth_satisfied = True
129    req_auth = requirement.get('auth')
130    if req_auth is not None:
131        if req_auth:
132            auth_satisfied = client_context.auth_enabled
133        else:
134            auth_satisfied = not client_context.auth_enabled
135
136    return (topology_satisfied and min_version_satisfied and
137            max_version_satisfied and params_satisfied and auth_satisfied)
138
139
140def parse_collection_or_database_options(options):
141    return parse_collection_options(options)
142
143
144def parse_bulk_write_result(result):
145    upserted_ids = {str(int_idx): result.upserted_ids[int_idx]
146                    for int_idx in result.upserted_ids}
147    return {
148        'deletedCount': result.deleted_count,
149        'insertedCount': result.inserted_count,
150        'matchedCount': result.matched_count,
151        'modifiedCount': result.modified_count,
152        'upsertedCount': result.upserted_count,
153        'upsertedIds': upserted_ids}
154
155
156def parse_bulk_write_error_result(error):
157    write_result = BulkWriteResult(error.details, True)
158    return parse_bulk_write_result(write_result)
159
160
161class NonLazyCursor(object):
162    """A find cursor proxy that creates the remote cursor when initialized."""
163    def __init__(self, find_cursor):
164        self.find_cursor = find_cursor
165        # Create the server side cursor.
166        self.first_result = next(find_cursor, None)
167
168    def __iter__(self):
169        return self
170
171    def __next__(self):
172        if self.first_result is not None:
173            first = self.first_result
174            self.first_result = None
175            return first
176        return next(self.find_cursor)
177
178    next = __next__
179
180    def close(self):
181        self.find_cursor.close()
182
183
184class EventListenerUtil(CMAPListener, CommandListener):
185    def __init__(self, observe_events, ignore_commands,
186                 observe_sensitive_commands):
187        self._event_types = set(name.lower() for name in observe_events)
188        if observe_sensitive_commands:
189            self._ignore_commands = set(ignore_commands)
190        else:
191            self._ignore_commands = _SENSITIVE_COMMANDS | set(ignore_commands)
192            self._ignore_commands.add('configurefailpoint')
193        super(EventListenerUtil, self).__init__()
194
195    def get_events(self, event_type):
196        if event_type == 'command':
197            return [e for e in self.events if 'Command' in type(e).__name__]
198        return [e for e in self.events if 'Command' not in type(e).__name__]
199
200    def add_event(self, event):
201        if type(event).__name__.lower() in self._event_types:
202            super(EventListenerUtil, self).add_event(event)
203
204    def _command_event(self, event):
205        if event.command_name.lower() not in self._ignore_commands:
206            self.add_event(event)
207
208    def started(self, event):
209        self._command_event(event)
210
211    def succeeded(self, event):
212        self._command_event(event)
213
214    def failed(self, event):
215        self._command_event(event)
216
217
218class EntityMapUtil(object):
219    """Utility class that implements an entity map as per the unified
220    test format specification."""
221    def __init__(self, test_class):
222        self._entities = {}
223        self._listeners = {}
224        self._session_lsids = {}
225        self.test = test_class
226
227    def __getitem__(self, item):
228        try:
229            return self._entities[item]
230        except KeyError:
231            self.test.fail('Could not find entity named %s in map' % (
232                item,))
233
234    def __setitem__(self, key, value):
235        if not isinstance(key, text_type):
236            self.test.fail(
237                'Expected entity name of type str, got %s' % (type(key)))
238
239        if key in self._entities:
240            self.test.fail('Entity named %s already in map' % (key,))
241
242        self._entities[key] = value
243
244    def _create_entity(self, entity_spec):
245        if len(entity_spec) != 1:
246            self.test.fail(
247                "Entity spec %s did not contain exactly one top-level key" % (
248                    entity_spec,))
249
250        entity_type, spec = next(iteritems(entity_spec))
251        if entity_type == 'client':
252            kwargs = {}
253            observe_events = spec.get('observeEvents', [])
254            ignore_commands = spec.get('ignoreCommandMonitoringEvents', [])
255            observe_sensitive_commands = spec.get(
256                'observeSensitiveCommands', False)
257            # TODO: SUPPORT storeEventsAsEntities
258            if len(observe_events) or len(ignore_commands):
259                ignore_commands = [cmd.lower() for cmd in ignore_commands]
260                listener = EventListenerUtil(
261                    observe_events, ignore_commands, observe_sensitive_commands)
262                self._listeners[spec['id']] = listener
263                kwargs['event_listeners'] = [listener]
264            if spec.get('useMultipleMongoses'):
265                if client_context.load_balancer:
266                    kwargs['h'] = client_context.MULTI_MONGOS_LB_URI
267                elif client_context.is_mongos:
268                    kwargs['h'] = client_context.mongos_seeds()
269            kwargs.update(spec.get('uriOptions', {}))
270            server_api = spec.get('serverApi')
271            if server_api:
272                kwargs['server_api'] = ServerApi(
273                    server_api['version'], strict=server_api.get('strict'),
274                    deprecation_errors=server_api.get('deprecationErrors'))
275            client = rs_or_single_client(**kwargs)
276            self[spec['id']] = client
277            self.test.addCleanup(client.close)
278            return
279        elif entity_type == 'database':
280            client = self[spec['client']]
281            if not isinstance(client, MongoClient):
282                self.test.fail(
283                    'Expected entity %s to be of type MongoClient, got %s' % (
284                        spec['client'], type(client)))
285            options = parse_collection_or_database_options(
286                spec.get('databaseOptions', {}))
287            self[spec['id']] = client.get_database(
288                spec['databaseName'], **options)
289            return
290        elif entity_type == 'collection':
291            database = self[spec['database']]
292            if not isinstance(database, Database):
293                self.test.fail(
294                    'Expected entity %s to be of type Database, got %s' % (
295                        spec['database'], type(database)))
296            options = parse_collection_or_database_options(
297                spec.get('collectionOptions', {}))
298            self[spec['id']] = database.get_collection(
299                spec['collectionName'], **options)
300            return
301        elif entity_type == 'session':
302            client = self[spec['client']]
303            if not isinstance(client, MongoClient):
304                self.test.fail(
305                    'Expected entity %s to be of type MongoClient, got %s' % (
306                        spec['client'], type(client)))
307            opts = camel_to_snake_args(spec.get('sessionOptions', {}))
308            if 'default_transaction_options' in opts:
309                txn_opts = parse_spec_options(
310                    opts['default_transaction_options'])
311                txn_opts = TransactionOptions(**txn_opts)
312                opts = copy.deepcopy(opts)
313                opts['default_transaction_options'] = txn_opts
314            session = client.start_session(**dict(opts))
315            self[spec['id']] = session
316            self._session_lsids[spec['id']] = copy.deepcopy(session.session_id)
317            self.test.addCleanup(session.end_session)
318            return
319        elif entity_type == 'bucket':
320            # TODO: implement the 'bucket' entity type
321            self.test.skipTest(
322                'GridFS is not currently supported (PYTHON-2459)')
323        self.test.fail(
324            'Unable to create entity of unknown type %s' % (entity_type,))
325
326    def create_entities_from_spec(self, entity_spec):
327        for spec in entity_spec:
328            self._create_entity(spec)
329
330    def get_listener_for_client(self, client_name):
331        client = self[client_name]
332        if not isinstance(client, MongoClient):
333            self.test.fail(
334                'Expected entity %s to be of type MongoClient, got %s' % (
335                        client_name, type(client)))
336
337        listener = self._listeners.get(client_name)
338        if not listener:
339            self.test.fail(
340                'No listeners configured for client %s' % (client_name,))
341
342        return listener
343
344    def get_lsid_for_session(self, session_name):
345        session = self[session_name]
346        if not isinstance(session, ClientSession):
347            self.test.fail(
348                'Expected entity %s to be of type ClientSession, got %s' % (
349                    session_name, type(session)))
350
351        try:
352            return session.session_id
353        except InvalidOperation:
354            # session has been closed.
355            return self._session_lsids[session_name]
356
357
358if not PY3:
359    binary_types = (Binary,)
360    long_types = (Int64, long)
361    unicode_type = unicode
362else:
363    binary_types = (Binary, bytes)
364    long_types = (Int64,)
365    unicode_type = str
366
367
368BSON_TYPE_ALIAS_MAP = {
369    # https://docs.mongodb.com/manual/reference/operator/query/type/
370    # https://pymongo.readthedocs.io/en/stable/api/bson/index.html
371    'double': (float,),
372    'string': (text_type,),
373    'object': (abc.Mapping,),
374    'array': (abc.MutableSequence,),
375    'binData': binary_types,
376    'undefined': (type(None),),
377    'objectId': (ObjectId,),
378    'bool': (bool,),
379    'date': (datetime.datetime,),
380    'null': (type(None),),
381    'regex': (Regex, RE_TYPE),
382    'dbPointer': (DBRef,),
383    'javascript': (unicode_type, Code),
384    'symbol': (unicode_type,),
385    'javascriptWithScope': (unicode_type, Code),
386    'int': (int,),
387    'long': long_types,
388    'decimal': (Decimal128,),
389    'maxKey': (MaxKey,),
390    'minKey': (MinKey,),
391}
392
393
394class MatchEvaluatorUtil(object):
395    """Utility class that implements methods for evaluating matches as per
396    the unified test format specification."""
397    def __init__(self, test_class):
398        self.test = test_class
399
400    def _operation_exists(self, spec, actual, key_to_compare):
401        if spec is True:
402            self.test.assertIn(key_to_compare, actual)
403        elif spec is False:
404            self.test.assertNotIn(key_to_compare, actual)
405        else:
406            self.test.fail(
407                'Expected boolean value for $$exists operator, got %s' % (
408                    spec,))
409
410    def __type_alias_to_type(self, alias):
411        if alias not in BSON_TYPE_ALIAS_MAP:
412            self.test.fail('Unrecognized BSON type alias %s' % (alias,))
413        return BSON_TYPE_ALIAS_MAP[alias]
414
415    def _operation_type(self, spec, actual, key_to_compare):
416        if isinstance(spec, abc.MutableSequence):
417            permissible_types = tuple([
418                t for alias in spec for t in self.__type_alias_to_type(alias)])
419        else:
420            permissible_types = self.__type_alias_to_type(spec)
421        self.test.assertIsInstance(
422            actual[key_to_compare], permissible_types)
423
424    def _operation_matchesEntity(self, spec, actual, key_to_compare):
425        expected_entity = self.test.entity_map[spec]
426        self.test.assertIsInstance(expected_entity, abc.Mapping)
427        self.test.assertEqual(expected_entity, actual[key_to_compare])
428
429    def _operation_matchesHexBytes(self, spec, actual, key_to_compare):
430        raise NotImplementedError
431
432    def _operation_unsetOrMatches(self, spec, actual, key_to_compare):
433        if key_to_compare is None and not actual:
434            # top-level document can be None when unset
435            return
436
437        if key_to_compare not in actual:
438            # we add a dummy value for the compared key to pass map size check
439            actual[key_to_compare] = 'dummyValue'
440            return
441        self.match_result(spec, actual[key_to_compare], in_recursive_call=True)
442
443    def _operation_sessionLsid(self, spec, actual, key_to_compare):
444        expected_lsid = self.test.entity_map.get_lsid_for_session(spec)
445        self.test.assertEqual(expected_lsid, actual[key_to_compare])
446
447    def _evaluate_special_operation(self, opname, spec, actual,
448                                    key_to_compare):
449        method_name = '_operation_%s' % (opname.strip('$'),)
450        try:
451            method = getattr(self, method_name)
452        except AttributeError:
453            self.test.fail(
454                'Unsupported special matching operator %s' % (opname,))
455        else:
456            method(spec, actual, key_to_compare)
457
458    def _evaluate_if_special_operation(self, expectation, actual,
459                                       key_to_compare=None):
460        """Returns True if a special operation is evaluated, False
461        otherwise. If the ``expectation`` map contains a single key,
462        value pair we check it for a special operation.
463        If given, ``key_to_compare`` is assumed to be the key in
464        ``expectation`` whose corresponding value needs to be
465        evaluated for a possible special operation. ``key_to_compare``
466        is ignored when ``expectation`` has only one key."""
467        if not isinstance(expectation, abc.Mapping):
468            return False
469
470        is_special_op, opname, spec = False, False, False
471
472        if key_to_compare is not None:
473            if key_to_compare.startswith('$$'):
474                is_special_op = True
475                opname = key_to_compare
476                spec = expectation[key_to_compare]
477                key_to_compare = None
478            else:
479                nested = expectation[key_to_compare]
480                if isinstance(nested, abc.Mapping) and len(nested) == 1:
481                    opname, spec = next(iteritems(nested))
482                    if opname.startswith('$$'):
483                        is_special_op = True
484        elif len(expectation) == 1:
485            opname, spec = next(iteritems(expectation))
486            if opname.startswith('$$'):
487                is_special_op = True
488                key_to_compare = None
489
490        if is_special_op:
491            self._evaluate_special_operation(
492                opname=opname,
493                spec=spec,
494                actual=actual,
495                key_to_compare=key_to_compare)
496            return True
497
498        return False
499
500    def _match_document(self, expectation, actual, is_root):
501        if self._evaluate_if_special_operation(expectation, actual):
502            return
503
504        self.test.assertIsInstance(actual, abc.Mapping)
505        for key, value in iteritems(expectation):
506            if self._evaluate_if_special_operation(expectation, actual, key):
507                continue
508
509            self.test.assertIn(key, actual)
510            self.match_result(value, actual[key], in_recursive_call=True)
511
512        if not is_root:
513            expected_keys = set(expectation.keys())
514            for key, value in expectation.items():
515                if value == {'$$exists': False}:
516                    expected_keys.remove(key)
517            self.test.assertEqual(expected_keys, set(actual.keys()))
518
519    def match_result(self, expectation, actual,
520                     in_recursive_call=False):
521        if isinstance(expectation, abc.Mapping):
522            return self._match_document(
523                expectation, actual, is_root=not in_recursive_call)
524
525        if isinstance(expectation, abc.MutableSequence):
526            self.test.assertIsInstance(actual, abc.MutableSequence)
527            for e, a in zip(expectation, actual):
528                if isinstance(e, abc.Mapping):
529                    self._match_document(
530                        e, a, is_root=not in_recursive_call)
531                else:
532                    self.match_result(e, a, in_recursive_call=True)
533                return
534
535        # account for flexible numerics in element-wise comparison and unicode
536        # vs str on Python 2.
537        if not (isinstance(expectation, integer_types) or
538                isinstance(expectation, float) or
539                (isinstance(expectation, unicode_type) and not PY3)):
540            self.test.assertIsInstance(actual, type(expectation))
541        self.test.assertEqual(expectation, actual)
542
543    def assertHasServiceId(self, spec, actual):
544        if 'hasServiceId' in spec:
545            if spec.get('hasServiceId'):
546                self.test.assertIsNotNone(actual.service_id)
547                self.test.assertIsInstance(actual.service_id, ObjectId)
548            else:
549                self.test.assertIsNone(actual.service_id)
550
551    def match_event(self, event_type, expectation, actual):
552        name, spec = next(iteritems(expectation))
553
554        # every command event has the commandName field
555        if event_type == 'command':
556            command_name = spec.get('commandName')
557            if command_name:
558                self.test.assertEqual(command_name, actual.command_name)
559
560        if name == 'commandStartedEvent':
561            self.test.assertIsInstance(actual, CommandStartedEvent)
562            command = spec.get('command')
563            database_name = spec.get('databaseName')
564            if command:
565                if actual.command_name == 'update':
566                    # TODO: remove this once PYTHON-1744 is done.
567                    # Add upsert and multi fields back into expectations.
568                    for update in command.get('updates', []):
569                        update.setdefault('upsert', False)
570                        update.setdefault('multi', False)
571                self.match_result(command, actual.command)
572            if database_name:
573                self.test.assertEqual(
574                    database_name, actual.database_name)
575            self.assertHasServiceId(spec, actual)
576        elif name == 'commandSucceededEvent':
577            self.test.assertIsInstance(actual, CommandSucceededEvent)
578            reply = spec.get('reply')
579            if reply:
580                self.match_result(reply, actual.reply)
581            self.assertHasServiceId(spec, actual)
582        elif name == 'commandFailedEvent':
583            self.test.assertIsInstance(actual, CommandFailedEvent)
584            self.assertHasServiceId(spec, actual)
585        elif name == 'poolCreatedEvent':
586            self.test.assertIsInstance(actual, PoolCreatedEvent)
587        elif name == 'poolReadyEvent':
588            # PyMongo 3.X does not support PoolReadyEvent.
589            assert False
590        elif name == 'poolClearedEvent':
591            self.test.assertIsInstance(actual, PoolClearedEvent)
592            self.assertHasServiceId(spec, actual)
593        elif name == 'poolClosedEvent':
594            self.test.assertIsInstance(actual, PoolClosedEvent)
595        elif name == 'connectionCreatedEvent':
596            self.test.assertIsInstance(actual, ConnectionCreatedEvent)
597        elif name == 'connectionReadyEvent':
598            self.test.assertIsInstance(actual, ConnectionReadyEvent)
599        elif name == 'connectionClosedEvent':
600            self.test.assertIsInstance(actual, ConnectionClosedEvent)
601            if 'reason' in spec:
602                self.test.assertEqual(actual.reason, spec['reason'])
603        elif name == 'connectionCheckOutStartedEvent':
604            self.test.assertIsInstance(actual, ConnectionCheckOutStartedEvent)
605        elif name == 'connectionCheckOutFailedEvent':
606            self.test.assertIsInstance(actual, ConnectionCheckOutFailedEvent)
607            if 'reason' in spec:
608                self.test.assertEqual(actual.reason, spec['reason'])
609        elif name == 'connectionCheckedOutEvent':
610            self.test.assertIsInstance(actual, ConnectionCheckedOutEvent)
611        elif name == 'connectionCheckedInEvent':
612            self.test.assertIsInstance(actual, ConnectionCheckedInEvent)
613        else:
614            self.test.fail(
615                'Unsupported event type %s' % (name,))
616
617
618def coerce_result(opname, result):
619    """Convert a pymongo result into the spec's result format."""
620    if hasattr(result, 'acknowledged') and not result.acknowledged:
621        return {'acknowledged': False}
622    if opname == 'bulkWrite':
623        return parse_bulk_write_result(result)
624    if opname == 'insertOne':
625        return {'insertedId': result.inserted_id}
626    if opname == 'insertMany':
627        return {idx: _id for idx, _id in enumerate(result.inserted_ids)}
628    if opname in ('deleteOne', 'deleteMany'):
629        return {'deletedCount': result.deleted_count}
630    if opname in ('updateOne', 'updateMany', 'replaceOne'):
631        return {
632            'matchedCount': result.matched_count,
633            'modifiedCount': result.modified_count,
634            'upsertedCount': 0 if result.upserted_id is None else 1,
635        }
636    return result
637
638
639class UnifiedSpecTestMixinV1(IntegrationTest):
640    """Mixin class to run test cases from test specification files.
641
642    Assumes that tests conform to the `unified test format
643    <https://github.com/mongodb/specifications/blob/master/source/unified-test-format/unified-test-format.rst>`_.
644
645    Specification of the test suite being currently run is available as
646    a class attribute ``TEST_SPEC``.
647    """
648    SCHEMA_VERSION = Version.from_string('1.5')
649    RUN_ON_LOAD_BALANCER = True
650
651    @staticmethod
652    def should_run_on(run_on_spec):
653        if not run_on_spec:
654            # Always run these tests.
655            return True
656
657        for req in run_on_spec:
658            if is_run_on_requirement_satisfied(req):
659                return True
660        return False
661
662    def insert_initial_data(self, initial_data):
663        for collection_data in initial_data:
664            coll_name = collection_data['collectionName']
665            db_name = collection_data['databaseName']
666            documents = collection_data['documents']
667
668            coll = self.client.get_database(db_name).get_collection(
669                coll_name, write_concern=WriteConcern(w="majority"))
670            coll.drop()
671
672            if len(documents) > 0:
673                coll.insert_many(documents)
674            else:
675                # ensure collection exists
676                result = coll.insert_one({})
677                coll.delete_one({'_id': result.inserted_id})
678
679    @classmethod
680    def setUpClass(cls):
681        # super call creates internal client cls.client
682        super(UnifiedSpecTestMixinV1, cls).setUpClass()
683
684        # process file-level runOnRequirements
685        run_on_spec = cls.TEST_SPEC.get('runOnRequirements', [])
686        if not cls.should_run_on(run_on_spec):
687            raise unittest.SkipTest(
688                '%s runOnRequirements not satisfied' % (cls.__name__,))
689
690        # add any special-casing for skipping tests here
691        if client_context.storage_engine == 'mmapv1':
692            if 'retryable-writes' in cls.TEST_SPEC['description']:
693                raise unittest.SkipTest(
694                    "MMAPv1 does not support retryWrites=True")
695
696    def setUp(self):
697        super(UnifiedSpecTestMixinV1, self).setUp()
698
699        # process schemaVersion
700        # note: we check major schema version during class generation
701        # note: we do this here because we cannot run assertions in setUpClass
702        version = Version.from_string(self.TEST_SPEC['schemaVersion'])
703        self.assertLessEqual(
704            version, self.SCHEMA_VERSION,
705            'expected schema version %s or lower, got %s' % (
706                self.SCHEMA_VERSION, version))
707
708        # initialize internals
709        self.match_evaluator = MatchEvaluatorUtil(self)
710
711    def maybe_skip_test(self, spec):
712        # add any special-casing for skipping tests here
713        if client_context.storage_engine == 'mmapv1':
714            if 'Dirty explicit session is discarded' in spec['description']:
715                raise unittest.SkipTest(
716                    "MMAPv1 does not support retryWrites=True")
717        elif 'Client side error in command starting transaction' in spec['description']:
718            raise unittest.SkipTest("Implement PYTHON-1894")
719
720    def process_error(self, exception, spec):
721        is_error = spec.get('isError')
722        is_client_error = spec.get('isClientError')
723        error_contains = spec.get('errorContains')
724        error_code = spec.get('errorCode')
725        error_code_name = spec.get('errorCodeName')
726        error_labels_contain = spec.get('errorLabelsContain')
727        error_labels_omit = spec.get('errorLabelsOmit')
728        expect_result = spec.get('expectResult')
729
730        if is_error:
731            # already satisfied because exception was raised
732            pass
733
734        if is_client_error:
735            # Connection errors are considered client errors.
736            if isinstance(exception, ConnectionFailure):
737                self.assertNotIsInstance(exception, NotPrimaryError)
738            elif isinstance(exception, (InvalidOperation, ConfigurationError)):
739                pass
740            else:
741                self.assertNotIsInstance(exception, PyMongoError)
742
743        if error_contains:
744            if isinstance(exception, BulkWriteError):
745                errmsg = str(exception.details).lower()
746            else:
747                errmsg = str(exception).lower()
748            self.assertIn(error_contains.lower(), errmsg)
749
750        if error_code:
751            self.assertEqual(
752                error_code, exception.details.get('code'))
753
754        if error_code_name:
755            self.assertEqual(
756                error_code_name, exception.details.get('codeName'))
757
758        if error_labels_contain:
759            labels = [err_label for err_label in error_labels_contain
760                      if exception.has_error_label(err_label)]
761            self.assertEqual(labels, error_labels_contain)
762
763        if error_labels_omit:
764            for err_label in error_labels_omit:
765                if exception.has_error_label(err_label):
766                    self.fail("Exception '%s' unexpectedly had label '%s'" % (
767                        exception, err_label))
768
769        if expect_result:
770            if isinstance(exception, BulkWriteError):
771                result = parse_bulk_write_error_result(
772                    exception)
773                self.match_evaluator.match_result(expect_result, result)
774            else:
775                self.fail("expectResult can only be specified with %s "
776                          "exceptions" % (BulkWriteError,))
777
778    def __raise_if_unsupported(self, opname, target, *target_types):
779        if not isinstance(target, target_types):
780            self.fail('Operation %s not supported for entity '
781                      'of type %s' % (opname, type(target)))
782
783    def __entityOperation_createChangeStream(self, target, *args, **kwargs):
784        if client_context.storage_engine == 'mmapv1':
785            self.skipTest("MMAPv1 does not support change streams")
786        self.__raise_if_unsupported(
787            'createChangeStream', target, MongoClient, Database, Collection)
788        stream = target.watch(*args, **kwargs)
789        self.addCleanup(stream.close)
790        return stream
791
792    def _clientOperation_createChangeStream(self, target, *args, **kwargs):
793        return self.__entityOperation_createChangeStream(
794            target, *args, **kwargs)
795
796    def _databaseOperation_createChangeStream(self, target, *args, **kwargs):
797        return self.__entityOperation_createChangeStream(
798            target, *args, **kwargs)
799
800    def _collectionOperation_createChangeStream(self, target, *args, **kwargs):
801        return self.__entityOperation_createChangeStream(
802            target, *args, **kwargs)
803
804    def _databaseOperation_runCommand(self, target, **kwargs):
805        self.__raise_if_unsupported('runCommand', target, Database)
806        # Ensure the first key is the command name.
807        ordered_command = SON([(kwargs.pop('command_name'), 1)])
808        ordered_command.update(kwargs['command'])
809        kwargs['command'] = ordered_command
810        return target.command(**kwargs)
811
812    def _databaseOperation_listCollections(self, target, *args, **kwargs):
813        if 'batch_size' in kwargs:
814            kwargs['cursor'] = {'batchSize': kwargs.pop('batch_size')}
815        cursor = target.list_collections(*args, **kwargs)
816        return list(cursor)
817
818    def __entityOperation_aggregate(self, target, *args, **kwargs):
819        self.__raise_if_unsupported('aggregate', target, Database, Collection)
820        return list(target.aggregate(*args, **kwargs))
821
822    def _databaseOperation_aggregate(self, target, *args, **kwargs):
823        return self.__entityOperation_aggregate(target, *args, **kwargs)
824
825    def _collectionOperation_aggregate(self, target, *args, **kwargs):
826        return self.__entityOperation_aggregate(target, *args, **kwargs)
827
828    def _collectionOperation_find(self, target, *args, **kwargs):
829        self.__raise_if_unsupported('find', target, Collection)
830        find_cursor = target.find(*args, **kwargs)
831        return list(find_cursor)
832
833    def _collectionOperation_createFindCursor(self, target, *args, **kwargs):
834        self.__raise_if_unsupported('find', target, Collection)
835        cursor = NonLazyCursor(target.find(*args, **kwargs))
836        self.addCleanup(cursor.close)
837        return cursor
838
839    def _collectionOperation_listIndexes(self, target, *args, **kwargs):
840        if 'batch_size' in kwargs:
841            self.skipTest('PyMongo does not support batch_size for '
842                          'list_indexes')
843        return target.list_indexes(*args, **kwargs)
844
845    def _sessionOperation_withTransaction(self, target, *args, **kwargs):
846        if client_context.storage_engine == 'mmapv1':
847            self.skipTest('MMAPv1 does not support document-level locking')
848        self.__raise_if_unsupported('withTransaction', target, ClientSession)
849        return target.with_transaction(*args, **kwargs)
850
851    def _sessionOperation_startTransaction(self, target, *args, **kwargs):
852        if client_context.storage_engine == 'mmapv1':
853            self.skipTest('MMAPv1 does not support document-level locking')
854        self.__raise_if_unsupported('startTransaction', target, ClientSession)
855        return target.start_transaction(*args, **kwargs)
856
857    def _changeStreamOperation_iterateUntilDocumentOrError(self, target,
858                                                           *args, **kwargs):
859        self.__raise_if_unsupported(
860            'iterateUntilDocumentOrError', target, ChangeStream)
861        return next(target)
862
863    def _cursor_iterateUntilDocumentOrError(self, target, *args, **kwargs):
864        self.__raise_if_unsupported(
865            'iterateUntilDocumentOrError', target, NonLazyCursor)
866        return next(target)
867
868    def _cursor_close(self, target, *args, **kwargs):
869        self.__raise_if_unsupported('close', target, NonLazyCursor)
870        return target.close()
871
872    def run_entity_operation(self, spec):
873        target = self.entity_map[spec['object']]
874        opname = spec['name']
875        opargs = spec.get('arguments')
876        expect_error = spec.get('expectError')
877        save_as_entity = spec.get('saveResultAsEntity')
878        expect_result = spec.get('expectResult')
879        ignore = spec.get('ignoreResultAndError')
880        if ignore and (expect_error or save_as_entity or expect_result):
881            raise ValueError(
882                'ignoreResultAndError is incompatible with saveResultAsEntity'
883                ', expectError, and expectResult')
884        if opargs:
885            arguments = parse_spec_options(copy.deepcopy(opargs))
886            prepare_spec_arguments(spec, arguments, camel_to_snake(opname),
887                                   self.entity_map, self.run_operations)
888        else:
889            arguments = tuple()
890
891        if isinstance(target, MongoClient):
892            method_name = '_clientOperation_%s' % (opname,)
893        elif isinstance(target, Database):
894            method_name = '_databaseOperation_%s' % (opname,)
895        elif isinstance(target, Collection):
896            method_name = '_collectionOperation_%s' % (opname,)
897        elif isinstance(target, ChangeStream):
898            method_name = '_changeStreamOperation_%s' % (opname,)
899        elif isinstance(target, NonLazyCursor):
900            method_name = '_cursor_%s' % (opname,)
901        elif isinstance(target, ClientSession):
902            method_name = '_sessionOperation_%s' % (opname,)
903        elif isinstance(target, GridFSBucket):
904            raise NotImplementedError
905        else:
906            method_name = 'doesNotExist'
907
908        try:
909            method = getattr(self, method_name)
910        except AttributeError:
911            try:
912                cmd = getattr(target, camel_to_snake(opname))
913            except AttributeError:
914                self.fail('Unsupported operation %s on entity %s' % (
915                    opname, target))
916        else:
917            cmd = functools.partial(method, target)
918
919        try:
920            result = cmd(**dict(arguments))
921        except Exception as exc:
922            if ignore:
923                return
924            if expect_error:
925                return self.process_error(exc, expect_error)
926            raise
927        else:
928            if expect_error:
929                self.fail('Excepted error %s but "%s" succeeded: %s' % (
930                    expect_error, opname, result))
931
932        if expect_result:
933            actual = coerce_result(opname, result)
934            self.match_evaluator.match_result(expect_result, actual)
935
936        if save_as_entity:
937            self.entity_map[save_as_entity] = result
938
939    def __set_fail_point(self, client, command_args):
940        if not client_context.test_commands_enabled:
941            self.skipTest('Test commands must be enabled')
942
943        cmd_on = SON([('configureFailPoint', 'failCommand')])
944        cmd_on.update(command_args)
945        client.admin.command(cmd_on)
946        self.addCleanup(
947            client.admin.command,
948            'configureFailPoint', cmd_on['configureFailPoint'], mode='off')
949
950    def _testOperation_failPoint(self, spec):
951        self.__set_fail_point(
952            client=self.entity_map[spec['client']],
953            command_args=spec['failPoint'])
954
955    def _testOperation_targetedFailPoint(self, spec):
956        session = self.entity_map[spec['session']]
957        if not session._pinned_address:
958            self.fail("Cannot use targetedFailPoint operation with unpinned "
959                      "session %s" % (spec['session'],))
960
961        client = single_client('%s:%s' % session._pinned_address)
962        self.__set_fail_point(
963            client=client, command_args=spec['failPoint'])
964        self.addCleanup(client.close)
965
966    def _testOperation_assertSessionTransactionState(self, spec):
967        session = self.entity_map[spec['session']]
968        expected_state = getattr(_TxnState, spec['state'].upper())
969        self.assertEqual(expected_state, session._transaction.state)
970
971    def _testOperation_assertSessionPinned(self, spec):
972        session = self.entity_map[spec['session']]
973        self.assertIsNotNone(session._transaction.pinned_address)
974
975    def _testOperation_assertSessionUnpinned(self, spec):
976        session = self.entity_map[spec['session']]
977        self.assertIsNone(session._pinned_address)
978        self.assertIsNone(session._transaction.pinned_address)
979
980    def __get_last_two_command_lsids(self, listener):
981        cmd_started_events = []
982        for event in reversed(listener.events):
983            if isinstance(event, CommandStartedEvent):
984                cmd_started_events.append(event)
985        if len(cmd_started_events) < 2:
986            self.fail('Needed 2 CommandStartedEvents to compare lsids, '
987                      'got %s' % (len(cmd_started_events)))
988        return tuple([e.command['lsid'] for e in cmd_started_events][:2])
989
990    def _testOperation_assertDifferentLsidOnLastTwoCommands(self, spec):
991        listener = self.entity_map.get_listener_for_client(spec['client'])
992        self.assertNotEqual(*self.__get_last_two_command_lsids(listener))
993
994    def _testOperation_assertSameLsidOnLastTwoCommands(self, spec):
995        listener = self.entity_map.get_listener_for_client(spec['client'])
996        self.assertEqual(*self.__get_last_two_command_lsids(listener))
997
998    def _testOperation_assertSessionDirty(self, spec):
999        session = self.entity_map[spec['session']]
1000        self.assertTrue(session._server_session.dirty)
1001
1002    def _testOperation_assertSessionNotDirty(self, spec):
1003        session = self.entity_map[spec['session']]
1004        return self.assertFalse(session._server_session.dirty)
1005
1006    def _testOperation_assertCollectionExists(self, spec):
1007        database_name = spec['databaseName']
1008        collection_name = spec['collectionName']
1009        collection_name_list = list(
1010            self.client.get_database(database_name).list_collection_names())
1011        self.assertIn(collection_name, collection_name_list)
1012
1013    def _testOperation_assertCollectionNotExists(self, spec):
1014        database_name = spec['databaseName']
1015        collection_name = spec['collectionName']
1016        collection_name_list = list(
1017            self.client.get_database(database_name).list_collection_names())
1018        self.assertNotIn(collection_name, collection_name_list)
1019
1020    def _testOperation_assertIndexExists(self, spec):
1021        collection = self.client[spec['databaseName']][spec['collectionName']]
1022        index_names = [idx['name'] for idx in collection.list_indexes()]
1023        self.assertIn(spec['indexName'], index_names)
1024
1025    def _testOperation_assertIndexNotExists(self, spec):
1026        collection = self.client[spec['databaseName']][spec['collectionName']]
1027        for index in collection.list_indexes():
1028            self.assertNotEqual(spec['indexName'], index['name'])
1029
1030    def _testOperation_assertNumberConnectionsCheckedOut(self, spec):
1031        client = self.entity_map[spec['client']]
1032        pool = get_pool(client)
1033        self.assertEqual(spec['connections'], pool.active_sockets)
1034
1035    def run_special_operation(self, spec):
1036        opname = spec['name']
1037        method_name = '_testOperation_%s' % (opname,)
1038        try:
1039            method = getattr(self, method_name)
1040        except AttributeError:
1041            self.fail('Unsupported special test operation %s' % (opname,))
1042        else:
1043            method(spec['arguments'])
1044
1045    def run_operations(self, spec):
1046        for op in spec:
1047            target = op['object']
1048            if target != 'testRunner':
1049                self.run_entity_operation(op)
1050            else:
1051                self.run_special_operation(op)
1052
1053    def check_events(self, spec):
1054        for event_spec in spec:
1055            client_name = event_spec['client']
1056            events = event_spec['events']
1057            # Valid types: 'command', 'cmap'
1058            event_type = event_spec.get('eventType', 'command')
1059            assert event_type in ('command', 'cmap')
1060
1061            listener = self.entity_map.get_listener_for_client(client_name)
1062            actual_events = listener.get_events(event_type)
1063            if len(events) == 0:
1064                self.assertEqual(actual_events, [])
1065                continue
1066
1067            if len(events) > len(actual_events):
1068                self.fail('Expected to see %s events, got %s' % (
1069                    len(events), len(actual_events)))
1070
1071            for idx, expected_event in enumerate(events):
1072                self.match_evaluator.match_event(
1073                    event_type, expected_event, actual_events[idx])
1074
1075    def verify_outcome(self, spec):
1076        for collection_data in spec:
1077            coll_name = collection_data['collectionName']
1078            db_name = collection_data['databaseName']
1079            expected_documents = collection_data['documents']
1080
1081            coll = self.client.get_database(db_name).get_collection(
1082                coll_name,
1083                read_preference=ReadPreference.PRIMARY,
1084                read_concern=ReadConcern(level='local'))
1085
1086            if expected_documents:
1087                sorted_expected_documents = sorted(
1088                    expected_documents, key=lambda doc: doc['_id'])
1089                actual_documents = list(
1090                    coll.find({}, sort=[('_id', ASCENDING)]))
1091                self.assertListEqual(sorted_expected_documents,
1092                                     actual_documents)
1093
1094    def run_scenario(self, spec):
1095        # maybe skip test manually
1096        self.maybe_skip_test(spec)
1097
1098        # process test-level runOnRequirements
1099        run_on_spec = spec.get('runOnRequirements', [])
1100        if not self.should_run_on(run_on_spec):
1101            raise unittest.SkipTest('runOnRequirements not satisfied')
1102
1103        # process skipReason
1104        skip_reason = spec.get('skipReason', None)
1105        if skip_reason is not None:
1106            raise unittest.SkipTest('%s' % (skip_reason,))
1107
1108        # process createEntities
1109        self.entity_map = EntityMapUtil(self)
1110        self.entity_map.create_entities_from_spec(
1111            self.TEST_SPEC.get('createEntities', []))
1112
1113        # process initialData
1114        self.insert_initial_data(self.TEST_SPEC.get('initialData', []))
1115
1116        # process operations
1117        self.run_operations(spec['operations'])
1118
1119        # process expectEvents
1120        self.check_events(spec.get('expectEvents', []))
1121
1122        # process outcome
1123        self.verify_outcome(spec.get('outcome', []))
1124
1125
1126class UnifiedSpecTestMeta(type):
1127    """Metaclass for generating test classes."""
1128    def __init__(cls, *args, **kwargs):
1129        super(UnifiedSpecTestMeta, cls).__init__(*args, **kwargs)
1130
1131        def create_test(spec):
1132            def test_case(self):
1133                self.run_scenario(spec)
1134            return test_case
1135
1136        for test_spec in cls.TEST_SPEC['tests']:
1137            description = test_spec['description']
1138            test_name = 'test_%s' % (description.strip('. ').
1139                                     replace(' ', '_').replace('.', '_'),)
1140            test_method = create_test(copy.deepcopy(test_spec))
1141            test_method.__name__ = str(test_name)
1142
1143            for fail_pattern in cls.EXPECTED_FAILURES:
1144                if re.search(fail_pattern, description):
1145                    test_method = unittest.expectedFailure(test_method)
1146                    break
1147
1148            setattr(cls, test_name, test_method)
1149
1150
1151_ALL_MIXIN_CLASSES = [
1152    UnifiedSpecTestMixinV1,
1153    # add mixin classes for new schema major versions here
1154]
1155
1156
1157_SCHEMA_VERSION_MAJOR_TO_MIXIN_CLASS = {
1158    KLASS.SCHEMA_VERSION[0]: KLASS for KLASS in _ALL_MIXIN_CLASSES}
1159
1160
1161def generate_test_classes(test_path, module=__name__, class_name_prefix='',
1162                          expected_failures=[],
1163                          bypass_test_generation_errors=False):
1164    """Method for generating test classes. Returns a dictionary where keys are
1165    the names of test classes and values are the test class objects."""
1166    test_klasses = {}
1167
1168    def test_base_class_factory(test_spec):
1169        """Utility that creates the base class to use for test generation.
1170        This is needed to ensure that cls.TEST_SPEC is appropriately set when
1171        the metaclass __init__ is invoked."""
1172        class SpecTestBase(with_metaclass(UnifiedSpecTestMeta)):
1173            TEST_SPEC = test_spec
1174            EXPECTED_FAILURES = expected_failures
1175        return SpecTestBase
1176
1177    for dirpath, _, filenames in os.walk(test_path):
1178        dirname = os.path.split(dirpath)[-1]
1179
1180        for filename in filenames:
1181            fpath = os.path.join(dirpath, filename)
1182            with open(fpath) as scenario_stream:
1183                # Use tz_aware=False to match how CodecOptions decodes
1184                # dates.
1185                opts = json_util.JSONOptions(tz_aware=False)
1186                scenario_def = json_util.loads(
1187                    scenario_stream.read(), json_options=opts)
1188
1189            test_type = os.path.splitext(filename)[0]
1190            snake_class_name = 'Test%s_%s_%s' % (
1191                class_name_prefix, dirname.replace('-', '_'),
1192                test_type.replace('-', '_').replace('.', '_'))
1193            class_name = snake_to_camel(snake_class_name)
1194
1195            try:
1196                schema_version = Version.from_string(
1197                    scenario_def['schemaVersion'])
1198                mixin_class = _SCHEMA_VERSION_MAJOR_TO_MIXIN_CLASS.get(
1199                    schema_version[0])
1200                if mixin_class is None:
1201                    raise ValueError(
1202                        "test file '%s' has unsupported schemaVersion '%s'" % (
1203                            fpath, schema_version))
1204                test_klasses[class_name] = type(
1205                    class_name,
1206                    (mixin_class, test_base_class_factory(scenario_def),),
1207                    {'__module__': module})
1208            except Exception:
1209                if bypass_test_generation_errors:
1210                    continue
1211                raise
1212
1213    return test_klasses
1214