1#
2# This file is a part of DNSViz, a tool suite for DNS/DNSSEC monitoring,
3# analysis, and visualization.
4# Created by Casey Deccio (casey@deccio.net)
5#
6# Copyright 2014-2016 VeriSign, Inc.
7#
8# Copyright 2016-2021 Casey Deccio
9#
10# DNSViz is free software; you can redistribute it and/or modify
11# it under the terms of the GNU General Public License as published by
12# the Free Software Foundation; either version 2 of the License, or
13# (at your option) any later version.
14#
15# DNSViz is distributed in the hope that it will be useful,
16# but WITHOUT ANY WARRANTY; without even the implied warranty of
17# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
18# GNU General Public License for more details.
19#
20# You should have received a copy of the GNU General Public License along
21# with DNSViz.  If not, see <http://www.gnu.org/licenses/>.
22#
23
24from __future__ import unicode_literals
25
26import bisect
27import io
28import math
29import random
30import threading
31import time
32
33from .config import RESOLV_CONF
34from . import query
35from .ipaddr import *
36from . import response as Response
37from . import transport
38from . import util
39
40import dns.rdataclass, dns.exception, dns.message, dns.rcode, dns.resolver
41
42MAX_CNAME_REDIRECTION = 20
43
44class ResolvConfError(Exception):
45    pass
46
47_r = None
48def get_standard_resolver():
49    global _r
50    if _r is None:
51        _r = Resolver.from_file(RESOLV_CONF, query.StandardRecursiveQuery)
52    return _r
53
54_rd = None
55def get_dnssec_resolver():
56    global _rd
57    if _rd is None:
58        _rd = Resolver.from_file(RESOLV_CONF, query.RecursiveDNSSECQuery)
59    return _rd
60
61class DNSAnswer:
62    '''An answer to a DNS query, including the full DNS response message, the
63    RRset requested, and the server.'''
64
65    def __init__(self, qname, rdtype, response, server):
66        self.response = response
67        self.server = server
68
69        self.rrset = None
70
71        self._handle_nxdomain(response)
72
73        i = 0
74        qname_sought = qname
75        while i < MAX_CNAME_REDIRECTION:
76            try:
77                self.rrset = response.find_rrset(response.answer, qname_sought, dns.rdataclass.IN, rdtype)
78                i = MAX_CNAME_REDIRECTION
79            except KeyError:
80                try:
81                    rrset = response.find_rrset(response.answer, qname_sought, dns.rdataclass.IN, dns.rdatatype.CNAME)
82                    qname_sought = rrset[0].target
83                except KeyError:
84                    break
85            i += 1
86
87        self._handle_noanswer()
88
89    def _handle_nxdomain(self, response):
90        if response.rcode() == dns.rcode.NXDOMAIN:
91            raise dns.resolver.NXDOMAIN()
92
93    def _handle_noanswer(self):
94        if self.rrset is None:
95            raise dns.resolver.NoAnswer()
96
97class DNSAnswerNoAnswerAllowed(DNSAnswer):
98    '''An answer to a DNS query, including the full DNS response message, the
99    RRset requested, and the server.'''
100
101    def _handle_noanswer(self):
102        pass
103
104class Resolver:
105    '''A simple stub DNS resolver.'''
106
107    def __init__(self, servers, query_cls, timeout=1.0, max_attempts=5, lifetime=15.0, shuffle=False, client_ipv4=None, client_ipv6=None, port=53, transport_manager=None, th_factories=None):
108        if lifetime is None and max_attempts is None:
109            raise ValueError("At least one of lifetime or max_attempts must be specified for a Resolver instance.")
110
111        self._servers = servers
112        self._query_cls = query_cls
113        self._timeout = timeout
114        self._max_attempts = max_attempts
115        self._lifetime = lifetime
116        self._shuffle = shuffle
117        self._client_ipv4 = client_ipv4
118        self._client_ipv6 = client_ipv6
119        self._port = port
120        self._transport_manager = transport_manager
121        self._th_factories = th_factories
122
123    @classmethod
124    def from_file(cls, resolv_conf, query_cls, **kwargs):
125        servers = []
126        try:
127            with io.open(resolv_conf, 'r', encoding='utf-8') as f:
128                for line in f:
129                    line = line.strip()
130                    words = line.split()
131                    if len(words) > 1 and words[0] == 'nameserver':
132                        try:
133                            servers.append(IPAddr(words[1]))
134                        except ValueError:
135                            pass
136        except IOError as e:
137            raise ResolvConfError('Unable to open %s: %s' % (resolv_conf, str(e)))
138        if not servers:
139            raise ResolvConfError('No servers found in %s' % (resolv_conf))
140        return Resolver(servers, query_cls, **kwargs)
141
142    def query(self, qname, rdtype, rdclass=dns.rdataclass.IN, accept_first_response=False, continue_on_servfail=True):
143        return list(self.query_multiple((qname, rdtype, rdclass), accept_first_response=accept_first_response, continue_on_servfail=continue_on_servfail).values())[0]
144
145    def query_for_answer(self, qname, rdtype, rdclass=dns.rdataclass.IN, allow_noanswer=False):
146        answer = list(self.query_multiple_for_answer((qname, rdtype, rdclass), allow_noanswer=allow_noanswer).values())[0]
147        if isinstance(answer, DNSAnswer):
148            return answer
149        else:
150            raise answer
151
152    def query_multiple_for_answer(self, *query_tuples, **kwargs):
153        if kwargs.pop('allow_noanswer', False):
154            answer_cls = DNSAnswerNoAnswerAllowed
155        else:
156            answer_cls = DNSAnswer
157
158        responses = self.query_multiple(*query_tuples, accept_first_response=False, continue_on_servfail=True)
159
160        answers = {}
161        for query_tuple, (server, response) in responses.items():
162            # no servers were queried
163            if response is None:
164                answers[query_tuple] = dns.resolver.NoNameservers()
165            # response was valid
166            elif response.is_complete_response() and response.is_valid_response():
167                try:
168                    answers[query_tuple] = answer_cls(query_tuple[0], query_tuple[1], response.message, server)
169                except (dns.resolver.NoAnswer, dns.resolver.NXDOMAIN) as e:
170                    answers[query_tuple] = e
171            # response was timeout or network error
172            elif response.error in (query.RESPONSE_ERROR_TIMEOUT, query.RESPONSE_ERROR_NETWORK_ERROR):
173                answers[query_tuple] = dns.exception.Timeout()
174            # there was a response, but it was invalid for some reason
175            else:
176                answers[query_tuple] = dns.resolver.NoNameservers()
177
178        return answers
179
180    def query_multiple(self, *query_tuples, **kwargs):
181        valid_servers = {}
182        responses = {}
183        last_responses = {}
184        attempts = {}
185
186        accept_first_response = kwargs.get('accept_first_response', False)
187        continue_on_servfail = kwargs.get('continue_on_servfail', True)
188
189        query_tuples = set(query_tuples)
190        for query_tuple in query_tuples:
191            attempts[query_tuple] = 0
192            valid_servers[query_tuple] = set(self._servers)
193
194        if self._shuffle:
195            servers = self._servers[:]
196            random.shuffle(servers)
197        else:
198            servers = self._servers
199
200        tuples_to_query = query_tuples.difference(last_responses)
201        start = time.time()
202        while tuples_to_query and (self._lifetime is None or time.time() - start < self._lifetime):
203            now = time.time()
204            queries = {}
205            for query_tuple in tuples_to_query:
206                if not valid_servers[query_tuple]:
207                    try:
208                        last_responses[query_tuple] = responses[query_tuple]
209                    except KeyError:
210                        last_responses[query_tuple] = None, None
211                    continue
212
213                while query_tuple not in queries:
214                    cycle_num, server_index = divmod(attempts[query_tuple], len(servers))
215                    # if we've exceeded our maximum attempts, then break out
216                    if cycle_num >= self._max_attempts:
217                        try:
218                            last_responses[query_tuple] = responses[query_tuple]
219                        except KeyError:
220                            last_responses[query_tuple] = None, None
221                        break
222
223                    server = servers[server_index]
224                    if server in valid_servers[query_tuple]:
225                        if self._lifetime is not None:
226                            timeout = min(self._timeout, max((start + self._lifetime) - now, 0))
227                        else:
228                            timeout = self._timeout
229                        q = self._query_cls(query_tuple[0], query_tuple[1], query_tuple[2], server, None, client_ipv4=self._client_ipv4, client_ipv6=self._client_ipv6, port=self._port, query_timeout=timeout, max_attempts=1)
230                        queries[query_tuple] = q
231
232                    attempts[query_tuple] += 1
233
234            query.ExecutableDNSQuery.execute_queries(*list(queries.values()), tm=self._transport_manager, th_factories=self._th_factories)
235
236            for query_tuple, q in queries.items():
237                # no response means we didn't even try because we don't have
238                # proper connectivity
239                if not q.responses:
240                    server = list(q.servers)[0]
241                    valid_servers[query_tuple].remove(server)
242                    if not valid_servers[query_tuple]:
243                        last_responses[query_tuple] = server, None
244                    continue
245
246                server, client_response = list(q.responses.items())[0]
247                client, response = list(client_response.items())[0]
248                responses[query_tuple] = (server, response)
249                # if we received a complete message with an acceptable rcode,
250                # then accept it as the last response
251                if response.is_complete_response() and response.is_valid_response():
252                    last_responses[query_tuple] = responses[query_tuple]
253                # if we received a message that was incomplete (i.e.,
254                # truncated), had an invalid rcode, was malformed, or was
255                # otherwise invalid, then accept the response (if directed),
256                # and invalidate the server
257                elif response.message is not None or \
258                        response.error not in (query.RESPONSE_ERROR_TIMEOUT, query.RESPONSE_ERROR_NETWORK_ERROR):
259                    # accept_first_response is true, then accept the response
260                    if accept_first_response:
261                        last_responses[query_tuple] = responses[query_tuple]
262                    # if the response was SERVFAIL, and we were not directed to
263                    # continue, then accept the response
264                    elif response.message is not None and \
265                            response.message.rcode() == dns.rcode.SERVFAIL and not continue_on_servfail:
266                        last_responses[query_tuple] = responses[query_tuple]
267                    valid_servers[query_tuple].remove(server)
268
269            tuples_to_query = query_tuples.difference(last_responses)
270
271        for query_tuple in tuples_to_query:
272            last_responses[query_tuple] = responses[query_tuple]
273
274        return last_responses
275
276class CacheEntry:
277    def __init__(self, rrset, source, expiration, rcode, soa_rrset):
278        self.rrset = rrset
279        self.source = source
280        self.expiration = expiration
281        self.rcode = rcode
282        self.soa_rrset = soa_rrset
283
284class ServFail(Exception):
285    pass
286
287class FullResolver:
288    '''A full iterative DNS resolver, following hints.'''
289
290    SRC_PRIMARY_ZONE = 0
291    SRC_SECONDARY_ZONE = 1
292    SRC_AUTH_ANS = 2
293    SRC_AUTH_AUTH = 3
294    SRC_GLUE_PRIMARY_ZONE = 4
295    SRC_GLUE_SECONDARY_ZONE = 5
296    SRC_NONAUTH_ANS = 6
297    SRC_ADDITIONAL = 7
298    SRC_NONAUTH_AUTH = 7
299
300    MIN_TTL = 60
301    MAX_CHAIN = 20
302
303    default_th_factory = transport.DNSQueryTransportHandlerDNSFactory()
304
305    def __init__(self, hints=util.get_root_hints(), query_cls=(query.QuickDNSSECQuery, query.DiagnosticQuery), client_ipv4=None, client_ipv6=None, odd_ports=None, cookie_standin=None, transport_manager=None, th_factories=None, max_ttl=None):
306
307        self._hints = hints
308        self._query_cls = query_cls
309        self._client_ipv4 = client_ipv4
310        self._client_ipv6 = client_ipv6
311        if odd_ports is None:
312            odd_ports = {}
313        self._odd_ports = odd_ports
314        self._transport_manager = transport_manager
315        if th_factories is None:
316            self._th_factories = (self.default_th_factory,)
317        else:
318            self._th_factories = th_factories
319        self.allow_loopback_query = not bool([x for x in self._th_factories if not x.cls.allow_loopback_query])
320        self.allow_private_query = not bool([x for x in self._th_factories if not x.cls.allow_private_query])
321
322        self._max_ttl = max_ttl
323
324        self._cookie_standin = cookie_standin
325        self._cookie_jar = {}
326        self._cache = {}
327        self._expirations = []
328        self._cache_lock = threading.Lock()
329
330    def _allow_server(self, server):
331        if not self.allow_loopback_query and (LOOPBACK_IPV4_RE.search(server) is not None or server == LOOPBACK_IPV6):
332            return False
333        if not self.allow_private_query and (RFC_1918_RE.search(server) is not None or LINK_LOCAL_RE.search(server) is not None or UNIQ_LOCAL_RE.search(server) is not None):
334            return False
335        if ZERO_SLASH8_RE.search(server) is not None:
336            return False
337        return True
338
339    def flush_cache(self):
340        with self._cache_lock:
341            self._cache = {}
342            self._expirations = []
343
344    def expire_cache(self):
345        t = time.time()
346
347        with self._cache_lock:
348            if self._expirations and self._expirations[0][0] > t:
349                return
350
351            future_index = bisect.bisect_right(self._expirations, (t, None))
352            for i in range(future_index):
353                cache_key = self._expirations[i][1]
354                del self._cache[cache_key]
355            self._expirations = self._expirations[future_index:]
356
357    def cache_put(self, name, rdtype, rrset, source, rcode, soa_rrset, ttl):
358        t = time.time()
359
360        if rrset is not None:
361            ttl = max(rrset.ttl, self.MIN_TTL)
362        elif soa_rrset is not None:
363            ttl = max(min(soa_rrset.ttl, soa_rrset[0].minimum), self.MIN_TTL)
364        elif ttl is not None:
365            ttl = max(ttl, self.MIN_TTL)
366        else:
367            ttl = self.MIN_TTL
368
369        if self._max_ttl is not None and ttl > self._max_ttl:
370            ttl = self._max_ttl
371
372        expiration = math.ceil(t) + ttl
373
374        key = (name, rdtype)
375        new_entry = CacheEntry(rrset, source, expiration, rcode, soa_rrset)
376
377        with self._cache_lock:
378            try:
379                old_entry = self._cache[key]
380            except KeyError:
381                pass
382            else:
383                if new_entry.source >= old_entry.source:
384                    return
385
386                # remove the old entry from expirations
387                old_index = bisect.bisect_left(self._expirations, (old_entry.expiration, key))
388                old_key = self._expirations.pop(old_index)[1]
389                assert old_key == key, "Old key doesn't match new key!"
390
391            self._cache[key] = new_entry
392            bisect.insort(self._expirations, (expiration, key))
393
394    def cache_get(self, name, rdtype):
395        try:
396            entry = self._cache[(name, rdtype)]
397        except KeyError:
398            return None
399        else:
400            t = time.time()
401            ttl = max(0, int(entry.expiration - t))
402
403            if entry.rrset is not None:
404                entry.rrset.update_ttl(ttl)
405            if entry.soa_rrset is not None:
406                entry.soa_rrset.update_ttl(ttl)
407
408            return entry
409
410    def cache_dump(self):
411        keys = self._cache.keys()
412        keys.sort()
413
414        t = time.time()
415        for key in keys:
416            entry = self._cache[key]
417
418    def query(self, qname, rdtype, rdclass=dns.rdataclass.IN):
419        msg = dns.message.make_response(dns.message.make_query(qname, rdtype), True)
420        try:
421            l = self._query(qname, rdtype, rdclass, 0, self.SRC_NONAUTH_ANS)
422        except ServFail:
423            msg.set_rcode(dns.rcode.SERVFAIL)
424        else:
425            msg.set_rcode(l[-1])
426            for rrset in l[:-1]:
427                if rrset is not None:
428                    new_rrset = msg.find_rrset(msg.answer, rrset.name, rrset.rdclass, rrset.rdtype, create=True)
429                    new_rrset.update(rrset)
430        return msg, None
431
432    def query_for_answer(self, qname, rdtype, rdclass=dns.rdataclass.IN, allow_noanswer=False):
433        response, server = self.query(qname, rdtype, rdclass)
434        if response.rcode() == dns.rcode.SERVFAIL:
435            raise dns.resolver.NoNameservers()
436        if allow_noanswer:
437            answer_cls = DNSAnswerNoAnswerAllowed
438        else:
439            answer_cls = DNSAnswer
440        return answer_cls(qname, rdtype, response, server)
441
442    def query_multiple_for_answer(self, *query_tuples, **kwargs):
443        allow_noanswer = kwargs.pop('allow_noanswer', False)
444        answers = {}
445        for query_tuple in query_tuples:
446            try:
447                answers[query_tuple] = self.query_for_answer(query_tuple[0], query_tuple[1], query_tuple[2], allow_noanswer=allow_noanswer)
448            except (dns.resolver.NoAnswer, dns.resolver.NXDOMAIN, dns.resolver.NoNameservers) as e:
449                answers[query_tuple] = e
450        return answers
451
452    def query_multiple(self, *query_tuples, **kwargs):
453        responses = {}
454        for query_tuple in query_tuples:
455            responses[query_tuple] = self.query(query_tuple[0], query_tuple[1], query_tuple[2])
456        return responses
457
458    def _get_answer(self, qname, rdtype, rdclass, max_source):
459        # first check cache for answer
460        entry = self.cache_get(qname, rdtype)
461        if entry is not None and entry.source <= max_source:
462            return [entry.rrset, entry.rcode]
463
464        # check hints, if allowed
465        if self.SRC_ADDITIONAL <= max_source and (qname, rdtype) in self._hints:
466            return [self._hints[(qname, rdtype)], dns.rcode.NOERROR]
467
468        return None
469
470    def _query(self, qname, rdtype, rdclass, level, max_source, starting_domain=None):
471        self.expire_cache()
472
473        # check for max chain length
474        if level > self.MAX_CHAIN:
475            raise ServFail('SERVFAIL - resolution chain too long')
476
477        ans = self._get_answer(qname, rdtype, rdclass, max_source)
478        if ans:
479            return ans
480
481        # next check cache for alias
482        ans = self._get_answer(qname, dns.rdatatype.CNAME, rdclass, max_source)
483        if ans:
484            return [ans[0]] + self._query(ans[0][0].target, rdtype, rdclass, level + 1, max_source)
485
486        # now check for closest enclosing NS, DNAME, or hint
487        closest_zone = qname
488
489        # when rdtype is DS, start at the parent
490        if rdtype == dns.rdatatype.DS and qname != dns.name.root:
491            closest_zone = qname.parent()
492        elif starting_domain is not None:
493            assert qname.is_subdomain(starting_domain), 'qname must be a subdomain of starting_domain'
494            closest_zone = starting_domain
495
496        ns_names = {}
497
498        # iterative resolution is necessary, so find the closest zone ancestor or DNAME
499        while True:
500            # if we are a proper superdomain, then look for DNAME
501            if closest_zone != qname:
502                entry = self.cache_get(closest_zone, dns.rdatatype.DNAME)
503                if entry is not None and entry.rrset is not None:
504                    cname_rrset = Response.cname_from_dname(qname, entry.rrset)
505                    return [entry.rrset, cname_rrset] + self._query(cname_rrset[0].target, rdtype, rdclass, level + 1, max_source)
506
507            # look for NS records in cache
508            ans = self._get_answer(closest_zone, dns.rdatatype.NS, rdclass, self.SRC_ADDITIONAL)
509            if ans and ans[0] is not None:
510                ns_rrset = ans[0]
511                for ns_rdata in ans[0]:
512                    addrs = set()
513                    for a_rdtype in dns.rdatatype.A, dns.rdatatype.AAAA:
514                        ans1 = self._get_answer(ns_rdata.target, a_rdtype, rdclass, self.SRC_ADDITIONAL)
515                        if ans1 and ans1[0]:
516                            for a_rdata in ans1[0]:
517                                addrs.add(IPAddr(a_rdata.address))
518                    if addrs:
519                        ns_names[ns_rdata.target] = addrs
520                    else:
521                        ns_names[ns_rdata.target] = None
522
523            # if there were NS records associated with the names, then
524            # no need to continue
525            if ns_names:
526                break
527
528            # otherwise, continue upwards until some are found
529            try:
530                closest_zone = closest_zone.parent()
531            except dns.name.NoParent:
532                raise ServFail('SERVFAIL - no NS RRs at root')
533
534        ret = None
535        soa_rrset = None
536        rcode = None
537
538        # iterate, following referrals down the namespace tree
539        while True:
540            bailiwick = ns_rrset.name
541            is_referral = False
542
543            # query names first for which there are addresses
544            ns_names_with_addresses = [n for n in ns_names if ns_names[n] is not None]
545            random.shuffle(ns_names_with_addresses)
546            ns_names_without_addresses = list(set(ns_names).difference(ns_names_with_addresses))
547            random.shuffle(ns_names_without_addresses)
548            all_ns_names = ns_names_with_addresses + ns_names_without_addresses
549            previous_valid_answer = set()
550
551            for query_cls in self._query_cls:
552                # query each server until we get a match
553                for ns_name in all_ns_names:
554                    is_referral = False
555                    if ns_names[ns_name] is None:
556                        # first get the addresses associated with each name
557                        ns_names[ns_name] = set()
558                        for a_rdtype in dns.rdatatype.A, dns.rdatatype.AAAA:
559                            if ns_name.is_subdomain(bailiwick):
560                                if bailiwick == dns.name.root:
561                                    sd = bailiwick
562                                else:
563                                    sd = bailiwick.parent()
564                            else:
565                                sd = None
566                            try:
567                                a_rrset = self._query(ns_name, a_rdtype, dns.rdataclass.IN, level + 1, self.SRC_ADDITIONAL, starting_domain=sd)[-2]
568                            except ServFail:
569                                a_rrset = None
570                            if a_rrset is not None:
571                                for rdata in a_rrset:
572                                    ns_names[ns_name].add(IPAddr(rdata.address))
573
574                    for server in ns_names[ns_name].difference(previous_valid_answer):
575                        # server disallowed by policy
576                        if not self._allow_server(server):
577                            continue
578
579                        q = query_cls(qname, rdtype, rdclass, (server,), bailiwick, self._client_ipv4, self._client_ipv6, self._odd_ports.get((bailiwick, server), 53), cookie_jar=self._cookie_jar, cookie_standin=self._cookie_standin)
580                        q.execute(tm=self._transport_manager, th_factories=self._th_factories)
581                        is_referral = False
582
583                        if not q.responses:
584                            # No network connectivity
585                            continue
586
587                        server1, client_response = list(q.responses.items())[0]
588                        client, response = list(client_response.items())[0]
589
590                        server_cookie = response.get_server_cookie()
591                        if server_cookie is not None:
592                            self._cookie_jar[server1] = server_cookie
593
594                        if not (response.is_valid_response() and response.is_complete_response()):
595                            continue
596
597                        previous_valid_answer.add(server)
598
599                        soa_rrset = None
600                        rcode = response.message.rcode()
601
602                        # response is acceptable
603                        try:
604                            # first check for exact match
605                            ret = [[x for x in response.message.answer if x.name == qname and x.rdtype == rdtype and x.rdclass == rdclass][0]]
606                        except IndexError:
607                            try:
608                                # now look for DNAME
609                                dname_rrset = [x for x in response.message.answer if qname.is_subdomain(x.name) and qname != x.name and x.rdtype == dns.rdatatype.DNAME and x.rdclass == rdclass][0]
610                            except IndexError:
611                                try:
612                                    # now look for CNAME
613                                    cname_rrset = [x for x in response.message.answer if x.name == qname and x.rdtype == dns.rdatatype.CNAME and x.rdclass == rdclass][0]
614                                except IndexError:
615                                    ret = [None]
616                                    # no answer
617                                    try:
618                                        soa_rrset = [x for x in response.message.authority if qname.is_subdomain(x.name) and x.rdtype == dns.rdatatype.SOA][0]
619                                    except IndexError:
620                                        pass
621                                # cache the NS RRset
622                                else:
623                                    cname_rrset = [x for x in response.message.answer if x.name == qname and x.rdtype == dns.rdatatype.CNAME and x.rdclass == rdclass][0]
624                                    ret = [cname_rrset]
625                            else:
626                                # handle DNAME: return the DNAME, CNAME and (recursively) its chain
627                                cname_rrset = Response.cname_from_dname(qname, dname_rrset)
628                                ret = [dname_rrset, cname_rrset]
629
630                        if response.is_referral(qname, rdtype, rdclass, bailiwick):
631                            is_referral = True
632                            a_rrsets = {}
633                            min_ttl = None
634                            ret = None
635
636                            # if response is referral, then we follow it
637                            ns_rrset = [x for x in response.message.authority if qname.is_subdomain(x.name) and x.rdtype == dns.rdatatype.NS][0]
638                            ns_names = response.ns_ip_mapping_from_additional(ns_rrset.name, bailiwick)
639                            for ns_name in ns_names:
640                                if not ns_names[ns_name]:
641                                    ns_names[ns_name] = None
642                                else: # name is in bailiwick
643                                    for a_rdtype in (dns.rdatatype.A, dns.rdatatype.AAAA):
644                                        try:
645                                            a_rrsets[a_rdtype] = response.message.find_rrset(response.message.additional, ns_name, a_rdtype, dns.rdataclass.IN)
646                                        except KeyError:
647                                            pass
648                                        else:
649                                            if min_ttl is None or a_rrsets[a_rdtype].ttl < min_ttl:
650                                                min_ttl = a_rrsets[a_rdtype].ttl
651
652                                    for a_rdtype in (dns.rdatatype.A, dns.rdatatype.AAAA):
653                                        if a_rdtype in a_rrsets:
654                                            a_rrsets[a_rdtype].update_ttl(min_ttl)
655                                            self.cache_put(ns_name, a_rdtype, a_rrsets[a_rdtype], self.SRC_ADDITIONAL, dns.rcode.NOERROR, None, None)
656                                        else:
657                                            self.cache_put(ns_name, a_rdtype, None, self.SRC_ADDITIONAL, dns.rcode.NOERROR, None, min_ttl)
658
659                            if min_ttl is not None:
660                                ns_rrset.update_ttl(min_ttl)
661
662                            # cache the NS RRset
663                            self.cache_put(ns_rrset.name, dns.rdatatype.NS, ns_rrset, self.SRC_NONAUTH_AUTH, rcode, None, None)
664                            break
665
666                        elif response.is_authoritative():
667                            terminal = True
668                            a_rrsets = {}
669                            min_ttl = None
670
671                            # if response is authoritative (and not a referral), then we return it
672                            try:
673                                ns_rrset = [x for x in  response.message.answer + response.message.authority if qname.is_subdomain(x.name) and x.rdtype == dns.rdatatype.NS][0]
674                            except IndexError:
675                                pass
676                            else:
677
678                                ns_names = response.ns_ip_mapping_from_additional(ns_rrset.name, bailiwick)
679                                for ns_name in ns_names:
680                                    if not ns_names[ns_name]:
681                                        ns_names[ns_name] = None
682                                    else: # name is in bailiwick
683                                        for a_rdtype in (dns.rdatatype.A, dns.rdatatype.AAAA):
684                                            try:
685                                                a_rrsets[a_rdtype] = response.message.find_rrset(response.message.additional, ns_name, a_rdtype, dns.rdataclass.IN)
686                                            except KeyError:
687                                                pass
688                                            else:
689                                                if min_ttl is None or a_rrsets[a_rdtype].ttl < min_ttl:
690                                                    min_ttl = a_rrsets[a_rdtype].ttl
691
692                                        for a_rdtype in (dns.rdatatype.A, dns.rdatatype.AAAA):
693                                            if a_rdtype in a_rrsets:
694                                                a_rrsets[a_rdtype].update_ttl(min_ttl)
695                                                self.cache_put(ns_name, a_rdtype, a_rrsets[a_rdtype], self.SRC_ADDITIONAL, dns.rcode.NOERROR, None, None)
696                                            else:
697                                                self.cache_put(ns_name, a_rdtype, None, self.SRC_ADDITIONAL, dns.rcode.NOERROR, None, min_ttl)
698
699                                if min_ttl is not None:
700                                    ns_rrset.update_ttl(min_ttl)
701
702                                self.cache_put(ns_rrset.name, dns.rdatatype.NS, ns_rrset, self.SRC_AUTH_AUTH, rcode, None, None)
703
704                            if ret[-1] == None:
705                                self.cache_put(qname, rdtype, None, self.SRC_AUTH_ANS, rcode, soa_rrset, None)
706
707                            else:
708                                for rrset in ret:
709                                    self.cache_put(rrset.name, rrset.rdtype, rrset, self.SRC_AUTH_ANS, rcode, None, None)
710
711                                if ret[-1].rdtype == dns.rdatatype.CNAME:
712                                    ret += self._query(ret[-1][0].target, rdtype, rdclass, level + 1, self.SRC_NONAUTH_ANS)
713                                    terminal = False
714
715                            if terminal:
716                                ret.append(rcode)
717                            return ret
718
719                    # if referral, then break
720                    if is_referral:
721                        break
722
723                # if referral, then break
724                if is_referral:
725                    break
726
727            # if not referral, then we're done iterating
728            if not is_referral:
729                break
730
731            # if we were only to ask the parent, then we're done
732            if starting_domain is not None:
733                break
734
735            # otherwise continue onward, looking for an authoritative answer
736
737        # return non-authoritative answer
738        if ret is not None:
739            terminal = True
740
741            if ret[-1] == None:
742                self.cache_put(qname, rdtype, None, self.SRC_NONAUTH_ANS, rcode, soa_rrset, None)
743
744            else:
745                for rrset in ret:
746                    self.cache_put(rrset.name, rrset.rdtype, rrset, self.SRC_NONAUTH_ANS, rcode, None, None)
747
748                if ret[-1].rdtype == dns.rdatatype.CNAME:
749                    ret += self._query(ret[-1][0].target, rdtype, rdclass, level + 1, self.SRC_NONAUTH_ANS)
750                    terminal = False
751
752            if terminal:
753                ret.append(rcode)
754            return ret
755
756        raise ServFail('SERVFAIL - no valid responses')
757
758class PrivateFullResolver(FullResolver):
759    default_th_factory = transport.DNSQueryTransportHandlerDNSPrivateFactory()
760
761def main():
762    import sys
763    import getopt
764
765    def usage():
766        sys.stderr.write('Usage: %s <name> <type> [<server>...]\n' % (sys.argv[0]))
767        sys.exit(1)
768
769    try:
770        opts, args = getopt.getopt(sys.argv[1:], '')
771        opts = dict(opts)
772    except getopt.error:
773        usage()
774
775    if len(args) < 2:
776        usage()
777
778    if len(args) < 3:
779        r = get_standard_resolver()
780    else:
781        r = Resolver([IPAddr(x) for x in sys.argv[3:]], query.StandardRecursiveQuery)
782    a = r.query_for_answer(dns.name.from_text(args[0]), dns.rdatatype.from_text(args[1]))
783
784    print('Response for %s/%s:' % (args[0], args[1]))
785    print('   from %s: %s (%d bytes)' % (a.server, repr(a.response), len(a.response.to_wire())))
786    print('   answer:\n      %s' % (a.rrset))
787
788if __name__ == '__main__':
789    main()
790