1# Copyright 2019-present MongoDB, Inc.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Run the SRV support tests."""
16
17import sys
18
19from time import sleep
20
21sys.path[0:0] = [""]
22
23import pymongo
24
25from pymongo import common
26from pymongo.errors import ConfigurationError
27from pymongo.srv_resolver import _HAVE_DNSPYTHON
28from pymongo.mongo_client import MongoClient
29from test import client_knobs, unittest
30from test.utils import wait_until, FunctionCallRecorder
31
32
33WAIT_TIME = 0.1
34
35
36class SrvPollingKnobs(object):
37    def __init__(self, ttl_time=None, min_srv_rescan_interval=None,
38                 dns_resolver_nodelist_response=None,
39                 count_resolver_calls=False):
40        self.ttl_time = ttl_time
41        self.min_srv_rescan_interval = min_srv_rescan_interval
42        self.dns_resolver_nodelist_response = dns_resolver_nodelist_response
43        self.count_resolver_calls = count_resolver_calls
44
45        self.old_min_srv_rescan_interval = None
46        self.old_dns_resolver_response = None
47
48    def enable(self):
49        self.old_min_srv_rescan_interval = common.MIN_SRV_RESCAN_INTERVAL
50        self.old_dns_resolver_response = \
51            pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl
52
53        if self.min_srv_rescan_interval is not None:
54            common.MIN_SRV_RESCAN_INTERVAL = self.min_srv_rescan_interval
55
56        def mock_get_hosts_and_min_ttl(resolver, *args):
57            nodes, ttl = self.old_dns_resolver_response(resolver)
58            if self.dns_resolver_nodelist_response is not None:
59                nodes = self.dns_resolver_nodelist_response()
60            if self.ttl_time is not None:
61                ttl = self.ttl_time
62            return nodes, ttl
63
64        if self.count_resolver_calls:
65            patch_func = FunctionCallRecorder(mock_get_hosts_and_min_ttl)
66        else:
67            patch_func = mock_get_hosts_and_min_ttl
68
69        pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func
70
71    def __enter__(self):
72        self.enable()
73
74    def disable(self):
75        common.MIN_SRV_RESCAN_INTERVAL = self.old_min_srv_rescan_interval
76        pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = \
77            self.old_dns_resolver_response
78
79    def __exit__(self, exc_type, exc_val, exc_tb):
80        self.disable()
81
82
83class TestSrvPolling(unittest.TestCase):
84
85    BASE_SRV_RESPONSE = [
86        ("localhost.test.build.10gen.cc", 27017),
87        ("localhost.test.build.10gen.cc", 27018)]
88
89    CONNECTION_STRING = "mongodb+srv://test1.test.build.10gen.cc"
90
91    def setUp(self):
92        if not _HAVE_DNSPYTHON:
93            raise unittest.SkipTest("SRV polling tests require the dnspython "
94                                    "module")
95        # Patch timeouts to ensure short rescan SRV interval.
96        self.client_knobs = client_knobs(
97            heartbeat_frequency=WAIT_TIME, min_heartbeat_interval=WAIT_TIME,
98            events_queue_frequency=WAIT_TIME)
99        self.client_knobs.enable()
100
101    def tearDown(self):
102        self.client_knobs.disable()
103
104    def get_nodelist(self, client):
105        return client._topology.description.server_descriptions().keys()
106
107    def assert_nodelist_change(self, expected_nodelist, client):
108        """Check if the client._topology eventually sees all nodes in the
109        expected_nodelist.
110        """
111        def predicate():
112            nodelist = self.get_nodelist(client)
113            if set(expected_nodelist) == set(nodelist):
114                return True
115            return False
116        wait_until(predicate, "see expected nodelist", timeout=100*WAIT_TIME)
117
118    def assert_nodelist_nochange(self, expected_nodelist, client):
119        """Check if the client._topology ever deviates from seeing all nodes
120        in the expected_nodelist. Consistency is checked after sleeping for
121        (WAIT_TIME * 10) seconds. Also check that the resolver is called at
122        least once.
123        """
124        sleep(WAIT_TIME*10)
125        nodelist = self.get_nodelist(client)
126        if set(expected_nodelist) != set(nodelist):
127            msg = "Client nodelist %s changed unexpectedly (expected %s)"
128            raise self.fail(msg % (nodelist, expected_nodelist))
129        self.assertGreaterEqual(
130            pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count,
131            1, "resolver was never called")
132        return True
133
134    def run_scenario(self, dns_response, expect_change):
135        if callable(dns_response):
136            dns_resolver_response = dns_response
137        else:
138            def dns_resolver_response():
139                return dns_response
140
141        if expect_change:
142            assertion_method = self.assert_nodelist_change
143            count_resolver_calls = False
144            expected_response = dns_response
145        else:
146            assertion_method = self.assert_nodelist_nochange
147            count_resolver_calls = True
148            expected_response = self.BASE_SRV_RESPONSE
149
150        # Patch timeouts to ensure short test running times.
151        with SrvPollingKnobs(
152                ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
153            client = MongoClient(self.CONNECTION_STRING)
154            self.assert_nodelist_change(self.BASE_SRV_RESPONSE, client)
155            # Patch list of hosts returned by DNS query.
156            with SrvPollingKnobs(
157                    dns_resolver_nodelist_response=dns_resolver_response,
158                    count_resolver_calls=count_resolver_calls):
159                assertion_method(expected_response, client)
160
161    def test_addition(self):
162        response = self.BASE_SRV_RESPONSE[:]
163        response.append(
164            ("localhost.test.build.10gen.cc", 27019))
165        self.run_scenario(response, True)
166
167    def test_removal(self):
168        response = self.BASE_SRV_RESPONSE[:]
169        response.remove(
170            ("localhost.test.build.10gen.cc", 27018))
171        self.run_scenario(response, True)
172
173    def test_replace_one(self):
174        response = self.BASE_SRV_RESPONSE[:]
175        response.remove(
176            ("localhost.test.build.10gen.cc", 27018))
177        response.append(
178            ("localhost.test.build.10gen.cc", 27019))
179        self.run_scenario(response, True)
180
181    def test_replace_both_with_one(self):
182        response = [("localhost.test.build.10gen.cc", 27019)]
183        self.run_scenario(response, True)
184
185    def test_replace_both_with_two(self):
186        response = [("localhost.test.build.10gen.cc", 27019),
187                    ("localhost.test.build.10gen.cc", 27020)]
188        self.run_scenario(response, True)
189
190    def test_dns_failures(self):
191        from dns import exception
192        for exc in (exception.FormError, exception.TooBig, exception.Timeout):
193            def response_callback(*args):
194                raise exc("DNS Failure!")
195            self.run_scenario(response_callback, False)
196
197    def test_dns_record_lookup_empty(self):
198        response = []
199        self.run_scenario(response, False)
200
201    def _test_recover_from_initial(self, initial_callback):
202        # Construct a valid final response callback distinct from base.
203        response_final = self.BASE_SRV_RESPONSE[:]
204        response_final.pop()
205        def final_callback():
206            return response_final
207
208        with SrvPollingKnobs(
209                ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME,
210                dns_resolver_nodelist_response=initial_callback,
211                count_resolver_calls=True):
212            # Client uses unpatched method to get initial nodelist
213            client = MongoClient(self.CONNECTION_STRING)
214            # Invalid DNS resolver response should not change nodelist.
215            self.assert_nodelist_nochange(self.BASE_SRV_RESPONSE, client)
216
217        with SrvPollingKnobs(
218                ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME,
219                dns_resolver_nodelist_response=final_callback):
220            # Nodelist should reflect new valid DNS resolver response.
221            self.assert_nodelist_change(response_final, client)
222
223    def test_recover_from_initially_empty_seedlist(self):
224        def empty_seedlist():
225            return []
226        self._test_recover_from_initial(empty_seedlist)
227
228    def test_recover_from_initially_erroring_seedlist(self):
229        def erroring_seedlist():
230            raise ConfigurationError
231        self._test_recover_from_initial(erroring_seedlist)
232
233
234if __name__ == '__main__':
235    unittest.main()
236