1# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
2
3# Copyright (C) 2003-2017 Nominum, Inc.
4#
5# Permission to use, copy, modify, and distribute this software and its
6# documentation for any purpose with or without fee is hereby granted,
7# provided that the above copyright notice and this permission notice
8# appear in all copies.
9#
10# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
11# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
13# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
15# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
16# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17
18"""Talk to a DNS server."""
19
20import contextlib
21import errno
22import os
23import select
24import socket
25import struct
26import time
27import base64
28import urllib.parse
29
30import dns.exception
31import dns.inet
32import dns.name
33import dns.message
34import dns.rcode
35import dns.rdataclass
36import dns.rdatatype
37import dns.serial
38
39try:
40    import requests
41    from requests_toolbelt.adapters.source import SourceAddressAdapter
42    from requests_toolbelt.adapters.host_header_ssl import HostHeaderSSLAdapter
43    have_doh = True
44except ImportError:  # pragma: no cover
45    have_doh = False
46
47try:
48    import ssl
49except ImportError:  # pragma: no cover
50    class ssl:    # type: ignore
51
52        class WantReadException(Exception):
53            pass
54
55        class WantWriteException(Exception):
56            pass
57
58        class SSLSocket:
59            pass
60
61        def create_default_context(self, *args, **kwargs):
62            raise Exception('no ssl support')
63
64# Function used to create a socket.  Can be overridden if needed in special
65# situations.
66socket_factory = socket.socket
67
68class UnexpectedSource(dns.exception.DNSException):
69    """A DNS query response came from an unexpected address or port."""
70
71
72class BadResponse(dns.exception.FormError):
73    """A DNS query response does not respond to the question asked."""
74
75
76class TransferError(dns.exception.DNSException):
77    """A zone transfer response got a non-zero rcode."""
78
79    def __init__(self, rcode):
80        message = 'Zone transfer error: %s' % dns.rcode.to_text(rcode)
81        super().__init__(message)
82        self.rcode = rcode
83
84
85class NoDOH(dns.exception.DNSException):
86    """DNS over HTTPS (DOH) was requested but the requests module is not
87    available."""
88
89
90def _compute_times(timeout):
91    now = time.time()
92    if timeout is None:
93        return (now, None)
94    else:
95        return (now, now + timeout)
96
97# This module can use either poll() or select() as the "polling backend".
98#
99# A backend function takes an fd, bools for readability, writablity, and
100# error detection, and a timeout.
101
102def _poll_for(fd, readable, writable, error, timeout):
103    """Poll polling backend."""
104
105    event_mask = 0
106    if readable:
107        event_mask |= select.POLLIN
108    if writable:
109        event_mask |= select.POLLOUT
110    if error:
111        event_mask |= select.POLLERR
112
113    pollable = select.poll()
114    pollable.register(fd, event_mask)
115
116    if timeout:
117        event_list = pollable.poll(timeout * 1000)
118    else:
119        event_list = pollable.poll()
120
121    return bool(event_list)
122
123
124def _select_for(fd, readable, writable, error, timeout):
125    """Select polling backend."""
126
127    rset, wset, xset = [], [], []
128
129    if readable:
130        rset = [fd]
131    if writable:
132        wset = [fd]
133    if error:
134        xset = [fd]
135
136    if timeout is None:
137        (rcount, wcount, xcount) = select.select(rset, wset, xset)
138    else:
139        (rcount, wcount, xcount) = select.select(rset, wset, xset, timeout)
140
141    return bool((rcount or wcount or xcount))
142
143
144def _wait_for(fd, readable, writable, error, expiration):
145    # Use the selected polling backend to wait for any of the specified
146    # events.  An "expiration" absolute time is converted into a relative
147    # timeout.
148
149    done = False
150    while not done:
151        if expiration is None:
152            timeout = None
153        else:
154            timeout = expiration - time.time()
155            if timeout <= 0.0:
156                raise dns.exception.Timeout
157        try:
158            if isinstance(fd, ssl.SSLSocket) and readable and fd.pending() > 0:
159                return True
160            if not _polling_backend(fd, readable, writable, error, timeout):
161                raise dns.exception.Timeout
162        except OSError as e:  # pragma: no cover
163            if e.args[0] != errno.EINTR:
164                raise e
165        done = True
166
167
168def _set_polling_backend(fn):
169    # Internal API. Do not use.
170
171    global _polling_backend
172
173    _polling_backend = fn
174
175if hasattr(select, 'poll'):
176    # Prefer poll() on platforms that support it because it has no
177    # limits on the maximum value of a file descriptor (plus it will
178    # be more efficient for high values).
179    _polling_backend = _poll_for
180else:
181    _polling_backend = _select_for  # pragma: no cover
182
183
184def _wait_for_readable(s, expiration):
185    _wait_for(s, True, False, True, expiration)
186
187
188def _wait_for_writable(s, expiration):
189    _wait_for(s, False, True, True, expiration)
190
191
192def _addresses_equal(af, a1, a2):
193    # Convert the first value of the tuple, which is a textual format
194    # address into binary form, so that we are not confused by different
195    # textual representations of the same address
196    try:
197        n1 = dns.inet.inet_pton(af, a1[0])
198        n2 = dns.inet.inet_pton(af, a2[0])
199    except dns.exception.SyntaxError:
200        return False
201    return n1 == n2 and a1[1:] == a2[1:]
202
203
204def _matches_destination(af, from_address, destination, ignore_unexpected):
205    # Check that from_address is appropriate for a response to a query
206    # sent to destination.
207    if not destination:
208        return True
209    if _addresses_equal(af, from_address, destination) or \
210       (dns.inet.is_multicast(destination[0]) and
211        from_address[1:] == destination[1:]):
212        return True
213    elif ignore_unexpected:
214        return False
215    raise UnexpectedSource(f'got a response from {from_address} instead of '
216                           f'{destination}')
217
218
219def _destination_and_source(where, port, source, source_port,
220                            where_must_be_address=True):
221    # Apply defaults and compute destination and source tuples
222    # suitable for use in connect(), sendto(), or bind().
223    af = None
224    destination = None
225    try:
226        af = dns.inet.af_for_address(where)
227        destination = where
228    except Exception:
229        if where_must_be_address:
230            raise
231        # URLs are ok so eat the exception
232    if source:
233        saf = dns.inet.af_for_address(source)
234        if af:
235            # We know the destination af, so source had better agree!
236            if saf != af:
237                raise ValueError('different address families for source ' +
238                                 'and destination')
239        else:
240            # We didn't know the destination af, but we know the source,
241            # so that's our af.
242            af = saf
243    if source_port and not source:
244        # Caller has specified a source_port but not an address, so we
245        # need to return a source, and we need to use the appropriate
246        # wildcard address as the address.
247        if af == socket.AF_INET:
248            source = '0.0.0.0'
249        elif af == socket.AF_INET6:
250            source = '::'
251        else:
252            raise ValueError('source_port specified but address family is '
253                             'unknown')
254    # Convert high-level (address, port) tuples into low-level address
255    # tuples.
256    if destination:
257        destination = dns.inet.low_level_address_tuple((destination, port), af)
258    if source:
259        source = dns.inet.low_level_address_tuple((source, source_port), af)
260    return (af, destination, source)
261
262def _make_socket(af, type, source, ssl_context=None, server_hostname=None):
263    s = socket_factory(af, type)
264    try:
265        s.setblocking(False)
266        if source is not None:
267            s.bind(source)
268        if ssl_context:
269            return ssl_context.wrap_socket(s, do_handshake_on_connect=False,
270                                           server_hostname=server_hostname)
271        else:
272            return s
273    except Exception:
274        s.close()
275        raise
276
277def https(q, where, timeout=None, port=443, source=None, source_port=0,
278          one_rr_per_rrset=False, ignore_trailing=False,
279          session=None, path='/dns-query', post=True,
280          bootstrap_address=None, verify=True):
281    """Return the response obtained after sending a query via DNS-over-HTTPS.
282
283    *q*, a ``dns.message.Message``, the query to send.
284
285    *where*, a ``str``, the nameserver IP address or the full URL. If an IP
286    address is given, the URL will be constructed using the following schema:
287    https://<IP-address>:<port>/<path>.
288
289    *timeout*, a ``float`` or ``None``, the number of seconds to
290    wait before the query times out. If ``None``, the default, wait forever.
291
292    *port*, a ``int``, the port to send the query to. The default is 443.
293
294    *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
295    the source address.  The default is the wildcard address.
296
297    *source_port*, an ``int``, the port from which to send the message.
298    The default is 0.
299
300    *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
301    RRset.
302
303    *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
304    junk at end of the received message.
305
306    *session*, a ``requests.session.Session``.  If provided, the session to use
307    to send the queries.
308
309    *path*, a ``str``. If *where* is an IP address, then *path* will be used to
310    construct the URL to send the DNS query to.
311
312    *post*, a ``bool``. If ``True``, the default, POST method will be used.
313
314    *bootstrap_address*, a ``str``, the IP address to use to bypass the
315    system's DNS resolver.
316
317    *verify*, a ``str``, containing a path to a certificate file or directory.
318
319    Returns a ``dns.message.Message``.
320    """
321
322    if not have_doh:
323        raise NoDOH  # pragma: no cover
324
325    wire = q.to_wire()
326    (af, destination, source) = _destination_and_source(where, port,
327                                                        source, source_port,
328                                                        False)
329    transport_adapter = None
330    headers = {
331        "accept": "application/dns-message"
332    }
333    try:
334        where_af = dns.inet.af_for_address(where)
335        if where_af == socket.AF_INET:
336            url = 'https://{}:{}{}'.format(where, port, path)
337        elif where_af == socket.AF_INET6:
338            url = 'https://[{}]:{}{}'.format(where, port, path)
339    except ValueError:
340        if bootstrap_address is not None:
341            split_url = urllib.parse.urlsplit(where)
342            headers['Host'] = split_url.hostname
343            url = where.replace(split_url.hostname, bootstrap_address)
344            transport_adapter = HostHeaderSSLAdapter()
345        else:
346            url = where
347    if source is not None:
348        # set source port and source address
349        transport_adapter = SourceAddressAdapter(source)
350
351    with contextlib.ExitStack() as stack:
352        if not session:
353            session = stack.enter_context(requests.sessions.Session())
354
355        if transport_adapter:
356            session.mount(url, transport_adapter)
357
358        # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
359        # GET and POST examples
360        if post:
361            headers.update({
362                "content-type": "application/dns-message",
363                "content-length": str(len(wire))
364            })
365            response = session.post(url, headers=headers, data=wire,
366                                    timeout=timeout, verify=verify)
367        else:
368            wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
369            response = session.get(url, headers=headers,
370                                   timeout=timeout, verify=verify,
371                                   params={"dns": wire})
372
373    # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
374    # status codes
375    if response.status_code < 200 or response.status_code > 299:
376        raise ValueError('{} responded with status code {}'
377                         '\nResponse body: {}'.format(where,
378                                                      response.status_code,
379                                                      response.content))
380    r = dns.message.from_wire(response.content,
381                              keyring=q.keyring,
382                              request_mac=q.request_mac,
383                              one_rr_per_rrset=one_rr_per_rrset,
384                              ignore_trailing=ignore_trailing)
385    r.time = response.elapsed
386    if not q.is_response(r):
387        raise BadResponse
388    return r
389
390def send_udp(sock, what, destination, expiration=None):
391    """Send a DNS message to the specified UDP socket.
392
393    *sock*, a ``socket``.
394
395    *what*, a ``bytes`` or ``dns.message.Message``, the message to send.
396
397    *destination*, a destination tuple appropriate for the address family
398    of the socket, specifying where to send the query.
399
400    *expiration*, a ``float`` or ``None``, the absolute time at which
401    a timeout exception should be raised.  If ``None``, no timeout will
402    occur.
403
404    Returns an ``(int, float)`` tuple of bytes sent and the sent time.
405    """
406
407    if isinstance(what, dns.message.Message):
408        what = what.to_wire()
409    _wait_for_writable(sock, expiration)
410    sent_time = time.time()
411    n = sock.sendto(what, destination)
412    return (n, sent_time)
413
414
415def receive_udp(sock, destination=None, expiration=None,
416                ignore_unexpected=False, one_rr_per_rrset=False,
417                keyring=None, request_mac=b'', ignore_trailing=False,
418                raise_on_truncation=False):
419    """Read a DNS message from a UDP socket.
420
421    *sock*, a ``socket``.
422
423    *destination*, a destination tuple appropriate for the address family
424    of the socket, specifying where the message is expected to arrive from.
425    When receiving a response, this would be where the associated query was
426    sent.
427
428    *expiration*, a ``float`` or ``None``, the absolute time at which
429    a timeout exception should be raised.  If ``None``, no timeout will
430    occur.
431
432    *ignore_unexpected*, a ``bool``.  If ``True``, ignore responses from
433    unexpected sources.
434
435    *one_rr_per_rrset*, a ``bool``.  If ``True``, put each RR into its own
436    RRset.
437
438    *keyring*, a ``dict``, the keyring to use for TSIG.
439
440    *request_mac*, a ``bytes``, the MAC of the request (for TSIG).
441
442    *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
443    junk at end of the received message.
444
445    *raise_on_truncation*, a ``bool``.  If ``True``, raise an exception if
446    the TC bit is set.
447
448    Raises if the message is malformed, if network errors occur, of if
449    there is a timeout.
450
451    If *destination* is not ``None``, returns a ``(dns.message.Message, float)``
452    tuple of the received message and the received time.
453
454    If *destination* is ``None``, returns a
455    ``(dns.message.Message, float, tuple)``
456    tuple of the received message, the received time, and the address where
457    the message arrived from.
458    """
459
460    wire = b''
461    while 1:
462        _wait_for_readable(sock, expiration)
463        (wire, from_address) = sock.recvfrom(65535)
464        if _matches_destination(sock.family, from_address, destination,
465                                ignore_unexpected):
466            break
467    received_time = time.time()
468    r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
469                              one_rr_per_rrset=one_rr_per_rrset,
470                              ignore_trailing=ignore_trailing,
471                              raise_on_truncation=raise_on_truncation)
472    if destination:
473        return (r, received_time)
474    else:
475        return (r, received_time, from_address)
476
477def udp(q, where, timeout=None, port=53, source=None, source_port=0,
478        ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False,
479        raise_on_truncation=False, sock=None):
480    """Return the response obtained after sending a query via UDP.
481
482    *q*, a ``dns.message.Message``, the query to send
483
484    *where*, a ``str`` containing an IPv4 or IPv6 address,  where
485    to send the message.
486
487    *timeout*, a ``float`` or ``None``, the number of seconds to wait before the
488    query times out.  If ``None``, the default, wait forever.
489
490    *port*, an ``int``, the port send the message to.  The default is 53.
491
492    *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
493    the source address.  The default is the wildcard address.
494
495    *source_port*, an ``int``, the port from which to send the message.
496    The default is 0.
497
498    *ignore_unexpected*, a ``bool``.  If ``True``, ignore responses from
499    unexpected sources.
500
501    *one_rr_per_rrset*, a ``bool``.  If ``True``, put each RR into its own
502    RRset.
503
504    *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
505    junk at end of the received message.
506
507    *raise_on_truncation*, a ``bool``.  If ``True``, raise an exception if
508    the TC bit is set.
509
510    *sock*, a ``socket.socket``, or ``None``, the socket to use for the
511    query.  If ``None``, the default, a socket is created.  Note that
512    if a socket is provided, it must be a nonblocking datagram socket,
513    and the *source* and *source_port* are ignored.
514
515    Returns a ``dns.message.Message``.
516    """
517
518    wire = q.to_wire()
519    (af, destination, source) = _destination_and_source(where, port,
520                                                        source, source_port)
521    (begin_time, expiration) = _compute_times(timeout)
522    with contextlib.ExitStack() as stack:
523        if sock:
524            s = sock
525        else:
526            s = stack.enter_context(_make_socket(af, socket.SOCK_DGRAM, source))
527        send_udp(s, wire, destination, expiration)
528        (r, received_time) = receive_udp(s, destination, expiration,
529                                         ignore_unexpected, one_rr_per_rrset,
530                                         q.keyring, q.mac, ignore_trailing,
531                                         raise_on_truncation)
532        r.time = received_time - begin_time
533        if not q.is_response(r):
534            raise BadResponse
535        return r
536
537def udp_with_fallback(q, where, timeout=None, port=53, source=None,
538                      source_port=0, ignore_unexpected=False,
539                      one_rr_per_rrset=False, ignore_trailing=False,
540                      udp_sock=None, tcp_sock=None):
541    """Return the response to the query, trying UDP first and falling back
542    to TCP if UDP results in a truncated response.
543
544    *q*, a ``dns.message.Message``, the query to send
545
546    *where*, a ``str`` containing an IPv4 or IPv6 address,  where
547    to send the message.
548
549    *timeout*, a ``float`` or ``None``, the number of seconds to wait before the
550    query times out.  If ``None``, the default, wait forever.
551
552    *port*, an ``int``, the port send the message to.  The default is 53.
553
554    *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
555    the source address.  The default is the wildcard address.
556
557    *source_port*, an ``int``, the port from which to send the message.
558    The default is 0.
559
560    *ignore_unexpected*, a ``bool``.  If ``True``, ignore responses from
561    unexpected sources.
562
563    *one_rr_per_rrset*, a ``bool``.  If ``True``, put each RR into its own
564    RRset.
565
566    *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
567    junk at end of the received message.
568
569    *udp_sock*, a ``socket.socket``, or ``None``, the socket to use for the
570    UDP query.  If ``None``, the default, a socket is created.  Note that
571    if a socket is provided, it must be a nonblocking datagram socket,
572    and the *source* and *source_port* are ignored for the UDP query.
573
574    *tcp_sock*, a ``socket.socket``, or ``None``, the socket to use for the
575    TCP query.  If ``None``, the default, a socket is created.  Note that
576    if a socket is provided, it must be a nonblocking connected stream
577    socket, and *where*, *source* and *source_port* are ignored for the TCP
578    query.
579
580    Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True``
581    if and only if TCP was used.
582    """
583    try:
584        response = udp(q, where, timeout, port, source, source_port,
585                       ignore_unexpected, one_rr_per_rrset,
586                       ignore_trailing, True, udp_sock)
587        return (response, False)
588    except dns.message.Truncated:
589        response = tcp(q, where, timeout, port, source, source_port,
590                       one_rr_per_rrset, ignore_trailing, tcp_sock)
591        return (response, True)
592
593def _net_read(sock, count, expiration):
594    """Read the specified number of bytes from sock.  Keep trying until we
595    either get the desired amount, or we hit EOF.
596    A Timeout exception will be raised if the operation is not completed
597    by the expiration time.
598    """
599    s = b''
600    while count > 0:
601        _wait_for_readable(sock, expiration)
602        try:
603            n = sock.recv(count)
604        except ssl.SSLWantReadError:  # pragma: no cover
605            continue
606        except ssl.SSLWantWriteError:  # pragma: no cover
607            _wait_for_writable(sock, expiration)
608            continue
609        if n == b'':
610            raise EOFError
611        count = count - len(n)
612        s = s + n
613    return s
614
615
616def _net_write(sock, data, expiration):
617    """Write the specified data to the socket.
618    A Timeout exception will be raised if the operation is not completed
619    by the expiration time.
620    """
621    current = 0
622    l = len(data)
623    while current < l:
624        _wait_for_writable(sock, expiration)
625        try:
626            current += sock.send(data[current:])
627        except ssl.SSLWantReadError:  # pragma: no cover
628            _wait_for_readable(sock, expiration)
629            continue
630        except ssl.SSLWantWriteError:  # pragma: no cover
631            continue
632
633
634def send_tcp(sock, what, expiration=None):
635    """Send a DNS message to the specified TCP socket.
636
637    *sock*, a ``socket``.
638
639    *what*, a ``bytes`` or ``dns.message.Message``, the message to send.
640
641    *expiration*, a ``float`` or ``None``, the absolute time at which
642    a timeout exception should be raised.  If ``None``, no timeout will
643    occur.
644
645    Returns an ``(int, float)`` tuple of bytes sent and the sent time.
646    """
647
648    if isinstance(what, dns.message.Message):
649        what = what.to_wire()
650    l = len(what)
651    # copying the wire into tcpmsg is inefficient, but lets us
652    # avoid writev() or doing a short write that would get pushed
653    # onto the net
654    tcpmsg = struct.pack("!H", l) + what
655    _wait_for_writable(sock, expiration)
656    sent_time = time.time()
657    _net_write(sock, tcpmsg, expiration)
658    return (len(tcpmsg), sent_time)
659
660def receive_tcp(sock, expiration=None, one_rr_per_rrset=False,
661                keyring=None, request_mac=b'', ignore_trailing=False):
662    """Read a DNS message from a TCP socket.
663
664    *sock*, a ``socket``.
665
666    *expiration*, a ``float`` or ``None``, the absolute time at which
667    a timeout exception should be raised.  If ``None``, no timeout will
668    occur.
669
670    *one_rr_per_rrset*, a ``bool``.  If ``True``, put each RR into its own
671    RRset.
672
673    *keyring*, a ``dict``, the keyring to use for TSIG.
674
675    *request_mac*, a ``bytes``, the MAC of the request (for TSIG).
676
677    *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
678    junk at end of the received message.
679
680    Raises if the message is malformed, if network errors occur, of if
681    there is a timeout.
682
683    Returns a ``(dns.message.Message, float)`` tuple of the received message
684    and the received time.
685    """
686
687    ldata = _net_read(sock, 2, expiration)
688    (l,) = struct.unpack("!H", ldata)
689    wire = _net_read(sock, l, expiration)
690    received_time = time.time()
691    r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
692                              one_rr_per_rrset=one_rr_per_rrset,
693                              ignore_trailing=ignore_trailing)
694    return (r, received_time)
695
696def _connect(s, address, expiration):
697    err = s.connect_ex(address)
698    if err == 0:
699        return
700    if err in (errno.EINPROGRESS, errno.EWOULDBLOCK, errno.EALREADY):
701        _wait_for_writable(s, expiration)
702        err = s.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
703    if err != 0:
704        raise OSError(err, os.strerror(err))
705
706
707def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
708        one_rr_per_rrset=False, ignore_trailing=False, sock=None):
709    """Return the response obtained after sending a query via TCP.
710
711    *q*, a ``dns.message.Message``, the query to send
712
713    *where*, a ``str`` containing an IPv4 or IPv6 address, where
714    to send the message.
715
716    *timeout*, a ``float`` or ``None``, the number of seconds to wait before the
717    query times out.  If ``None``, the default, wait forever.
718
719    *port*, an ``int``, the port send the message to.  The default is 53.
720
721    *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
722    the source address.  The default is the wildcard address.
723
724    *source_port*, an ``int``, the port from which to send the message.
725    The default is 0.
726
727    *one_rr_per_rrset*, a ``bool``.  If ``True``, put each RR into its own
728    RRset.
729
730    *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
731    junk at end of the received message.
732
733    *sock*, a ``socket.socket``, or ``None``, the socket to use for the
734    query.  If ``None``, the default, a socket is created.  Note that
735    if a socket is provided, it must be a nonblocking connected stream
736    socket, and *where*, *port*, *source* and *source_port* are ignored.
737
738    Returns a ``dns.message.Message``.
739    """
740
741    wire = q.to_wire()
742    (begin_time, expiration) = _compute_times(timeout)
743    with contextlib.ExitStack() as stack:
744        if sock:
745            #
746            # Verify that the socket is connected, as if it's not connected,
747            # it's not writable, and the polling in send_tcp() will time out or
748            # hang forever.
749            sock.getpeername()
750            s = sock
751        else:
752            (af, destination, source) = _destination_and_source(where, port,
753                                                                source,
754                                                                source_port)
755            s = stack.enter_context(_make_socket(af, socket.SOCK_STREAM,
756                                                 source))
757            _connect(s, destination, expiration)
758        send_tcp(s, wire, expiration)
759        (r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset,
760                                         q.keyring, q.mac, ignore_trailing)
761        r.time = received_time - begin_time
762        if not q.is_response(r):
763            raise BadResponse
764        return r
765
766
767def _tls_handshake(s, expiration):
768    while True:
769        try:
770            s.do_handshake()
771            return
772        except ssl.SSLWantReadError:
773            _wait_for_readable(s, expiration)
774        except ssl.SSLWantWriteError:  # pragma: no cover
775            _wait_for_writable(s, expiration)
776
777
778def tls(q, where, timeout=None, port=853, source=None, source_port=0,
779        one_rr_per_rrset=False, ignore_trailing=False, sock=None,
780        ssl_context=None, server_hostname=None):
781    """Return the response obtained after sending a query via TLS.
782
783    *q*, a ``dns.message.Message``, the query to send
784
785    *where*, a ``str`` containing an IPv4 or IPv6 address,  where
786    to send the message.
787
788    *timeout*, a ``float`` or ``None``, the number of seconds to wait before the
789    query times out.  If ``None``, the default, wait forever.
790
791    *port*, an ``int``, the port send the message to.  The default is 853.
792
793    *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
794    the source address.  The default is the wildcard address.
795
796    *source_port*, an ``int``, the port from which to send the message.
797    The default is 0.
798
799    *one_rr_per_rrset*, a ``bool``.  If ``True``, put each RR into its own
800    RRset.
801
802    *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
803    junk at end of the received message.
804
805    *sock*, an ``ssl.SSLSocket``, or ``None``, the socket to use for
806    the query.  If ``None``, the default, a socket is created.  Note
807    that if a socket is provided, it must be a nonblocking connected
808    SSL stream socket, and *where*, *port*, *source*, *source_port*,
809    and *ssl_context* are ignored.
810
811    *ssl_context*, an ``ssl.SSLContext``, the context to use when establishing
812    a TLS connection. If ``None``, the default, creates one with the default
813    configuration.
814
815    *server_hostname*, a ``str`` containing the server's hostname.  The
816    default is ``None``, which means that no hostname is known, and if an
817    SSL context is created, hostname checking will be disabled.
818
819    Returns a ``dns.message.Message``.
820
821    """
822
823    if sock:
824        #
825        # If a socket was provided, there's no special TLS handling needed.
826        #
827        return tcp(q, where, timeout, port, source, source_port,
828                   one_rr_per_rrset, ignore_trailing, sock)
829
830    wire = q.to_wire()
831    (begin_time, expiration) = _compute_times(timeout)
832    (af, destination, source) = _destination_and_source(where, port,
833                                                        source, source_port)
834    if ssl_context is None and not sock:
835        ssl_context = ssl.create_default_context()
836        if server_hostname is None:
837            ssl_context.check_hostname = False
838
839    with _make_socket(af, socket.SOCK_STREAM, source, ssl_context=ssl_context,
840                      server_hostname=server_hostname) as s:
841        _connect(s, destination, expiration)
842        _tls_handshake(s, expiration)
843        send_tcp(s, wire, expiration)
844        (r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset,
845                                         q.keyring, q.mac, ignore_trailing)
846        r.time = received_time - begin_time
847        if not q.is_response(r):
848            raise BadResponse
849        return r
850
851
852def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
853        timeout=None, port=53, keyring=None, keyname=None, relativize=True,
854        lifetime=None, source=None, source_port=0, serial=0,
855        use_udp=False, keyalgorithm=dns.tsig.default_algorithm):
856    """Return a generator for the responses to a zone transfer.
857
858    *where*, a ``str`` containing an IPv4 or IPv6 address,  where
859    to send the message.
860
861    *zone*, a ``dns.name.Name`` or ``str``, the name of the zone to transfer.
862
863    *rdtype*, an ``int`` or ``str``, the type of zone transfer.  The
864    default is ``dns.rdatatype.AXFR``.  ``dns.rdatatype.IXFR`` can be
865    used to do an incremental transfer instead.
866
867    *rdclass*, an ``int`` or ``str``, the class of the zone transfer.
868    The default is ``dns.rdataclass.IN``.
869
870    *timeout*, a ``float``, the number of seconds to wait for each
871    response message.  If None, the default, wait forever.
872
873    *port*, an ``int``, the port send the message to.  The default is 53.
874
875    *keyring*, a ``dict``, the keyring to use for TSIG.
876
877    *keyname*, a ``dns.name.Name`` or ``str``, the name of the TSIG
878    key to use.
879
880    *relativize*, a ``bool``.  If ``True``, all names in the zone will be
881    relativized to the zone origin.  It is essential that the
882    relativize setting matches the one specified to
883    ``dns.zone.from_xfr()`` if using this generator to make a zone.
884
885    *lifetime*, a ``float``, the total number of seconds to spend
886    doing the transfer.  If ``None``, the default, then there is no
887    limit on the time the transfer may take.
888
889    *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
890    the source address.  The default is the wildcard address.
891
892    *source_port*, an ``int``, the port from which to send the message.
893    The default is 0.
894
895    *serial*, an ``int``, the SOA serial number to use as the base for
896    an IXFR diff sequence (only meaningful if *rdtype* is
897    ``dns.rdatatype.IXFR``).
898
899    *use_udp*, a ``bool``.  If ``True``, use UDP (only meaningful for IXFR).
900
901    *keyalgorithm*, a ``dns.name.Name`` or ``str``, the TSIG algorithm to use.
902
903    Raises on errors, and so does the generator.
904
905    Returns a generator of ``dns.message.Message`` objects.
906    """
907
908    if isinstance(zone, str):
909        zone = dns.name.from_text(zone)
910    rdtype = dns.rdatatype.RdataType.make(rdtype)
911    q = dns.message.make_query(zone, rdtype, rdclass)
912    if rdtype == dns.rdatatype.IXFR:
913        rrset = dns.rrset.from_text(zone, 0, 'IN', 'SOA',
914                                    '. . %u 0 0 0 0' % serial)
915        q.authority.append(rrset)
916    if keyring is not None:
917        q.use_tsig(keyring, keyname, algorithm=keyalgorithm)
918    wire = q.to_wire()
919    (af, destination, source) = _destination_and_source(where, port,
920                                                        source, source_port)
921    if use_udp and rdtype != dns.rdatatype.IXFR:
922        raise ValueError('cannot do a UDP AXFR')
923    sock_type = socket.SOCK_DGRAM if use_udp else socket.SOCK_STREAM
924    with _make_socket(af, sock_type, source) as s:
925        (_, expiration) = _compute_times(lifetime)
926        _connect(s, destination, expiration)
927        l = len(wire)
928        if use_udp:
929            _wait_for_writable(s, expiration)
930            s.send(wire)
931        else:
932            tcpmsg = struct.pack("!H", l) + wire
933            _net_write(s, tcpmsg, expiration)
934        done = False
935        delete_mode = True
936        expecting_SOA = False
937        soa_rrset = None
938        if relativize:
939            origin = zone
940            oname = dns.name.empty
941        else:
942            origin = None
943            oname = zone
944        tsig_ctx = None
945        while not done:
946            (_, mexpiration) = _compute_times(timeout)
947            if mexpiration is None or \
948               (expiration is not None and mexpiration > expiration):
949                mexpiration = expiration
950            if use_udp:
951                _wait_for_readable(s, expiration)
952                (wire, from_address) = s.recvfrom(65535)
953            else:
954                ldata = _net_read(s, 2, mexpiration)
955                (l,) = struct.unpack("!H", ldata)
956                wire = _net_read(s, l, mexpiration)
957            is_ixfr = (rdtype == dns.rdatatype.IXFR)
958            r = dns.message.from_wire(wire, keyring=q.keyring,
959                                      request_mac=q.mac, xfr=True,
960                                      origin=origin, tsig_ctx=tsig_ctx,
961                                      multi=True, one_rr_per_rrset=is_ixfr)
962            rcode = r.rcode()
963            if rcode != dns.rcode.NOERROR:
964                raise TransferError(rcode)
965            tsig_ctx = r.tsig_ctx
966            answer_index = 0
967            if soa_rrset is None:
968                if not r.answer or r.answer[0].name != oname:
969                    raise dns.exception.FormError(
970                        "No answer or RRset not for qname")
971                rrset = r.answer[0]
972                if rrset.rdtype != dns.rdatatype.SOA:
973                    raise dns.exception.FormError("first RRset is not an SOA")
974                answer_index = 1
975                soa_rrset = rrset.copy()
976                if rdtype == dns.rdatatype.IXFR:
977                    if dns.serial.Serial(soa_rrset[0].serial) <= serial:
978                        #
979                        # We're already up-to-date.
980                        #
981                        done = True
982                    else:
983                        expecting_SOA = True
984            #
985            # Process SOAs in the answer section (other than the initial
986            # SOA in the first message).
987            #
988            for rrset in r.answer[answer_index:]:
989                if done:
990                    raise dns.exception.FormError("answers after final SOA")
991                if rrset.rdtype == dns.rdatatype.SOA and rrset.name == oname:
992                    if expecting_SOA:
993                        if rrset[0].serial != serial:
994                            raise dns.exception.FormError(
995                                "IXFR base serial mismatch")
996                        expecting_SOA = False
997                    elif rdtype == dns.rdatatype.IXFR:
998                        delete_mode = not delete_mode
999                    #
1000                    # If this SOA RRset is equal to the first we saw then we're
1001                    # finished. If this is an IXFR we also check that we're
1002                    # seeing the record in the expected part of the response.
1003                    #
1004                    if rrset == soa_rrset and \
1005                            (rdtype == dns.rdatatype.AXFR or
1006                             (rdtype == dns.rdatatype.IXFR and delete_mode)):
1007                        done = True
1008                elif expecting_SOA:
1009                    #
1010                    # We made an IXFR request and are expecting another
1011                    # SOA RR, but saw something else, so this must be an
1012                    # AXFR response.
1013                    #
1014                    rdtype = dns.rdatatype.AXFR
1015                    expecting_SOA = False
1016            if done and q.keyring and not r.had_tsig:
1017                raise dns.exception.FormError("missing TSIG")
1018            yield r
1019