1# Copyright 2019-present MongoDB, Inc. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Run the SRV support tests.""" 16 17import sys 18 19from time import sleep 20 21sys.path[0:0] = [""] 22 23import pymongo 24 25from pymongo import common 26from pymongo.errors import ConfigurationError 27from pymongo.srv_resolver import _HAVE_DNSPYTHON 28from pymongo.mongo_client import MongoClient 29from test import client_knobs, unittest 30from test.utils import wait_until, FunctionCallRecorder 31 32 33WAIT_TIME = 0.1 34 35 36class SrvPollingKnobs(object): 37 def __init__(self, ttl_time=None, min_srv_rescan_interval=None, 38 dns_resolver_nodelist_response=None, 39 count_resolver_calls=False): 40 self.ttl_time = ttl_time 41 self.min_srv_rescan_interval = min_srv_rescan_interval 42 self.dns_resolver_nodelist_response = dns_resolver_nodelist_response 43 self.count_resolver_calls = count_resolver_calls 44 45 self.old_min_srv_rescan_interval = None 46 self.old_dns_resolver_response = None 47 48 def enable(self): 49 self.old_min_srv_rescan_interval = common.MIN_SRV_RESCAN_INTERVAL 50 self.old_dns_resolver_response = \ 51 pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl 52 53 if self.min_srv_rescan_interval is not None: 54 common.MIN_SRV_RESCAN_INTERVAL = self.min_srv_rescan_interval 55 56 def mock_get_hosts_and_min_ttl(resolver, *args): 57 nodes, ttl = self.old_dns_resolver_response(resolver) 58 if self.dns_resolver_nodelist_response is not None: 59 nodes = self.dns_resolver_nodelist_response() 60 if self.ttl_time is not None: 61 ttl = self.ttl_time 62 return nodes, ttl 63 64 if self.count_resolver_calls: 65 patch_func = FunctionCallRecorder(mock_get_hosts_and_min_ttl) 66 else: 67 patch_func = mock_get_hosts_and_min_ttl 68 69 pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func 70 71 def __enter__(self): 72 self.enable() 73 74 def disable(self): 75 common.MIN_SRV_RESCAN_INTERVAL = self.old_min_srv_rescan_interval 76 pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = \ 77 self.old_dns_resolver_response 78 79 def __exit__(self, exc_type, exc_val, exc_tb): 80 self.disable() 81 82 83class TestSrvPolling(unittest.TestCase): 84 85 BASE_SRV_RESPONSE = [ 86 ("localhost.test.build.10gen.cc", 27017), 87 ("localhost.test.build.10gen.cc", 27018)] 88 89 CONNECTION_STRING = "mongodb+srv://test1.test.build.10gen.cc" 90 91 def setUp(self): 92 if not _HAVE_DNSPYTHON: 93 raise unittest.SkipTest("SRV polling tests require the dnspython " 94 "module") 95 # Patch timeouts to ensure short rescan SRV interval. 96 self.client_knobs = client_knobs( 97 heartbeat_frequency=WAIT_TIME, min_heartbeat_interval=WAIT_TIME, 98 events_queue_frequency=WAIT_TIME) 99 self.client_knobs.enable() 100 101 def tearDown(self): 102 self.client_knobs.disable() 103 104 def get_nodelist(self, client): 105 return client._topology.description.server_descriptions().keys() 106 107 def assert_nodelist_change(self, expected_nodelist, client): 108 """Check if the client._topology eventually sees all nodes in the 109 expected_nodelist. 110 """ 111 def predicate(): 112 nodelist = self.get_nodelist(client) 113 if set(expected_nodelist) == set(nodelist): 114 return True 115 return False 116 wait_until(predicate, "see expected nodelist", timeout=100*WAIT_TIME) 117 118 def assert_nodelist_nochange(self, expected_nodelist, client): 119 """Check if the client._topology ever deviates from seeing all nodes 120 in the expected_nodelist. Consistency is checked after sleeping for 121 (WAIT_TIME * 10) seconds. Also check that the resolver is called at 122 least once. 123 """ 124 sleep(WAIT_TIME*10) 125 nodelist = self.get_nodelist(client) 126 if set(expected_nodelist) != set(nodelist): 127 msg = "Client nodelist %s changed unexpectedly (expected %s)" 128 raise self.fail(msg % (nodelist, expected_nodelist)) 129 self.assertGreaterEqual( 130 pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count, 131 1, "resolver was never called") 132 return True 133 134 def run_scenario(self, dns_response, expect_change): 135 if callable(dns_response): 136 dns_resolver_response = dns_response 137 else: 138 def dns_resolver_response(): 139 return dns_response 140 141 if expect_change: 142 assertion_method = self.assert_nodelist_change 143 count_resolver_calls = False 144 expected_response = dns_response 145 else: 146 assertion_method = self.assert_nodelist_nochange 147 count_resolver_calls = True 148 expected_response = self.BASE_SRV_RESPONSE 149 150 # Patch timeouts to ensure short test running times. 151 with SrvPollingKnobs( 152 ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): 153 client = MongoClient(self.CONNECTION_STRING) 154 self.assert_nodelist_change(self.BASE_SRV_RESPONSE, client) 155 # Patch list of hosts returned by DNS query. 156 with SrvPollingKnobs( 157 dns_resolver_nodelist_response=dns_resolver_response, 158 count_resolver_calls=count_resolver_calls): 159 assertion_method(expected_response, client) 160 161 def test_addition(self): 162 response = self.BASE_SRV_RESPONSE[:] 163 response.append( 164 ("localhost.test.build.10gen.cc", 27019)) 165 self.run_scenario(response, True) 166 167 def test_removal(self): 168 response = self.BASE_SRV_RESPONSE[:] 169 response.remove( 170 ("localhost.test.build.10gen.cc", 27018)) 171 self.run_scenario(response, True) 172 173 def test_replace_one(self): 174 response = self.BASE_SRV_RESPONSE[:] 175 response.remove( 176 ("localhost.test.build.10gen.cc", 27018)) 177 response.append( 178 ("localhost.test.build.10gen.cc", 27019)) 179 self.run_scenario(response, True) 180 181 def test_replace_both_with_one(self): 182 response = [("localhost.test.build.10gen.cc", 27019)] 183 self.run_scenario(response, True) 184 185 def test_replace_both_with_two(self): 186 response = [("localhost.test.build.10gen.cc", 27019), 187 ("localhost.test.build.10gen.cc", 27020)] 188 self.run_scenario(response, True) 189 190 def test_dns_failures(self): 191 from dns import exception 192 for exc in (exception.FormError, exception.TooBig, exception.Timeout): 193 def response_callback(*args): 194 raise exc("DNS Failure!") 195 self.run_scenario(response_callback, False) 196 197 def test_dns_record_lookup_empty(self): 198 response = [] 199 self.run_scenario(response, False) 200 201 def _test_recover_from_initial(self, initial_callback): 202 # Construct a valid final response callback distinct from base. 203 response_final = self.BASE_SRV_RESPONSE[:] 204 response_final.pop() 205 def final_callback(): 206 return response_final 207 208 with SrvPollingKnobs( 209 ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME, 210 dns_resolver_nodelist_response=initial_callback, 211 count_resolver_calls=True): 212 # Client uses unpatched method to get initial nodelist 213 client = MongoClient(self.CONNECTION_STRING) 214 # Invalid DNS resolver response should not change nodelist. 215 self.assert_nodelist_nochange(self.BASE_SRV_RESPONSE, client) 216 217 with SrvPollingKnobs( 218 ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME, 219 dns_resolver_nodelist_response=final_callback): 220 # Nodelist should reflect new valid DNS resolver response. 221 self.assert_nodelist_change(response_final, client) 222 223 def test_recover_from_initially_empty_seedlist(self): 224 def empty_seedlist(): 225 return [] 226 self._test_recover_from_initial(empty_seedlist) 227 228 def test_recover_from_initially_erroring_seedlist(self): 229 def erroring_seedlist(): 230 raise ConfigurationError 231 self._test_recover_from_initial(erroring_seedlist) 232 233 234if __name__ == '__main__': 235 unittest.main() 236