1"""
2Customize the behavior of a fixture by allowing special code to be
3executed before or after each test, and before or after each suite.
4"""
5
6from __future__ import absolute_import
7
8import os
9import sys
10
11import bson
12import pymongo
13
14from . import fixtures
15from . import testcases
16from .. import errors
17from .. import logging
18from .. import utils
19
20
21def make_custom_behavior(class_name, *args, **kwargs):
22    """
23    Factory function for creating CustomBehavior instances.
24    """
25
26    if class_name not in _CUSTOM_BEHAVIORS:
27        raise ValueError("Unknown custom behavior class '%s'" % (class_name))
28    return _CUSTOM_BEHAVIORS[class_name](*args, **kwargs)
29
30
31class CustomBehavior(object):
32    """
33    The common interface all CustomBehaviors will inherit from.
34    """
35
36    @staticmethod
37    def start_dynamic_test(test_case, test_report):
38        """
39        If a CustomBehavior wants to add a test case that will show up
40        in the test report, it should use this method to add it to the
41        report, since we will need to count it as a dynamic test to get
42        the stats in the summary information right.
43        """
44        test_report.startTest(test_case, dynamic=True)
45
46    def __init__(self, logger, fixture):
47        """
48        Initializes the CustomBehavior with the specified fixture.
49        """
50
51        if not isinstance(logger, logging.Logger):
52            raise TypeError("logger must be a Logger instance")
53
54        self.logger = logger
55        self.fixture = fixture
56
57    def before_suite(self, test_report):
58        """
59        The test runner calls this exactly once before they start
60        running the suite.
61        """
62        pass
63
64    def after_suite(self, test_report):
65        """
66        The test runner calls this exactly once after all tests have
67        finished executing. Be sure to reset the behavior back to its
68        original state so that it can be run again.
69        """
70        pass
71
72    def before_test(self, test_report):
73        """
74        Each test will call this before it executes.
75
76        Raises a TestFailure if the test should be marked as a failure,
77        or a ServerFailure if the fixture exits uncleanly or
78        unexpectedly.
79        """
80        pass
81
82    def after_test(self, test_report):
83        """
84        Each test will call this after it executes.
85
86        Raises a TestFailure if the test should be marked as a failure,
87        or a ServerFailure if the fixture exits uncleanly or
88        unexpectedly.
89        """
90        pass
91
92
93class CleanEveryN(CustomBehavior):
94    """
95    Restarts the fixture after it has ran 'n' tests.
96    On mongod-related fixtures, this will clear the dbpath.
97    """
98
99    DEFAULT_N = 20
100
101    def __init__(self, logger, fixture, n=DEFAULT_N):
102        CustomBehavior.__init__(self, logger, fixture)
103
104        # Try to isolate what test triggers the leak by restarting the fixture each time.
105        if "detect_leaks=1" in os.getenv("ASAN_OPTIONS", ""):
106            self.logger.info("ASAN_OPTIONS environment variable set to detect leaks, so restarting"
107                             " the fixture after each test instead of after every %d.", n)
108            n = 1
109
110        self.n = n
111        self.tests_run = 0
112
113    def after_test(self, test_report):
114        self.tests_run += 1
115        if self.tests_run >= self.n:
116            self.logger.info("%d tests have been run against the fixture, stopping it...",
117                             self.tests_run)
118            self.tests_run = 0
119
120            teardown_success = self.fixture.teardown()
121            self.logger.info("Starting the fixture back up again...")
122            self.fixture.setup()
123            self.fixture.await_ready()
124
125            # Raise this after calling setup in case --continueOnFailure was specified.
126            if not teardown_success:
127                raise errors.TestFailure("%s did not exit cleanly" % (self.fixture))
128
129
130class CheckReplDBHash(CustomBehavior):
131    """
132    Waits for replication after each test, then checks that the dbhahses
133    of all databases other than "local" match on the primary and all of
134    the secondaries. If any dbhashes do not match, logs information
135    about what was different (e.g. Different numbers of collections,
136    missing documents in a collection, mismatching documents, etc).
137
138    Compatible only with ReplFixture subclasses.
139    """
140
141    def __init__(self, logger, fixture):
142        if not isinstance(fixture, fixtures.ReplFixture):
143            raise TypeError("%s does not support replication" % (fixture.__class__.__name__))
144
145        CustomBehavior.__init__(self, logger, fixture)
146
147        self.test_case = testcases.TestCase(self.logger, "Hook", "#dbhash#")
148
149        self.started = False
150
151    def after_test(self, test_report):
152        """
153        After each test, check that the dbhash of the test database is
154        the same on all nodes in the replica set or master/slave
155        fixture.
156        """
157
158        try:
159            if not self.started:
160                CustomBehavior.start_dynamic_test(self.test_case, test_report)
161                self.started = True
162
163            # Wait until all operations have replicated.
164            self.fixture.await_repl()
165
166            success = True
167            sb = []  # String builder.
168
169            primary = self.fixture.get_primary()
170            primary_conn = utils.new_mongo_client(port=primary.port)
171
172            for secondary in self.fixture.get_secondaries():
173                read_preference = pymongo.ReadPreference.SECONDARY
174                secondary_conn = utils.new_mongo_client(port=secondary.port,
175                                                        read_preference=read_preference)
176                # Skip arbiters.
177                if secondary_conn.admin.command("isMaster").get("arbiterOnly", False):
178                    continue
179
180                all_matched = CheckReplDBHash._check_all_db_hashes(primary_conn,
181                                                                   secondary_conn,
182                                                                   sb)
183                if not all_matched:
184                    sb.insert(0,
185                              "One or more databases were different between the primary on port %d"
186                              " and the secondary on port %d:"
187                              % (primary.port, secondary.port))
188
189                success = all_matched and success
190
191            if not success:
192                # Adding failures to a TestReport requires traceback information, so we raise
193                # a 'self.test_case.failureException' that we will catch ourselves.
194                self.test_case.logger.info("\n    ".join(sb))
195                raise self.test_case.failureException("The dbhashes did not match")
196        except self.test_case.failureException as err:
197            self.test_case.logger.exception("The dbhashes did not match.")
198            self.test_case.return_code = 1
199            test_report.addFailure(self.test_case, sys.exc_info())
200            test_report.stopTest(self.test_case)
201            raise errors.ServerFailure(err.args[0])
202        except pymongo.errors.WTimeoutError:
203            self.test_case.logger.exception("Awaiting replication timed out.")
204            self.test_case.return_code = 2
205            test_report.addError(self.test_case, sys.exc_info())
206            test_report.stopTest(self.test_case)
207            raise errors.StopExecution("Awaiting replication timed out")
208
209    def after_suite(self, test_report):
210        """
211        If we get to this point, the #dbhash# test must have been
212        successful, so add it to the test report.
213        """
214
215        if self.started:
216            self.test_case.logger.info("The dbhashes matched for all tests.")
217            self.test_case.return_code = 0
218            test_report.addSuccess(self.test_case)
219            # TestReport.stopTest() has already been called if there was a failure.
220            test_report.stopTest(self.test_case)
221
222        self.started = False
223
224    @staticmethod
225    def _check_all_db_hashes(primary_conn, secondary_conn, sb):
226        """
227        Returns true if for each non-local database, the dbhash command
228        returns the same MD5 hash on the primary as it does on the
229        secondary. Returns false otherwise.
230
231        Logs a message describing the differences if any database's
232        dbhash did not match.
233        """
234
235        # Overview of how we'll check that everything replicated correctly between these two nodes:
236        #
237        # - Check whether they have the same databases.
238        #     - If not, log which databases are missing where, and dump the contents of any that are
239        #       missing.
240        #
241        # - Check whether each database besides "local" gives the same md5 field as the result of
242        #   running the dbhash command.
243        #     - If not, check whether they have the same collections.
244        #         - If not, log which collections are missing where, and dump the contents of any
245        #           that are missing.
246        #     - If so, check that the hash of each non-capped collection matches.
247        #         - If any do not match, log the diff of the collection between the two nodes.
248
249        success = True
250
251        if not CheckReplDBHash._check_dbs_present(primary_conn, secondary_conn, sb):
252            return False
253
254        for db_name in primary_conn.database_names():
255            if db_name == "local":
256                continue  # We don't expect this to match across different nodes.
257
258            matched = CheckReplDBHash._check_db_hash(primary_conn, secondary_conn, db_name, sb)
259            success = matched and success
260
261        return success
262
263    @staticmethod
264    def _check_dbs_present(primary_conn, secondary_conn, sb):
265        """
266        Returns true if the list of databases on the primary is
267        identical to the list of databases on the secondary, and false
268        otherwise.
269        """
270
271        success = True
272        primary_dbs = primary_conn.database_names()
273
274        # Can't run database_names() on secondary, so instead use the listDatabases command.
275        # TODO: Use database_names() once PYTHON-921 is resolved.
276        list_db_output = secondary_conn.admin.command("listDatabases")
277        secondary_dbs = [db["name"] for db in list_db_output["databases"]]
278
279        # There may be a difference in databases which is not considered an error, when
280        # the database only contains system collections. This difference is only logged
281        # when others are encountered, i.e., success = False.
282        missing_on_primary, missing_on_secondary = CheckReplDBHash._check_difference(
283            set(primary_dbs), set(secondary_dbs), "database")
284
285        for missing_db in missing_on_secondary:
286            db = primary_conn[missing_db]
287            coll_names = db.collection_names()
288            non_system_colls = [name for name in coll_names if not name.startswith("system.")]
289
290            # It is only an error if there are any non-system collections in the database,
291            # otherwise it's not well defined whether they should exist or not.
292            if non_system_colls:
293                sb.append("Database %s present on primary but not on secondary." % (missing_db))
294                CheckReplDBHash._dump_all_collections(db, non_system_colls, sb)
295                success = False
296
297        for missing_db in missing_on_primary:
298            db = secondary_conn[missing_db]
299
300            # Can't run collection_names() on secondary, so instead use the listCollections command.
301            # TODO: Always use collection_names() once PYTHON-921 is resolved. Then much of the
302            # logic that is duplicated here can be consolidated.
303            list_coll_output = db.command("listCollections")["cursor"]["firstBatch"]
304            coll_names = [coll["name"] for coll in list_coll_output]
305            non_system_colls = [name for name in coll_names if not name.startswith("system.")]
306
307            # It is only an error if there are any non-system collections in the database,
308            # otherwise it's not well defined if it should exist or not.
309            if non_system_colls:
310                sb.append("Database %s present on secondary but not on primary." % (missing_db))
311                CheckReplDBHash._dump_all_collections(db, non_system_colls, sb)
312                success = False
313
314        return success
315
316    @staticmethod
317    def _check_db_hash(primary_conn, secondary_conn, db_name, sb):
318        """
319        Returns true if the dbhash for 'db_name' matches on the primary
320        and the secondary, and false otherwise.
321
322        Appends a message to 'sb' describing the differences if the
323        dbhashes do not match.
324        """
325
326        primary_hash = primary_conn[db_name].command("dbhash")
327        secondary_hash = secondary_conn[db_name].command("dbhash")
328
329        if primary_hash["md5"] == secondary_hash["md5"]:
330            return True
331
332        success = CheckReplDBHash._check_dbs_eq(
333            primary_conn, secondary_conn, primary_hash, secondary_hash, db_name, sb)
334
335        if not success:
336            sb.append("Database %s has a different hash on the primary and the secondary"
337                      " ([ %s ] != [ %s ]):"
338                      % (db_name, primary_hash["md5"], secondary_hash["md5"]))
339
340        return success
341
342    @staticmethod
343    def _check_dbs_eq(primary_conn, secondary_conn, primary_hash, secondary_hash, db_name, sb):
344        """
345        Returns true if all non-capped collections had the same hash in
346        the dbhash response, and false otherwise.
347
348        Appends information to 'sb' about the differences between the
349        'db_name' database on the primary and the 'db_name' database on
350        the secondary, if any.
351        """
352
353        success = True
354
355        primary_db = primary_conn[db_name]
356        secondary_db = secondary_conn[db_name]
357
358        primary_coll_hashes = primary_hash["collections"]
359        secondary_coll_hashes = secondary_hash["collections"]
360
361        primary_coll_names = set(primary_coll_hashes.keys())
362        secondary_coll_names = set(secondary_coll_hashes.keys())
363
364        missing_on_primary, missing_on_secondary = CheckReplDBHash._check_difference(
365            primary_coll_names, secondary_coll_names, "collection", sb=sb)
366
367        if missing_on_primary or missing_on_secondary:
368
369            # 'sb' already describes which collections are missing where.
370            for coll_name in missing_on_primary:
371                CheckReplDBHash._dump_all_documents(secondary_db, coll_name, sb)
372            for coll_name in missing_on_secondary:
373                CheckReplDBHash._dump_all_documents(primary_db, coll_name, sb)
374            return
375
376        for coll_name in primary_coll_names & secondary_coll_names:
377            primary_coll_hash = primary_coll_hashes[coll_name]
378            secondary_coll_hash = secondary_coll_hashes[coll_name]
379
380            if primary_coll_hash == secondary_coll_hash:
381                continue
382
383            # Ignore capped collections because they are not expected to match on all nodes.
384            if primary_db.command({"collStats": coll_name})["capped"]:
385                # Still fail if the collection is not capped on the secondary.
386                if not secondary_db.command({"collStats": coll_name})["capped"]:
387                    success = False
388                    sb.append("%s.%s collection is capped on primary but not on secondary."
389                              % (primary_db.name, coll_name))
390                sb.append("%s.%s collection is capped, ignoring." % (primary_db.name, coll_name))
391                continue
392            # Still fail if the collection is capped on the secondary, but not on the primary.
393            elif secondary_db.command({"collStats": coll_name})["capped"]:
394                success = False
395                sb.append("%s.%s collection is capped on secondary but not on primary."
396                          % (primary_db.name, coll_name))
397                continue
398
399            success = False
400            sb.append("Collection %s.%s has a different hash on the primary and the secondary"
401                      " ([ %s ] != [ %s ]):"
402                      % (db_name, coll_name, primary_coll_hash, secondary_coll_hash))
403            CheckReplDBHash._check_colls_eq(primary_db, secondary_db, coll_name, sb)
404
405        if success:
406            sb.append("All collections that were expected to match did.")
407        return success
408
409    @staticmethod
410    def _check_colls_eq(primary_db, secondary_db, coll_name, sb):
411        """
412        Appends information to 'sb' about the differences or between
413        the 'coll_name' collection on the primary and the 'coll_name'
414        collection on the secondary, if any.
415        """
416
417        codec_options = bson.CodecOptions(document_class=TypeSensitiveSON)
418
419        primary_coll = primary_db.get_collection(coll_name, codec_options=codec_options)
420        secondary_coll = secondary_db.get_collection(coll_name, codec_options=codec_options)
421
422        primary_docs = CheckReplDBHash._extract_documents(primary_coll)
423        secondary_docs = CheckReplDBHash._extract_documents(secondary_coll)
424
425        CheckReplDBHash._get_collection_diff(primary_docs, secondary_docs, sb)
426
427    @staticmethod
428    def _extract_documents(collection):
429        """
430        Returns a list of all documents in the collection, sorted by
431        their _id.
432        """
433
434        return [doc for doc in collection.find().sort("_id", pymongo.ASCENDING)]
435
436    @staticmethod
437    def _get_collection_diff(primary_docs, secondary_docs, sb):
438        """
439        Returns true if the documents in 'primary_docs' exactly match
440        the documents in 'secondary_docs', and false otherwise.
441
442        Appends information to 'sb' about what matched or did not match.
443        """
444
445        matched = True
446
447        # These need to be lists instead of sets because documents aren't hashable.
448        missing_on_primary = []
449        missing_on_secondary = []
450
451        p_idx = 0  # Keep track of our position in 'primary_docs'.
452        s_idx = 0  # Keep track of our position in 'secondary_docs'.
453
454        while p_idx < len(primary_docs) and s_idx < len(secondary_docs):
455            primary_doc = primary_docs[p_idx]
456            secondary_doc = secondary_docs[s_idx]
457
458            if primary_doc == secondary_doc:
459                p_idx += 1
460                s_idx += 1
461                continue
462
463            # We have mismatching documents.
464            matched = False
465
466            if primary_doc["_id"] == secondary_doc["_id"]:
467                sb.append("Mismatching document:")
468                sb.append("    primary:   %s" % (primary_doc))
469                sb.append("    secondary: %s" % (secondary_doc))
470                p_idx += 1
471                s_idx += 1
472
473            # One node was missing a document. Since the documents are sorted by _id, the doc with
474            # the smaller _id was the one that was skipped.
475            elif primary_doc["_id"] < secondary_doc["_id"]:
476                missing_on_secondary.append(primary_doc)
477
478                # Only move past the doc that we know was skipped.
479                p_idx += 1
480
481            else:  # primary_doc["_id"] > secondary_doc["_id"]
482                missing_on_primary.append(secondary_doc)
483
484                # Only move past the doc that we know was skipped.
485                s_idx += 1
486
487        # Check if there are any unmatched documents left.
488        while p_idx < len(primary_docs):
489            matched = False
490            missing_on_secondary.append(primary_docs[p_idx])
491            p_idx += 1
492        while s_idx < len(secondary_docs):
493            matched = False
494            missing_on_primary.append(secondary_docs[s_idx])
495            s_idx += 1
496
497        if not matched:
498            CheckReplDBHash._append_differences(
499                missing_on_primary, missing_on_secondary, "document", sb)
500        else:
501            sb.append("All documents matched.")
502
503    @staticmethod
504    def _check_difference(primary_set, secondary_set, item_type_name, sb=None):
505        """
506        Returns true if the contents of 'primary_set' and
507        'secondary_set' are identical, and false otherwise. The sets
508        contain information about the primary and secondary,
509        respectively, e.g. the database names that exist on each node.
510
511        Appends information about anything that differed to 'sb'.
512        """
513
514        missing_on_primary = set()
515        missing_on_secondary = set()
516
517        for item in primary_set - secondary_set:
518            missing_on_secondary.add(item)
519
520        for item in secondary_set - primary_set:
521            missing_on_primary.add(item)
522
523        if sb is not None:
524            CheckReplDBHash._append_differences(
525                missing_on_primary, missing_on_secondary, item_type_name, sb)
526
527        return (missing_on_primary, missing_on_secondary)
528
529    @staticmethod
530    def _append_differences(missing_on_primary, missing_on_secondary, item_type_name, sb):
531        """
532        Given two iterables representing items that were missing on the
533        primary or the secondary respectively, append the information
534        about which items were missing to 'sb', if any.
535        """
536
537        if missing_on_primary:
538            sb.append("The following %ss were present on the secondary, but not on the"
539                      " primary:" % (item_type_name))
540            for item in missing_on_primary:
541                sb.append(str(item))
542
543        if missing_on_secondary:
544            sb.append("The following %ss were present on the primary, but not on the"
545                      " secondary:" % (item_type_name))
546            for item in missing_on_secondary:
547                sb.append(str(item))
548
549    @staticmethod
550    def _dump_all_collections(database, coll_names, sb):
551        """
552        Appends the contents of each of the collections in 'coll_names'
553        to 'sb'.
554        """
555
556        if coll_names:
557            sb.append("Database %s contains the following collections: %s"
558                      % (database.name, coll_names))
559            for coll_name in coll_names:
560                CheckReplDBHash._dump_all_documents(database, coll_name, sb)
561        else:
562            sb.append("No collections in database %s." % (database.name))
563
564    @staticmethod
565    def _dump_all_documents(database, coll_name, sb):
566        """
567        Appends the contents of 'coll_name' to 'sb'.
568        """
569
570        docs = CheckReplDBHash._extract_documents(database[coll_name])
571        if docs:
572            sb.append("Documents in %s.%s:" % (database.name, coll_name))
573            for doc in docs:
574                sb.append("    %s" % (doc))
575        else:
576            sb.append("No documents in %s.%s." % (database.name, coll_name))
577
578class TypeSensitiveSON(bson.SON):
579    """
580    Extends bson.SON to perform additional type-checking of document values
581    to differentiate BSON types.
582    """
583
584    def items_with_types(self):
585        """
586        Returns a list of triples. Each triple consists of a field name, a
587        field value, and a field type for each field in the document.
588        """
589
590        return [(key, self[key], type(self[key])) for key in self]
591
592    def __eq__(self, other):
593        """
594        Comparison to another TypeSensitiveSON is order-sensitive and
595        type-sensitive while comparison to a regular dictionary ignores order
596        and type mismatches.
597        """
598
599        if isinstance(other, TypeSensitiveSON):
600            return (len(self) == len(other) and
601                    self.items_with_types() == other.items_with_types())
602
603        raise TypeError("TypeSensitiveSON objects cannot be compared to other types")
604
605class ValidateCollections(CustomBehavior):
606    """
607    Runs full validation (db.collection.validate(true)) on all collections
608    in all databases on every standalone, or primary mongod. If validation
609    fails (validate.valid), then the validate return object is logged.
610
611    Compatible with all subclasses.
612    """
613    DEFAULT_FULL = True
614    DEFAULT_SCANDATA = True
615
616    def __init__(self, logger, fixture, full=DEFAULT_FULL, scandata=DEFAULT_SCANDATA):
617        CustomBehavior.__init__(self, logger, fixture)
618
619        if not isinstance(full, bool):
620            raise TypeError("Fixture option full is not specified as type bool")
621
622        if not isinstance(scandata, bool):
623            raise TypeError("Fixture option scandata is not specified as type bool")
624
625        self.test_case = testcases.TestCase(self.logger, "Hook", "#validate#")
626        self.started = False
627        self.full = full
628        self.scandata = scandata
629
630    def after_test(self, test_report):
631        """
632        After each test, run a full validation on all collections.
633        """
634
635        try:
636            if not self.started:
637                CustomBehavior.start_dynamic_test(self.test_case, test_report)
638                self.started = True
639
640            sb = []  # String builder.
641
642            # The self.fixture.port can be used for client connection to a
643            # standalone mongod, a replica-set primary, or mongos.
644            # TODO: Run collection validation on all nodes in a replica-set.
645            port = self.fixture.port
646            conn = utils.new_mongo_client(port=port)
647
648            success = ValidateCollections._check_all_collections(
649                conn, sb, self.full, self.scandata)
650
651            if not success:
652                # Adding failures to a TestReport requires traceback information, so we raise
653                # a 'self.test_case.failureException' that we will catch ourselves.
654                self.test_case.logger.info("\n    ".join(sb))
655                raise self.test_case.failureException("Collection validation failed")
656        except self.test_case.failureException as err:
657            self.test_case.logger.exception("Collection validation failed")
658            self.test_case.return_code = 1
659            test_report.addFailure(self.test_case, sys.exc_info())
660            test_report.stopTest(self.test_case)
661            raise errors.ServerFailure(err.args[0])
662
663    def after_suite(self, test_report):
664        """
665        If we get to this point, the #validate# test must have been
666        successful, so add it to the test report.
667        """
668
669        if self.started:
670            self.test_case.logger.info("Collection validation passed for all tests.")
671            self.test_case.return_code = 0
672            test_report.addSuccess(self.test_case)
673            # TestReport.stopTest() has already been called if there was a failure.
674            test_report.stopTest(self.test_case)
675
676        self.started = False
677
678    @staticmethod
679    def _check_all_collections(conn, sb, full, scandata):
680        """
681        Returns true if for all databases and collections validate_collection
682        succeeds. Returns false otherwise.
683
684        Logs a message if any database's collection fails validate_collection.
685        """
686
687        success = True
688
689        for db_name in conn.database_names():
690            for coll_name in conn[db_name].collection_names():
691                try:
692                    conn[db_name].validate_collection(coll_name, full=full, scandata=scandata)
693                except pymongo.errors.CollectionInvalid as err:
694                    sb.append("Database %s, collection %s failed to validate:\n%s"
695                              % (db_name, coll_name, err.args[0]))
696                    success = False
697        return success
698
699
700_CUSTOM_BEHAVIORS = {
701    "CleanEveryN": CleanEveryN,
702    "CheckReplDBHash": CheckReplDBHash,
703    "ValidateCollections": ValidateCollections,
704}
705