1# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
2
3# Copyright (C) 2003-2017 Nominum, Inc.
4#
5# Permission to use, copy, modify, and distribute this software and its
6# documentation for any purpose with or without fee is hereby granted,
7# provided that the above copyright notice and this permission notice
8# appear in all copies.
9#
10# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
11# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
13# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
15# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
16# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17
18"""Asynchronous DNS stub resolver."""
19
20import time
21
22import dns.asyncbackend
23import dns.asyncquery
24import dns.exception
25import dns.query
26import dns.resolver
27
28# import some resolver symbols for brevity
29from dns.resolver import NXDOMAIN, NoAnswer, NotAbsolute, NoRootSOA
30
31
32# for indentation purposes below
33_udp = dns.asyncquery.udp
34_tcp = dns.asyncquery.tcp
35
36
37class Resolver(dns.resolver.BaseResolver):
38    """Asynchronous DNS stub resolver."""
39
40    async def resolve(self, qname, rdtype=dns.rdatatype.A,
41                      rdclass=dns.rdataclass.IN,
42                      tcp=False, source=None, raise_on_no_answer=True,
43                      source_port=0, lifetime=None, search=None,
44                      backend=None):
45        """Query nameservers asynchronously to find the answer to the question.
46
47        *backend*, a ``dns.asyncbackend.Backend``, or ``None``.  If ``None``,
48        the default, then dnspython will use the default backend.
49
50        See :py:func:`dns.resolver.Resolver.resolve()` for the
51        documentation of the other parameters, exceptions, and return
52        type of this method.
53        """
54
55        resolution = dns.resolver._Resolution(self, qname, rdtype, rdclass, tcp,
56                                              raise_on_no_answer, search)
57        if not backend:
58            backend = dns.asyncbackend.get_default_backend()
59        start = time.time()
60        while True:
61            (request, answer) = resolution.next_request()
62            # Note we need to say "if answer is not None" and not just
63            # "if answer" because answer implements __len__, and python
64            # will call that.  We want to return if we have an answer
65            # object, including in cases where its length is 0.
66            if answer is not None:
67                # cache hit!
68                return answer
69            done = False
70            while not done:
71                (nameserver, port, tcp, backoff) = resolution.next_nameserver()
72                if backoff:
73                    await backend.sleep(backoff)
74                timeout = self._compute_timeout(start, lifetime)
75                try:
76                    if dns.inet.is_address(nameserver):
77                        if tcp:
78                            response = await _tcp(request, nameserver,
79                                                  timeout, port,
80                                                  source, source_port,
81                                                  backend=backend)
82                        else:
83                            response = await _udp(request, nameserver,
84                                                  timeout, port,
85                                                  source, source_port,
86                                                  raise_on_truncation=True,
87                                                  backend=backend)
88                    else:
89                        # We don't do DoH yet.
90                        raise NotImplementedError
91                except Exception as ex:
92                    (_, done) = resolution.query_result(None, ex)
93                    continue
94                (answer, done) = resolution.query_result(response, None)
95                # Note we need to say "if answer is not None" and not just
96                # "if answer" because answer implements __len__, and python
97                # will call that.  We want to return if we have an answer
98                # object, including in cases where its length is 0.
99                if answer is not None:
100                    return answer
101
102    async def resolve_address(self, ipaddr, *args, **kwargs):
103        """Use an asynchronous resolver to run a reverse query for PTR
104        records.
105
106        This utilizes the resolve() method to perform a PTR lookup on the
107        specified IP address.
108
109        *ipaddr*, a ``str``, the IPv4 or IPv6 address you want to get
110        the PTR record for.
111
112        All other arguments that can be passed to the resolve() function
113        except for rdtype and rdclass are also supported by this
114        function.
115
116        """
117
118        return await self.resolve(dns.reversename.from_address(ipaddr),
119                                  rdtype=dns.rdatatype.PTR,
120                                  rdclass=dns.rdataclass.IN,
121                                  *args, **kwargs)
122
123    # pylint: disable=redefined-outer-name
124
125    async def canonical_name(self, name):
126        """Determine the canonical name of *name*.
127
128        The canonical name is the name the resolver uses for queries
129        after all CNAME and DNAME renamings have been applied.
130
131        *name*, a ``dns.name.Name`` or ``str``, the query name.
132
133        This method can raise any exception that ``resolve()`` can
134        raise, other than ``dns.resolver.NoAnswer`` and
135        ``dns.resolver.NXDOMAIN``.
136
137        Returns a ``dns.name.Name``.
138        """
139        try:
140            answer = await self.resolve(name, raise_on_no_answer=False)
141            canonical_name = answer.canonical_name
142        except dns.resolver.NXDOMAIN as e:
143            canonical_name = e.canonical_name
144        return canonical_name
145
146
147default_resolver = None
148
149
150def get_default_resolver():
151    """Get the default asynchronous resolver, initializing it if necessary."""
152    if default_resolver is None:
153        reset_default_resolver()
154    return default_resolver
155
156
157def reset_default_resolver():
158    """Re-initialize default asynchronous resolver.
159
160    Note that the resolver configuration (i.e. /etc/resolv.conf on UNIX
161    systems) will be re-read immediately.
162    """
163
164    global default_resolver
165    default_resolver = Resolver()
166
167
168async def resolve(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
169                  tcp=False, source=None, raise_on_no_answer=True,
170                  source_port=0, lifetime=None, search=None, backend=None):
171    """Query nameservers asynchronously to find the answer to the question.
172
173    This is a convenience function that uses the default resolver
174    object to make the query.
175
176    See :py:func:`dns.asyncresolver.Resolver.resolve` for more
177    information on the parameters.
178    """
179
180    return await get_default_resolver().resolve(qname, rdtype, rdclass, tcp,
181                                                source, raise_on_no_answer,
182                                                source_port, lifetime, search,
183                                                backend)
184
185
186async def resolve_address(ipaddr, *args, **kwargs):
187    """Use a resolver to run a reverse query for PTR records.
188
189    See :py:func:`dns.asyncresolver.Resolver.resolve_address` for more
190    information on the parameters.
191    """
192
193    return await get_default_resolver().resolve_address(ipaddr, *args, **kwargs)
194
195async def canonical_name(name):
196    """Determine the canonical name of *name*.
197
198    See :py:func:`dns.resolver.Resolver.canonical_name` for more
199    information on the parameters and possible exceptions.
200    """
201
202    return await get_default_resolver().canonical_name(name)
203
204async def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False,
205                        resolver=None, backend=None):
206    """Find the name of the zone which contains the specified name.
207
208    See :py:func:`dns.resolver.Resolver.zone_for_name` for more
209    information on the parameters and possible exceptions.
210    """
211
212    if isinstance(name, str):
213        name = dns.name.from_text(name, dns.name.root)
214    if resolver is None:
215        resolver = get_default_resolver()
216    if not name.is_absolute():
217        raise NotAbsolute(name)
218    while True:
219        try:
220            answer = await resolver.resolve(name, dns.rdatatype.SOA, rdclass,
221                                            tcp, backend=backend)
222            if answer.rrset.name == name:
223                return name
224            # otherwise we were CNAMEd or DNAMEd and need to look higher
225        except (NXDOMAIN, NoAnswer):
226            pass
227        try:
228            name = name.parent()
229        except dns.name.NoParent:  # pragma: no cover
230            raise NoRootSOA
231