1# Copyright (c) Twisted Matrix Laboratories. 2# See LICENSE for details. 3 4""" 5Tests for implementations of L{IHostnameResolver} and their interactions with 6reactor implementations. 7""" 8 9 10from collections import defaultdict 11from socket import ( 12 AF_INET, 13 AF_INET6, 14 AF_UNSPEC, 15 EAI_NONAME, 16 IPPROTO_TCP, 17 SOCK_DGRAM, 18 SOCK_STREAM, 19 gaierror, 20 getaddrinfo, 21) 22from threading import Lock, local 23 24from zope.interface import implementer 25from zope.interface.verify import verifyObject 26 27from twisted._threads import LockWorker, Team, createMemoryWorker 28from twisted.internet._resolver import ( 29 ComplexResolverSimplifier, 30 GAIResolver, 31 SimpleResolverComplexifier, 32) 33from twisted.internet.address import IPv4Address, IPv6Address 34from twisted.internet.base import PluggableResolverMixin, ReactorBase 35from twisted.internet.defer import Deferred 36from twisted.internet.error import DNSLookupError 37from twisted.internet.interfaces import ( 38 IHostnameResolver, 39 IReactorPluggableNameResolver, 40 IResolutionReceiver, 41 IResolverSimple, 42) 43from twisted.python.threadpool import ThreadPool 44from twisted.trial.unittest import SynchronousTestCase as UnitTest 45 46 47class DeterministicThreadPool(ThreadPool): 48 """ 49 Create a deterministic L{ThreadPool} object. 50 """ 51 52 def __init__(self, team): 53 """ 54 Create a L{DeterministicThreadPool} from a L{Team}. 55 """ 56 self.min = 1 57 self.max = 1 58 self.name = None 59 self.threads = [] 60 self._team = team 61 62 63def deterministicPool(): 64 """ 65 Create a deterministic threadpool. 66 67 @return: 2-tuple of L{ThreadPool}, 0-argument C{work} callable; when 68 C{work} is called, do the work. 69 """ 70 worker, doer = createMemoryWorker() 71 return ( 72 DeterministicThreadPool( 73 Team(LockWorker(Lock(), local()), (lambda: worker), lambda: None) 74 ), 75 doer, 76 ) 77 78 79def deterministicReactorThreads(): 80 """ 81 Create a deterministic L{IReactorThreads} 82 83 @return: a 2-tuple consisting of an L{IReactorThreads}-like object and a 84 0-argument callable that will perform one unit of work invoked via that 85 object's C{callFromThread} method. 86 """ 87 worker, doer = createMemoryWorker() 88 89 class CFT: 90 def callFromThread(self, f, *a, **k): 91 worker.do(lambda: f(*a, **k)) 92 93 return CFT(), doer 94 95 96class FakeAddrInfoGetter: 97 """ 98 Test object implementing getaddrinfo. 99 """ 100 101 def __init__(self): 102 """ 103 Create a L{FakeAddrInfoGetter}. 104 """ 105 self.calls = [] 106 self.results = defaultdict(list) 107 108 def getaddrinfo(self, host, port, family=0, socktype=0, proto=0, flags=0): 109 """ 110 Mock for L{socket.getaddrinfo}. 111 112 @param host: see L{socket.getaddrinfo} 113 114 @param port: see L{socket.getaddrinfo} 115 116 @param family: see L{socket.getaddrinfo} 117 118 @param socktype: see L{socket.getaddrinfo} 119 120 @param proto: see L{socket.getaddrinfo} 121 122 @param flags: see L{socket.getaddrinfo} 123 124 @return: L{socket.getaddrinfo} 125 """ 126 self.calls.append((host, port, family, socktype, proto, flags)) 127 results = self.results[host] 128 if results: 129 return results 130 else: 131 raise gaierror(EAI_NONAME, "nodename nor servname provided, or not known") 132 133 def addResultForHost( 134 self, 135 host, 136 sockaddr, 137 family=AF_INET, 138 socktype=SOCK_STREAM, 139 proto=IPPROTO_TCP, 140 canonname=b"", 141 ): 142 """ 143 Add a result for a given hostname. When this hostname is resolved, the 144 result will be a L{list} of all results C{addResultForHost} has been 145 called with using that hostname so far. 146 147 @param host: The hostname to give this result for. This will be the 148 next result from L{FakeAddrInfoGetter.getaddrinfo} when passed this 149 host. 150 151 @type canonname: native L{str} 152 153 @param sockaddr: The resulting socket address; should be a 2-tuple for 154 IPv4 or a 4-tuple for IPv6. 155 156 @param family: An C{AF_*} constant that will be returned from 157 C{getaddrinfo}. 158 159 @param socktype: A C{SOCK_*} constant that will be returned from 160 C{getaddrinfo}. 161 162 @param proto: An C{IPPROTO_*} constant that will be returned from 163 C{getaddrinfo}. 164 165 @param canonname: A canonical name that will be returned from 166 C{getaddrinfo}. 167 @type canonname: native L{str} 168 """ 169 self.results[host].append((family, socktype, proto, canonname, sockaddr)) 170 171 172@implementer(IResolutionReceiver) 173class ResultHolder: 174 """ 175 A resolution receiver which holds onto the results it received. 176 """ 177 178 _started = False 179 _ended = False 180 181 def __init__(self, testCase): 182 """ 183 Create a L{ResultHolder} with a L{UnitTest}. 184 """ 185 self._testCase = testCase 186 187 def resolutionBegan(self, hostResolution): 188 """ 189 Hostname resolution began. 190 191 @param hostResolution: see L{IResolutionReceiver} 192 """ 193 self._started = True 194 self._resolution = hostResolution 195 self._addresses = [] 196 197 def addressResolved(self, address): 198 """ 199 An address was resolved. 200 201 @param address: see L{IResolutionReceiver} 202 """ 203 self._addresses.append(address) 204 205 def resolutionComplete(self): 206 """ 207 Hostname resolution is complete. 208 """ 209 self._ended = True 210 211 212class HelperTests(UnitTest): 213 """ 214 Tests for error cases of helpers used in this module. 215 """ 216 217 def test_logErrorsInThreads(self): 218 """ 219 L{DeterministicThreadPool} will log any exceptions that its "thread" 220 workers encounter. 221 """ 222 self.pool, self.doThreadWork = deterministicPool() 223 224 def divideByZero(): 225 return 1 / 0 226 227 self.pool.callInThread(divideByZero) 228 self.doThreadWork() 229 self.assertEqual(len(self.flushLoggedErrors(ZeroDivisionError)), 1) 230 231 232class HostnameResolutionTests(UnitTest): 233 """ 234 Tests for hostname resolution. 235 """ 236 237 def setUp(self): 238 """ 239 Set up a L{GAIResolver}. 240 """ 241 self.pool, self.doThreadWork = deterministicPool() 242 self.reactor, self.doReactorWork = deterministicReactorThreads() 243 self.getter = FakeAddrInfoGetter() 244 self.resolver = GAIResolver( 245 self.reactor, lambda: self.pool, self.getter.getaddrinfo 246 ) 247 248 def test_resolveOneHost(self): 249 """ 250 Resolving an individual hostname that results in one address from 251 getaddrinfo results in a single call each to C{resolutionBegan}, 252 C{addressResolved}, and C{resolutionComplete}. 253 """ 254 receiver = ResultHolder(self) 255 self.getter.addResultForHost("sample.example.com", ("4.3.2.1", 0)) 256 resolution = self.resolver.resolveHostName(receiver, "sample.example.com") 257 self.assertIs(receiver._resolution, resolution) 258 self.assertEqual(receiver._started, True) 259 self.assertEqual(receiver._ended, False) 260 self.doThreadWork() 261 self.doReactorWork() 262 self.assertEqual(receiver._ended, True) 263 self.assertEqual(receiver._addresses, [IPv4Address("TCP", "4.3.2.1", 0)]) 264 265 def test_resolveOneIPv6Host(self): 266 """ 267 Resolving an individual hostname that results in one address from 268 getaddrinfo results in a single call each to C{resolutionBegan}, 269 C{addressResolved}, and C{resolutionComplete}; C{addressResolved} will 270 receive an L{IPv6Address}. 271 """ 272 receiver = ResultHolder(self) 273 flowInfo = 1 274 scopeID = 2 275 self.getter.addResultForHost( 276 "sample.example.com", ("::1", 0, flowInfo, scopeID), family=AF_INET6 277 ) 278 resolution = self.resolver.resolveHostName(receiver, "sample.example.com") 279 self.assertIs(receiver._resolution, resolution) 280 self.assertEqual(receiver._started, True) 281 self.assertEqual(receiver._ended, False) 282 self.doThreadWork() 283 self.doReactorWork() 284 self.assertEqual(receiver._ended, True) 285 self.assertEqual( 286 receiver._addresses, [IPv6Address("TCP", "::1", 0, flowInfo, scopeID)] 287 ) 288 289 def test_gaierror(self): 290 """ 291 Resolving a hostname that results in C{getaddrinfo} raising a 292 L{gaierror} will result in the L{IResolutionReceiver} receiving a call 293 to C{resolutionComplete} with no C{addressResolved} calls in between; 294 no failure is logged. 295 """ 296 receiver = ResultHolder(self) 297 resolution = self.resolver.resolveHostName(receiver, "sample.example.com") 298 self.assertIs(receiver._resolution, resolution) 299 self.doThreadWork() 300 self.doReactorWork() 301 self.assertEqual(receiver._started, True) 302 self.assertEqual(receiver._ended, True) 303 self.assertEqual(receiver._addresses, []) 304 305 def _resolveOnlyTest(self, addrTypes, expectedAF): 306 """ 307 Verify that the given set of address types results in the given C{AF_} 308 constant being passed to C{getaddrinfo}. 309 310 @param addrTypes: iterable of L{IAddress} implementers 311 312 @param expectedAF: an C{AF_*} constant 313 """ 314 receiver = ResultHolder(self) 315 resolution = self.resolver.resolveHostName( 316 receiver, "sample.example.com", addressTypes=addrTypes 317 ) 318 self.assertIs(receiver._resolution, resolution) 319 self.doThreadWork() 320 self.doReactorWork() 321 host, port, family, socktype, proto, flags = self.getter.calls[0] 322 self.assertEqual(family, expectedAF) 323 324 def test_resolveOnlyIPv4(self): 325 """ 326 When passed an C{addressTypes} parameter containing only 327 L{IPv4Address}, L{GAIResolver} will pass C{AF_INET} to C{getaddrinfo}. 328 """ 329 self._resolveOnlyTest([IPv4Address], AF_INET) 330 331 def test_resolveOnlyIPv6(self): 332 """ 333 When passed an C{addressTypes} parameter containing only 334 L{IPv6Address}, L{GAIResolver} will pass C{AF_INET6} to C{getaddrinfo}. 335 """ 336 self._resolveOnlyTest([IPv6Address], AF_INET6) 337 338 def test_resolveBoth(self): 339 """ 340 When passed an C{addressTypes} parameter containing both L{IPv4Address} 341 and L{IPv6Address} (or the default of C{None}, which carries the same 342 meaning), L{GAIResolver} will pass C{AF_UNSPEC} to C{getaddrinfo}. 343 """ 344 self._resolveOnlyTest([IPv4Address, IPv6Address], AF_UNSPEC) 345 self._resolveOnlyTest(None, AF_UNSPEC) 346 347 def test_transportSemanticsToSocketType(self): 348 """ 349 When passed a C{transportSemantics} paramter, C{'TCP'} (the value 350 present in L{IPv4Address.type} to indicate a stream transport) maps to 351 C{SOCK_STREAM} and C{'UDP'} maps to C{SOCK_DGRAM}. 352 """ 353 receiver = ResultHolder(self) 354 self.resolver.resolveHostName(receiver, "example.com", transportSemantics="TCP") 355 receiver2 = ResultHolder(self) 356 self.resolver.resolveHostName( 357 receiver2, "example.com", transportSemantics="UDP" 358 ) 359 self.doThreadWork() 360 self.doReactorWork() 361 self.doThreadWork() 362 self.doReactorWork() 363 host, port, family, socktypeT, proto, flags = self.getter.calls[0] 364 host, port, family, socktypeU, proto, flags = self.getter.calls[1] 365 self.assertEqual(socktypeT, SOCK_STREAM) 366 self.assertEqual(socktypeU, SOCK_DGRAM) 367 368 def test_socketTypeToAddressType(self): 369 """ 370 When L{GAIResolver} receives a C{SOCK_DGRAM} result from 371 C{getaddrinfo}, it returns a C{'TCP'} L{IPv4Address} or L{IPv6Address}; 372 if it receives C{SOCK_STREAM} then it returns a C{'UDP'} type of same. 373 """ 374 receiver = ResultHolder(self) 375 flowInfo = 1 376 scopeID = 2 377 for socktype in SOCK_STREAM, SOCK_DGRAM: 378 self.getter.addResultForHost( 379 "example.com", 380 ("::1", 0, flowInfo, scopeID), 381 family=AF_INET6, 382 socktype=socktype, 383 ) 384 self.getter.addResultForHost( 385 "example.com", ("127.0.0.3", 0), family=AF_INET, socktype=socktype 386 ) 387 self.resolver.resolveHostName(receiver, "example.com") 388 self.doThreadWork() 389 self.doReactorWork() 390 stream4, stream6, dgram4, dgram6 = receiver._addresses 391 self.assertEqual(stream4.type, "TCP") 392 self.assertEqual(stream6.type, "TCP") 393 self.assertEqual(dgram4.type, "UDP") 394 self.assertEqual(dgram6.type, "UDP") 395 396 397@implementer(IResolverSimple) 398class SillyResolverSimple: 399 """ 400 Trivial implementation of L{IResolverSimple} 401 """ 402 403 def __init__(self): 404 """ 405 Create a L{SillyResolverSimple} with a queue of requests it is working 406 on. 407 """ 408 self._requests = [] 409 410 def getHostByName(self, name, timeout=()): 411 """ 412 Implement L{IResolverSimple.getHostByName}. 413 414 @param name: see L{IResolverSimple.getHostByName}. 415 416 @param timeout: see L{IResolverSimple.getHostByName}. 417 418 @return: see L{IResolverSimple.getHostByName}. 419 """ 420 self._requests.append(Deferred()) 421 return self._requests[-1] 422 423 424class LegacyCompatibilityTests(UnitTest): 425 """ 426 Older applications may supply an object to the reactor via 427 C{installResolver} that only provides L{IResolverSimple}. 428 L{SimpleResolverComplexifier} is a wrapper for an L{IResolverSimple}. 429 """ 430 431 def test_success(self): 432 """ 433 L{SimpleResolverComplexifier} translates C{resolveHostName} into 434 L{IResolutionReceiver.addressResolved}. 435 """ 436 simple = SillyResolverSimple() 437 complex = SimpleResolverComplexifier(simple) 438 receiver = ResultHolder(self) 439 self.assertEqual(receiver._started, False) 440 complex.resolveHostName(receiver, "example.com") 441 self.assertEqual(receiver._started, True) 442 self.assertEqual(receiver._ended, False) 443 self.assertEqual(receiver._addresses, []) 444 simple._requests[0].callback("192.168.1.1") 445 self.assertEqual(receiver._addresses, [IPv4Address("TCP", "192.168.1.1", 0)]) 446 self.assertEqual(receiver._ended, True) 447 448 def test_failure(self): 449 """ 450 L{SimpleResolverComplexifier} translates a known error result from 451 L{IResolverSimple.resolveHostName} into an empty result. 452 """ 453 simple = SillyResolverSimple() 454 complex = SimpleResolverComplexifier(simple) 455 receiver = ResultHolder(self) 456 self.assertEqual(receiver._started, False) 457 complex.resolveHostName(receiver, "example.com") 458 self.assertEqual(receiver._started, True) 459 self.assertEqual(receiver._ended, False) 460 self.assertEqual(receiver._addresses, []) 461 simple._requests[0].errback(DNSLookupError("nope")) 462 self.assertEqual(receiver._ended, True) 463 self.assertEqual(receiver._addresses, []) 464 465 def test_error(self): 466 """ 467 L{SimpleResolverComplexifier} translates an unknown error result from 468 L{IResolverSimple.resolveHostName} into an empty result and a logged 469 error. 470 """ 471 simple = SillyResolverSimple() 472 complex = SimpleResolverComplexifier(simple) 473 receiver = ResultHolder(self) 474 self.assertEqual(receiver._started, False) 475 complex.resolveHostName(receiver, "example.com") 476 self.assertEqual(receiver._started, True) 477 self.assertEqual(receiver._ended, False) 478 self.assertEqual(receiver._addresses, []) 479 simple._requests[0].errback(ZeroDivisionError("zow")) 480 self.assertEqual(len(self.flushLoggedErrors(ZeroDivisionError)), 1) 481 self.assertEqual(receiver._ended, True) 482 self.assertEqual(receiver._addresses, []) 483 484 def test_simplifier(self): 485 """ 486 L{ComplexResolverSimplifier} translates an L{IHostnameResolver} into an 487 L{IResolverSimple} for applications that still expect the old 488 interfaces to be in place. 489 """ 490 self.pool, self.doThreadWork = deterministicPool() 491 self.reactor, self.doReactorWork = deterministicReactorThreads() 492 self.getter = FakeAddrInfoGetter() 493 self.resolver = GAIResolver( 494 self.reactor, lambda: self.pool, self.getter.getaddrinfo 495 ) 496 simpleResolver = ComplexResolverSimplifier(self.resolver) 497 self.getter.addResultForHost("example.com", ("192.168.3.4", 4321)) 498 success = simpleResolver.getHostByName("example.com") 499 failure = simpleResolver.getHostByName("nx.example.com") 500 self.doThreadWork() 501 self.doReactorWork() 502 self.doThreadWork() 503 self.doReactorWork() 504 self.assertEqual(self.failureResultOf(failure).type, DNSLookupError) 505 self.assertEqual(self.successResultOf(success), "192.168.3.4") 506 507 def test_portNumber(self): 508 """ 509 L{SimpleResolverComplexifier} preserves the C{port} argument passed to 510 C{resolveHostName} in its returned addresses. 511 """ 512 simple = SillyResolverSimple() 513 complex = SimpleResolverComplexifier(simple) 514 receiver = ResultHolder(self) 515 complex.resolveHostName(receiver, "example.com", 4321) 516 self.assertEqual(receiver._started, True) 517 self.assertEqual(receiver._ended, False) 518 self.assertEqual(receiver._addresses, []) 519 simple._requests[0].callback("192.168.1.1") 520 self.assertEqual(receiver._addresses, [IPv4Address("TCP", "192.168.1.1", 4321)]) 521 self.assertEqual(receiver._ended, True) 522 523 524class JustEnoughReactor(ReactorBase): 525 """ 526 Just enough subclass implementation to be a valid L{ReactorBase} subclass. 527 """ 528 529 def installWaker(self): 530 """ 531 Do nothing. 532 """ 533 534 535class ReactorInstallationTests(UnitTest): 536 """ 537 Tests for installing old and new resolvers onto a 538 L{PluggableResolverMixin} and L{ReactorBase} (from which all of Twisted's 539 reactor implementations derive). 540 """ 541 542 def test_interfaceCompliance(self): 543 """ 544 L{PluggableResolverMixin} (and its subclasses) implement both 545 L{IReactorPluggableNameResolver} and L{IReactorPluggableResolver}. 546 """ 547 reactor = PluggableResolverMixin() 548 verifyObject(IReactorPluggableNameResolver, reactor) 549 verifyObject(IResolverSimple, reactor.resolver) 550 verifyObject(IHostnameResolver, reactor.nameResolver) 551 552 def test_installingOldStyleResolver(self): 553 """ 554 L{PluggableResolverMixin} will wrap an L{IResolverSimple} in a 555 complexifier. 556 """ 557 reactor = PluggableResolverMixin() 558 it = SillyResolverSimple() 559 verifyObject(IResolverSimple, reactor.installResolver(it)) 560 self.assertIsInstance(reactor.nameResolver, SimpleResolverComplexifier) 561 self.assertIs(reactor.nameResolver._simpleResolver, it) 562 563 def test_defaultToGAIResolver(self): 564 """ 565 L{ReactorBase} defaults to using a L{GAIResolver}. 566 """ 567 reactor = JustEnoughReactor() 568 self.assertIsInstance(reactor.nameResolver, GAIResolver) 569 self.assertIs(reactor.nameResolver._getaddrinfo, getaddrinfo) 570 self.assertIsInstance(reactor.resolver, ComplexResolverSimplifier) 571 self.assertIs(reactor.nameResolver._reactor, reactor) 572 self.assertIs(reactor.resolver._nameResolver, reactor.nameResolver) 573