1# Copyright 2019-present MongoDB, Inc.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Test retryable reads spec."""
16
17import os
18import sys
19
20sys.path[0:0] = [""]
21
22from pymongo.mongo_client import MongoClient
23from pymongo.write_concern import WriteConcern
24
25from test import unittest, client_context, PyMongoTestCase
26from test.utils import TestCreator
27from test.utils_spec_runner import SpecRunner
28
29
30# Location of JSON test specifications.
31_TEST_PATH = os.path.join(
32    os.path.dirname(os.path.realpath(__file__)), 'retryable_reads')
33
34
35class TestClientOptions(PyMongoTestCase):
36    def test_default(self):
37        client = MongoClient(connect=False)
38        self.assertEqual(client.retry_reads, True)
39
40    def test_kwargs(self):
41        client = MongoClient(retryReads=True, connect=False)
42        self.assertEqual(client.retry_reads, True)
43        client = MongoClient(retryReads=False, connect=False)
44        self.assertEqual(client.retry_reads, False)
45
46    def test_uri(self):
47        client = MongoClient('mongodb://h/?retryReads=true', connect=False)
48        self.assertEqual(client.retry_reads, True)
49        client = MongoClient('mongodb://h/?retryReads=false', connect=False)
50        self.assertEqual(client.retry_reads, False)
51
52
53class TestSpec(SpecRunner):
54    RUN_ON_LOAD_BALANCER = True
55
56    @classmethod
57    @client_context.require_failCommand_fail_point
58    # TODO: remove this once PYTHON-1948 is done.
59    @client_context.require_no_mmap
60    def setUpClass(cls):
61        super(TestSpec, cls).setUpClass()
62
63    def maybe_skip_scenario(self, test):
64        super(TestSpec, self).maybe_skip_scenario(test)
65        skip_names = [
66            'listCollectionObjects', 'listIndexNames', 'listDatabaseObjects']
67        for name in skip_names:
68            if name.lower() in test['description'].lower():
69                self.skipTest('PyMongo does not support %s' % (name,))
70
71        # Skip changeStream related tests on MMAPv1.
72        test_name = self.id().rsplit('.')[-1]
73        if ('changestream' in test_name.lower() and
74                client_context.storage_engine == 'mmapv1'):
75            self.skipTest("MMAPv1 does not support change streams.")
76
77    def get_scenario_coll_name(self, scenario_def):
78        """Override a test's collection name to support GridFS tests."""
79        if 'bucket_name' in scenario_def:
80            return scenario_def['bucket_name']
81        return super(TestSpec, self).get_scenario_coll_name(scenario_def)
82
83    def setup_scenario(self, scenario_def):
84        """Override a test's setup to support GridFS tests."""
85        if 'bucket_name' in scenario_def:
86            db_name = self.get_scenario_db_name(scenario_def)
87            db = client_context.client.get_database(
88                db_name, write_concern=WriteConcern(w='majority'))
89            # Create a bucket for the retryable reads GridFS tests.
90            client_context.client.drop_database(db_name)
91            if scenario_def['data']:
92                data = scenario_def['data']
93                # Load data.
94                db['fs.chunks'].insert_many(data['fs.chunks'])
95                db['fs.files'].insert_many(data['fs.files'])
96        else:
97            super(TestSpec, self).setup_scenario(scenario_def)
98
99
100def create_test(scenario_def, test, name):
101    @client_context.require_test_commands
102    def run_scenario(self):
103        self.run_scenario(scenario_def, test)
104
105    return run_scenario
106
107
108test_creator = TestCreator(create_test, TestSpec, _TEST_PATH)
109test_creator.create_tests()
110
111if __name__ == "__main__":
112    unittest.main()
113