1# Copyright (c) Twisted Matrix Laboratories.
2# See LICENSE for details.
3
4"""
5Tests for implementations of L{IHostnameResolver} and their interactions with
6reactor implementations.
7"""
8
9
10from collections import defaultdict
11from socket import (
12    AF_INET,
13    AF_INET6,
14    AF_UNSPEC,
15    EAI_NONAME,
16    IPPROTO_TCP,
17    SOCK_DGRAM,
18    SOCK_STREAM,
19    gaierror,
20    getaddrinfo,
21)
22from threading import Lock, local
23
24from zope.interface import implementer
25from zope.interface.verify import verifyObject
26
27from twisted._threads import LockWorker, Team, createMemoryWorker
28from twisted.internet._resolver import (
29    ComplexResolverSimplifier,
30    GAIResolver,
31    SimpleResolverComplexifier,
32)
33from twisted.internet.address import IPv4Address, IPv6Address
34from twisted.internet.base import PluggableResolverMixin, ReactorBase
35from twisted.internet.defer import Deferred
36from twisted.internet.error import DNSLookupError
37from twisted.internet.interfaces import (
38    IHostnameResolver,
39    IReactorPluggableNameResolver,
40    IResolutionReceiver,
41    IResolverSimple,
42)
43from twisted.python.threadpool import ThreadPool
44from twisted.trial.unittest import SynchronousTestCase as UnitTest
45
46
47class DeterministicThreadPool(ThreadPool):
48    """
49    Create a deterministic L{ThreadPool} object.
50    """
51
52    def __init__(self, team):
53        """
54        Create a L{DeterministicThreadPool} from a L{Team}.
55        """
56        self.min = 1
57        self.max = 1
58        self.name = None
59        self.threads = []
60        self._team = team
61
62
63def deterministicPool():
64    """
65    Create a deterministic threadpool.
66
67    @return: 2-tuple of L{ThreadPool}, 0-argument C{work} callable; when
68        C{work} is called, do the work.
69    """
70    worker, doer = createMemoryWorker()
71    return (
72        DeterministicThreadPool(
73            Team(LockWorker(Lock(), local()), (lambda: worker), lambda: None)
74        ),
75        doer,
76    )
77
78
79def deterministicReactorThreads():
80    """
81    Create a deterministic L{IReactorThreads}
82
83    @return: a 2-tuple consisting of an L{IReactorThreads}-like object and a
84        0-argument callable that will perform one unit of work invoked via that
85        object's C{callFromThread} method.
86    """
87    worker, doer = createMemoryWorker()
88
89    class CFT:
90        def callFromThread(self, f, *a, **k):
91            worker.do(lambda: f(*a, **k))
92
93    return CFT(), doer
94
95
96class FakeAddrInfoGetter:
97    """
98    Test object implementing getaddrinfo.
99    """
100
101    def __init__(self):
102        """
103        Create a L{FakeAddrInfoGetter}.
104        """
105        self.calls = []
106        self.results = defaultdict(list)
107
108    def getaddrinfo(self, host, port, family=0, socktype=0, proto=0, flags=0):
109        """
110        Mock for L{socket.getaddrinfo}.
111
112        @param host: see L{socket.getaddrinfo}
113
114        @param port: see L{socket.getaddrinfo}
115
116        @param family: see L{socket.getaddrinfo}
117
118        @param socktype: see L{socket.getaddrinfo}
119
120        @param proto: see L{socket.getaddrinfo}
121
122        @param flags: see L{socket.getaddrinfo}
123
124        @return: L{socket.getaddrinfo}
125        """
126        self.calls.append((host, port, family, socktype, proto, flags))
127        results = self.results[host]
128        if results:
129            return results
130        else:
131            raise gaierror(EAI_NONAME, "nodename nor servname provided, or not known")
132
133    def addResultForHost(
134        self,
135        host,
136        sockaddr,
137        family=AF_INET,
138        socktype=SOCK_STREAM,
139        proto=IPPROTO_TCP,
140        canonname=b"",
141    ):
142        """
143        Add a result for a given hostname.  When this hostname is resolved, the
144        result will be a L{list} of all results C{addResultForHost} has been
145        called with using that hostname so far.
146
147        @param host: The hostname to give this result for.  This will be the
148            next result from L{FakeAddrInfoGetter.getaddrinfo} when passed this
149            host.
150
151        @type canonname: native L{str}
152
153        @param sockaddr: The resulting socket address; should be a 2-tuple for
154            IPv4 or a 4-tuple for IPv6.
155
156        @param family: An C{AF_*} constant that will be returned from
157            C{getaddrinfo}.
158
159        @param socktype: A C{SOCK_*} constant that will be returned from
160            C{getaddrinfo}.
161
162        @param proto: An C{IPPROTO_*} constant that will be returned from
163            C{getaddrinfo}.
164
165        @param canonname: A canonical name that will be returned from
166            C{getaddrinfo}.
167        @type canonname: native L{str}
168        """
169        self.results[host].append((family, socktype, proto, canonname, sockaddr))
170
171
172@implementer(IResolutionReceiver)
173class ResultHolder:
174    """
175    A resolution receiver which holds onto the results it received.
176    """
177
178    _started = False
179    _ended = False
180
181    def __init__(self, testCase):
182        """
183        Create a L{ResultHolder} with a L{UnitTest}.
184        """
185        self._testCase = testCase
186
187    def resolutionBegan(self, hostResolution):
188        """
189        Hostname resolution began.
190
191        @param hostResolution: see L{IResolutionReceiver}
192        """
193        self._started = True
194        self._resolution = hostResolution
195        self._addresses = []
196
197    def addressResolved(self, address):
198        """
199        An address was resolved.
200
201        @param address: see L{IResolutionReceiver}
202        """
203        self._addresses.append(address)
204
205    def resolutionComplete(self):
206        """
207        Hostname resolution is complete.
208        """
209        self._ended = True
210
211
212class HelperTests(UnitTest):
213    """
214    Tests for error cases of helpers used in this module.
215    """
216
217    def test_logErrorsInThreads(self):
218        """
219        L{DeterministicThreadPool} will log any exceptions that its "thread"
220        workers encounter.
221        """
222        self.pool, self.doThreadWork = deterministicPool()
223
224        def divideByZero():
225            return 1 / 0
226
227        self.pool.callInThread(divideByZero)
228        self.doThreadWork()
229        self.assertEqual(len(self.flushLoggedErrors(ZeroDivisionError)), 1)
230
231
232class HostnameResolutionTests(UnitTest):
233    """
234    Tests for hostname resolution.
235    """
236
237    def setUp(self):
238        """
239        Set up a L{GAIResolver}.
240        """
241        self.pool, self.doThreadWork = deterministicPool()
242        self.reactor, self.doReactorWork = deterministicReactorThreads()
243        self.getter = FakeAddrInfoGetter()
244        self.resolver = GAIResolver(
245            self.reactor, lambda: self.pool, self.getter.getaddrinfo
246        )
247
248    def test_resolveOneHost(self):
249        """
250        Resolving an individual hostname that results in one address from
251        getaddrinfo results in a single call each to C{resolutionBegan},
252        C{addressResolved}, and C{resolutionComplete}.
253        """
254        receiver = ResultHolder(self)
255        self.getter.addResultForHost("sample.example.com", ("4.3.2.1", 0))
256        resolution = self.resolver.resolveHostName(receiver, "sample.example.com")
257        self.assertIs(receiver._resolution, resolution)
258        self.assertEqual(receiver._started, True)
259        self.assertEqual(receiver._ended, False)
260        self.doThreadWork()
261        self.doReactorWork()
262        self.assertEqual(receiver._ended, True)
263        self.assertEqual(receiver._addresses, [IPv4Address("TCP", "4.3.2.1", 0)])
264
265    def test_resolveOneIPv6Host(self):
266        """
267        Resolving an individual hostname that results in one address from
268        getaddrinfo results in a single call each to C{resolutionBegan},
269        C{addressResolved}, and C{resolutionComplete}; C{addressResolved} will
270        receive an L{IPv6Address}.
271        """
272        receiver = ResultHolder(self)
273        flowInfo = 1
274        scopeID = 2
275        self.getter.addResultForHost(
276            "sample.example.com", ("::1", 0, flowInfo, scopeID), family=AF_INET6
277        )
278        resolution = self.resolver.resolveHostName(receiver, "sample.example.com")
279        self.assertIs(receiver._resolution, resolution)
280        self.assertEqual(receiver._started, True)
281        self.assertEqual(receiver._ended, False)
282        self.doThreadWork()
283        self.doReactorWork()
284        self.assertEqual(receiver._ended, True)
285        self.assertEqual(
286            receiver._addresses, [IPv6Address("TCP", "::1", 0, flowInfo, scopeID)]
287        )
288
289    def test_gaierror(self):
290        """
291        Resolving a hostname that results in C{getaddrinfo} raising a
292        L{gaierror} will result in the L{IResolutionReceiver} receiving a call
293        to C{resolutionComplete} with no C{addressResolved} calls in between;
294        no failure is logged.
295        """
296        receiver = ResultHolder(self)
297        resolution = self.resolver.resolveHostName(receiver, "sample.example.com")
298        self.assertIs(receiver._resolution, resolution)
299        self.doThreadWork()
300        self.doReactorWork()
301        self.assertEqual(receiver._started, True)
302        self.assertEqual(receiver._ended, True)
303        self.assertEqual(receiver._addresses, [])
304
305    def _resolveOnlyTest(self, addrTypes, expectedAF):
306        """
307        Verify that the given set of address types results in the given C{AF_}
308        constant being passed to C{getaddrinfo}.
309
310        @param addrTypes: iterable of L{IAddress} implementers
311
312        @param expectedAF: an C{AF_*} constant
313        """
314        receiver = ResultHolder(self)
315        resolution = self.resolver.resolveHostName(
316            receiver, "sample.example.com", addressTypes=addrTypes
317        )
318        self.assertIs(receiver._resolution, resolution)
319        self.doThreadWork()
320        self.doReactorWork()
321        host, port, family, socktype, proto, flags = self.getter.calls[0]
322        self.assertEqual(family, expectedAF)
323
324    def test_resolveOnlyIPv4(self):
325        """
326        When passed an C{addressTypes} parameter containing only
327        L{IPv4Address}, L{GAIResolver} will pass C{AF_INET} to C{getaddrinfo}.
328        """
329        self._resolveOnlyTest([IPv4Address], AF_INET)
330
331    def test_resolveOnlyIPv6(self):
332        """
333        When passed an C{addressTypes} parameter containing only
334        L{IPv6Address}, L{GAIResolver} will pass C{AF_INET6} to C{getaddrinfo}.
335        """
336        self._resolveOnlyTest([IPv6Address], AF_INET6)
337
338    def test_resolveBoth(self):
339        """
340        When passed an C{addressTypes} parameter containing both L{IPv4Address}
341        and L{IPv6Address} (or the default of C{None}, which carries the same
342        meaning), L{GAIResolver} will pass C{AF_UNSPEC} to C{getaddrinfo}.
343        """
344        self._resolveOnlyTest([IPv4Address, IPv6Address], AF_UNSPEC)
345        self._resolveOnlyTest(None, AF_UNSPEC)
346
347    def test_transportSemanticsToSocketType(self):
348        """
349        When passed a C{transportSemantics} paramter, C{'TCP'} (the value
350        present in L{IPv4Address.type} to indicate a stream transport) maps to
351        C{SOCK_STREAM} and C{'UDP'} maps to C{SOCK_DGRAM}.
352        """
353        receiver = ResultHolder(self)
354        self.resolver.resolveHostName(receiver, "example.com", transportSemantics="TCP")
355        receiver2 = ResultHolder(self)
356        self.resolver.resolveHostName(
357            receiver2, "example.com", transportSemantics="UDP"
358        )
359        self.doThreadWork()
360        self.doReactorWork()
361        self.doThreadWork()
362        self.doReactorWork()
363        host, port, family, socktypeT, proto, flags = self.getter.calls[0]
364        host, port, family, socktypeU, proto, flags = self.getter.calls[1]
365        self.assertEqual(socktypeT, SOCK_STREAM)
366        self.assertEqual(socktypeU, SOCK_DGRAM)
367
368    def test_socketTypeToAddressType(self):
369        """
370        When L{GAIResolver} receives a C{SOCK_DGRAM} result from
371        C{getaddrinfo}, it returns a C{'TCP'} L{IPv4Address} or L{IPv6Address};
372        if it receives C{SOCK_STREAM} then it returns a C{'UDP'} type of same.
373        """
374        receiver = ResultHolder(self)
375        flowInfo = 1
376        scopeID = 2
377        for socktype in SOCK_STREAM, SOCK_DGRAM:
378            self.getter.addResultForHost(
379                "example.com",
380                ("::1", 0, flowInfo, scopeID),
381                family=AF_INET6,
382                socktype=socktype,
383            )
384            self.getter.addResultForHost(
385                "example.com", ("127.0.0.3", 0), family=AF_INET, socktype=socktype
386            )
387        self.resolver.resolveHostName(receiver, "example.com")
388        self.doThreadWork()
389        self.doReactorWork()
390        stream4, stream6, dgram4, dgram6 = receiver._addresses
391        self.assertEqual(stream4.type, "TCP")
392        self.assertEqual(stream6.type, "TCP")
393        self.assertEqual(dgram4.type, "UDP")
394        self.assertEqual(dgram6.type, "UDP")
395
396
397@implementer(IResolverSimple)
398class SillyResolverSimple:
399    """
400    Trivial implementation of L{IResolverSimple}
401    """
402
403    def __init__(self):
404        """
405        Create a L{SillyResolverSimple} with a queue of requests it is working
406        on.
407        """
408        self._requests = []
409
410    def getHostByName(self, name, timeout=()):
411        """
412        Implement L{IResolverSimple.getHostByName}.
413
414        @param name: see L{IResolverSimple.getHostByName}.
415
416        @param timeout: see L{IResolverSimple.getHostByName}.
417
418        @return: see L{IResolverSimple.getHostByName}.
419        """
420        self._requests.append(Deferred())
421        return self._requests[-1]
422
423
424class LegacyCompatibilityTests(UnitTest):
425    """
426    Older applications may supply an object to the reactor via
427    C{installResolver} that only provides L{IResolverSimple}.
428    L{SimpleResolverComplexifier} is a wrapper for an L{IResolverSimple}.
429    """
430
431    def test_success(self):
432        """
433        L{SimpleResolverComplexifier} translates C{resolveHostName} into
434        L{IResolutionReceiver.addressResolved}.
435        """
436        simple = SillyResolverSimple()
437        complex = SimpleResolverComplexifier(simple)
438        receiver = ResultHolder(self)
439        self.assertEqual(receiver._started, False)
440        complex.resolveHostName(receiver, "example.com")
441        self.assertEqual(receiver._started, True)
442        self.assertEqual(receiver._ended, False)
443        self.assertEqual(receiver._addresses, [])
444        simple._requests[0].callback("192.168.1.1")
445        self.assertEqual(receiver._addresses, [IPv4Address("TCP", "192.168.1.1", 0)])
446        self.assertEqual(receiver._ended, True)
447
448    def test_failure(self):
449        """
450        L{SimpleResolverComplexifier} translates a known error result from
451        L{IResolverSimple.resolveHostName} into an empty result.
452        """
453        simple = SillyResolverSimple()
454        complex = SimpleResolverComplexifier(simple)
455        receiver = ResultHolder(self)
456        self.assertEqual(receiver._started, False)
457        complex.resolveHostName(receiver, "example.com")
458        self.assertEqual(receiver._started, True)
459        self.assertEqual(receiver._ended, False)
460        self.assertEqual(receiver._addresses, [])
461        simple._requests[0].errback(DNSLookupError("nope"))
462        self.assertEqual(receiver._ended, True)
463        self.assertEqual(receiver._addresses, [])
464
465    def test_error(self):
466        """
467        L{SimpleResolverComplexifier} translates an unknown error result from
468        L{IResolverSimple.resolveHostName} into an empty result and a logged
469        error.
470        """
471        simple = SillyResolverSimple()
472        complex = SimpleResolverComplexifier(simple)
473        receiver = ResultHolder(self)
474        self.assertEqual(receiver._started, False)
475        complex.resolveHostName(receiver, "example.com")
476        self.assertEqual(receiver._started, True)
477        self.assertEqual(receiver._ended, False)
478        self.assertEqual(receiver._addresses, [])
479        simple._requests[0].errback(ZeroDivisionError("zow"))
480        self.assertEqual(len(self.flushLoggedErrors(ZeroDivisionError)), 1)
481        self.assertEqual(receiver._ended, True)
482        self.assertEqual(receiver._addresses, [])
483
484    def test_simplifier(self):
485        """
486        L{ComplexResolverSimplifier} translates an L{IHostnameResolver} into an
487        L{IResolverSimple} for applications that still expect the old
488        interfaces to be in place.
489        """
490        self.pool, self.doThreadWork = deterministicPool()
491        self.reactor, self.doReactorWork = deterministicReactorThreads()
492        self.getter = FakeAddrInfoGetter()
493        self.resolver = GAIResolver(
494            self.reactor, lambda: self.pool, self.getter.getaddrinfo
495        )
496        simpleResolver = ComplexResolverSimplifier(self.resolver)
497        self.getter.addResultForHost("example.com", ("192.168.3.4", 4321))
498        success = simpleResolver.getHostByName("example.com")
499        failure = simpleResolver.getHostByName("nx.example.com")
500        self.doThreadWork()
501        self.doReactorWork()
502        self.doThreadWork()
503        self.doReactorWork()
504        self.assertEqual(self.failureResultOf(failure).type, DNSLookupError)
505        self.assertEqual(self.successResultOf(success), "192.168.3.4")
506
507    def test_portNumber(self):
508        """
509        L{SimpleResolverComplexifier} preserves the C{port} argument passed to
510        C{resolveHostName} in its returned addresses.
511        """
512        simple = SillyResolverSimple()
513        complex = SimpleResolverComplexifier(simple)
514        receiver = ResultHolder(self)
515        complex.resolveHostName(receiver, "example.com", 4321)
516        self.assertEqual(receiver._started, True)
517        self.assertEqual(receiver._ended, False)
518        self.assertEqual(receiver._addresses, [])
519        simple._requests[0].callback("192.168.1.1")
520        self.assertEqual(receiver._addresses, [IPv4Address("TCP", "192.168.1.1", 4321)])
521        self.assertEqual(receiver._ended, True)
522
523
524class JustEnoughReactor(ReactorBase):
525    """
526    Just enough subclass implementation to be a valid L{ReactorBase} subclass.
527    """
528
529    def installWaker(self):
530        """
531        Do nothing.
532        """
533
534
535class ReactorInstallationTests(UnitTest):
536    """
537    Tests for installing old and new resolvers onto a
538    L{PluggableResolverMixin} and L{ReactorBase} (from which all of Twisted's
539    reactor implementations derive).
540    """
541
542    def test_interfaceCompliance(self):
543        """
544        L{PluggableResolverMixin} (and its subclasses) implement both
545        L{IReactorPluggableNameResolver} and L{IReactorPluggableResolver}.
546        """
547        reactor = PluggableResolverMixin()
548        verifyObject(IReactorPluggableNameResolver, reactor)
549        verifyObject(IResolverSimple, reactor.resolver)
550        verifyObject(IHostnameResolver, reactor.nameResolver)
551
552    def test_installingOldStyleResolver(self):
553        """
554        L{PluggableResolverMixin} will wrap an L{IResolverSimple} in a
555        complexifier.
556        """
557        reactor = PluggableResolverMixin()
558        it = SillyResolverSimple()
559        verifyObject(IResolverSimple, reactor.installResolver(it))
560        self.assertIsInstance(reactor.nameResolver, SimpleResolverComplexifier)
561        self.assertIs(reactor.nameResolver._simpleResolver, it)
562
563    def test_defaultToGAIResolver(self):
564        """
565        L{ReactorBase} defaults to using a L{GAIResolver}.
566        """
567        reactor = JustEnoughReactor()
568        self.assertIsInstance(reactor.nameResolver, GAIResolver)
569        self.assertIs(reactor.nameResolver._getaddrinfo, getaddrinfo)
570        self.assertIsInstance(reactor.resolver, ComplexResolverSimplifier)
571        self.assertIs(reactor.nameResolver._reactor, reactor)
572        self.assertIs(reactor.resolver._nameResolver, reactor.nameResolver)
573