1# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
2
3# Copyright (C) 2003-2017 Nominum, Inc.
4#
5# Permission to use, copy, modify, and distribute this software and its
6# documentation for any purpose with or without fee is hereby granted,
7# provided that the above copyright notice and this permission notice
8# appear in all copies.
9#
10# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
11# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
13# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
15# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
16# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17
18"""Talk to a DNS server."""
19
20from __future__ import generators
21
22import errno
23import select
24import socket
25import struct
26import sys
27import time
28
29import dns.exception
30import dns.inet
31import dns.name
32import dns.message
33import dns.rcode
34import dns.rdataclass
35import dns.rdatatype
36from ._compat import long, string_types, PY3
37
38if PY3:
39    select_error = OSError
40else:
41    select_error = select.error
42
43# Function used to create a socket.  Can be overridden if needed in special
44# situations.
45socket_factory = socket.socket
46
47class UnexpectedSource(dns.exception.DNSException):
48    """A DNS query response came from an unexpected address or port."""
49
50
51class BadResponse(dns.exception.FormError):
52    """A DNS query response does not respond to the question asked."""
53
54
55class TransferError(dns.exception.DNSException):
56    """A zone transfer response got a non-zero rcode."""
57
58    def __init__(self, rcode):
59        message = 'Zone transfer error: %s' % dns.rcode.to_text(rcode)
60        super(TransferError, self).__init__(message)
61        self.rcode = rcode
62
63
64def _compute_expiration(timeout):
65    if timeout is None:
66        return None
67    else:
68        return time.time() + timeout
69
70# This module can use either poll() or select() as the "polling backend".
71#
72# A backend function takes an fd, bools for readability, writablity, and
73# error detection, and a timeout.
74
75def _poll_for(fd, readable, writable, error, timeout):
76    """Poll polling backend."""
77
78    event_mask = 0
79    if readable:
80        event_mask |= select.POLLIN
81    if writable:
82        event_mask |= select.POLLOUT
83    if error:
84        event_mask |= select.POLLERR
85
86    pollable = select.poll()
87    pollable.register(fd, event_mask)
88
89    if timeout:
90        event_list = pollable.poll(long(timeout * 1000))
91    else:
92        event_list = pollable.poll()
93
94    return bool(event_list)
95
96
97def _select_for(fd, readable, writable, error, timeout):
98    """Select polling backend."""
99
100    rset, wset, xset = [], [], []
101
102    if readable:
103        rset = [fd]
104    if writable:
105        wset = [fd]
106    if error:
107        xset = [fd]
108
109    if timeout is None:
110        (rcount, wcount, xcount) = select.select(rset, wset, xset)
111    else:
112        (rcount, wcount, xcount) = select.select(rset, wset, xset, timeout)
113
114    return bool((rcount or wcount or xcount))
115
116
117def _wait_for(fd, readable, writable, error, expiration):
118    # Use the selected polling backend to wait for any of the specified
119    # events.  An "expiration" absolute time is converted into a relative
120    # timeout.
121
122    done = False
123    while not done:
124        if expiration is None:
125            timeout = None
126        else:
127            timeout = expiration - time.time()
128            if timeout <= 0.0:
129                raise dns.exception.Timeout
130        try:
131            if not _polling_backend(fd, readable, writable, error, timeout):
132                raise dns.exception.Timeout
133        except select_error as e:
134            if e.args[0] != errno.EINTR:
135                raise e
136        done = True
137
138
139def _set_polling_backend(fn):
140    # Internal API. Do not use.
141
142    global _polling_backend
143
144    _polling_backend = fn
145
146if hasattr(select, 'poll'):
147    # Prefer poll() on platforms that support it because it has no
148    # limits on the maximum value of a file descriptor (plus it will
149    # be more efficient for high values).
150    _polling_backend = _poll_for
151else:
152    _polling_backend = _select_for
153
154
155def _wait_for_readable(s, expiration):
156    _wait_for(s, True, False, True, expiration)
157
158
159def _wait_for_writable(s, expiration):
160    _wait_for(s, False, True, True, expiration)
161
162
163def _addresses_equal(af, a1, a2):
164    # Convert the first value of the tuple, which is a textual format
165    # address into binary form, so that we are not confused by different
166    # textual representations of the same address
167    try:
168        n1 = dns.inet.inet_pton(af, a1[0])
169        n2 = dns.inet.inet_pton(af, a2[0])
170    except dns.exception.SyntaxError:
171        return False
172    return n1 == n2 and a1[1:] == a2[1:]
173
174
175def _destination_and_source(af, where, port, source, source_port):
176    # Apply defaults and compute destination and source tuples
177    # suitable for use in connect(), sendto(), or bind().
178    if af is None:
179        try:
180            af = dns.inet.af_for_address(where)
181        except Exception:
182            af = dns.inet.AF_INET
183    if af == dns.inet.AF_INET:
184        destination = (where, port)
185        if source is not None or source_port != 0:
186            if source is None:
187                source = '0.0.0.0'
188            source = (source, source_port)
189    elif af == dns.inet.AF_INET6:
190        destination = (where, port, 0, 0)
191        if source is not None or source_port != 0:
192            if source is None:
193                source = '::'
194            source = (source, source_port, 0, 0)
195    return (af, destination, source)
196
197
198def send_udp(sock, what, destination, expiration=None):
199    """Send a DNS message to the specified UDP socket.
200
201    *sock*, a ``socket``.
202
203    *what*, a ``binary`` or ``dns.message.Message``, the message to send.
204
205    *destination*, a destination tuple appropriate for the address family
206    of the socket, specifying where to send the query.
207
208    *expiration*, a ``float`` or ``None``, the absolute time at which
209    a timeout exception should be raised.  If ``None``, no timeout will
210    occur.
211
212    Returns an ``(int, float)`` tuple of bytes sent and the sent time.
213    """
214
215    if isinstance(what, dns.message.Message):
216        what = what.to_wire()
217    _wait_for_writable(sock, expiration)
218    sent_time = time.time()
219    n = sock.sendto(what, destination)
220    return (n, sent_time)
221
222
223def receive_udp(sock, destination, expiration=None,
224                ignore_unexpected=False, one_rr_per_rrset=False,
225                keyring=None, request_mac=b'', ignore_trailing=False):
226    """Read a DNS message from a UDP socket.
227
228    *sock*, a ``socket``.
229
230    *destination*, a destination tuple appropriate for the address family
231    of the socket, specifying where the associated query was sent.
232
233    *expiration*, a ``float`` or ``None``, the absolute time at which
234    a timeout exception should be raised.  If ``None``, no timeout will
235    occur.
236
237    *ignore_unexpected*, a ``bool``.  If ``True``, ignore responses from
238    unexpected sources.
239
240    *one_rr_per_rrset*, a ``bool``.  If ``True``, put each RR into its own
241    RRset.
242
243    *keyring*, a ``dict``, the keyring to use for TSIG.
244
245    *request_mac*, a ``binary``, the MAC of the request (for TSIG).
246
247    *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
248    junk at end of the received message.
249
250    Raises if the message is malformed, if network errors occur, of if
251    there is a timeout.
252
253    Returns a ``dns.message.Message`` object.
254    """
255
256    wire = b''
257    while 1:
258        _wait_for_readable(sock, expiration)
259        (wire, from_address) = sock.recvfrom(65535)
260        if _addresses_equal(sock.family, from_address, destination) or \
261           (dns.inet.is_multicast(destination[0]) and
262            from_address[1:] == destination[1:]):
263            break
264        if not ignore_unexpected:
265            raise UnexpectedSource('got a response from '
266                                   '%s instead of %s' % (from_address,
267                                                         destination))
268    received_time = time.time()
269    r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
270                              one_rr_per_rrset=one_rr_per_rrset,
271                              ignore_trailing=ignore_trailing)
272    return (r, received_time)
273
274def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
275        ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False):
276    """Return the response obtained after sending a query via UDP.
277
278    *q*, a ``dns.message.Message``, the query to send
279
280    *where*, a ``text`` containing an IPv4 or IPv6 address,  where
281    to send the message.
282
283    *timeout*, a ``float`` or ``None``, the number of seconds to wait before the
284    query times out.  If ``None``, the default, wait forever.
285
286    *port*, an ``int``, the port send the message to.  The default is 53.
287
288    *af*, an ``int``, the address family to use.  The default is ``None``,
289    which causes the address family to use to be inferred from the form of
290    *where*.  If the inference attempt fails, AF_INET is used.  This
291    parameter is historical; you need never set it.
292
293    *source*, a ``text`` containing an IPv4 or IPv6 address, specifying
294    the source address.  The default is the wildcard address.
295
296    *source_port*, an ``int``, the port from which to send the message.
297    The default is 0.
298
299    *ignore_unexpected*, a ``bool``.  If ``True``, ignore responses from
300    unexpected sources.
301
302    *one_rr_per_rrset*, a ``bool``.  If ``True``, put each RR into its own
303    RRset.
304
305    *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
306    junk at end of the received message.
307
308    Returns a ``dns.message.Message``.
309    """
310
311    wire = q.to_wire()
312    (af, destination, source) = _destination_and_source(af, where, port,
313                                                        source, source_port)
314    s = socket_factory(af, socket.SOCK_DGRAM, 0)
315    received_time = None
316    sent_time = None
317    try:
318        expiration = _compute_expiration(timeout)
319        s.setblocking(0)
320        if source is not None:
321            s.bind(source)
322        (_, sent_time) = send_udp(s, wire, destination, expiration)
323        (r, received_time) = receive_udp(s, destination, expiration,
324                                         ignore_unexpected, one_rr_per_rrset,
325                                         q.keyring, q.mac, ignore_trailing)
326    finally:
327        if sent_time is None or received_time is None:
328            response_time = 0
329        else:
330            response_time = received_time - sent_time
331        s.close()
332    r.time = response_time
333    if not q.is_response(r):
334        raise BadResponse
335    return r
336
337
338def _net_read(sock, count, expiration):
339    """Read the specified number of bytes from sock.  Keep trying until we
340    either get the desired amount, or we hit EOF.
341    A Timeout exception will be raised if the operation is not completed
342    by the expiration time.
343    """
344    s = b''
345    while count > 0:
346        _wait_for_readable(sock, expiration)
347        n = sock.recv(count)
348        if n == b'':
349            raise EOFError
350        count = count - len(n)
351        s = s + n
352    return s
353
354
355def _net_write(sock, data, expiration):
356    """Write the specified data to the socket.
357    A Timeout exception will be raised if the operation is not completed
358    by the expiration time.
359    """
360    current = 0
361    l = len(data)
362    while current < l:
363        _wait_for_writable(sock, expiration)
364        current += sock.send(data[current:])
365
366
367def send_tcp(sock, what, expiration=None):
368    """Send a DNS message to the specified TCP socket.
369
370    *sock*, a ``socket``.
371
372    *what*, a ``binary`` or ``dns.message.Message``, the message to send.
373
374    *expiration*, a ``float`` or ``None``, the absolute time at which
375    a timeout exception should be raised.  If ``None``, no timeout will
376    occur.
377
378    Returns an ``(int, float)`` tuple of bytes sent and the sent time.
379    """
380
381    if isinstance(what, dns.message.Message):
382        what = what.to_wire()
383    l = len(what)
384    # copying the wire into tcpmsg is inefficient, but lets us
385    # avoid writev() or doing a short write that would get pushed
386    # onto the net
387    tcpmsg = struct.pack("!H", l) + what
388    _wait_for_writable(sock, expiration)
389    sent_time = time.time()
390    _net_write(sock, tcpmsg, expiration)
391    return (len(tcpmsg), sent_time)
392
393def receive_tcp(sock, expiration=None, one_rr_per_rrset=False,
394                keyring=None, request_mac=b'', ignore_trailing=False):
395    """Read a DNS message from a TCP socket.
396
397    *sock*, a ``socket``.
398
399    *expiration*, a ``float`` or ``None``, the absolute time at which
400    a timeout exception should be raised.  If ``None``, no timeout will
401    occur.
402
403    *one_rr_per_rrset*, a ``bool``.  If ``True``, put each RR into its own
404    RRset.
405
406    *keyring*, a ``dict``, the keyring to use for TSIG.
407
408    *request_mac*, a ``binary``, the MAC of the request (for TSIG).
409
410    *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
411    junk at end of the received message.
412
413    Raises if the message is malformed, if network errors occur, of if
414    there is a timeout.
415
416    Returns a ``dns.message.Message`` object.
417    """
418
419    ldata = _net_read(sock, 2, expiration)
420    (l,) = struct.unpack("!H", ldata)
421    wire = _net_read(sock, l, expiration)
422    received_time = time.time()
423    r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
424                              one_rr_per_rrset=one_rr_per_rrset,
425                              ignore_trailing=ignore_trailing)
426    return (r, received_time)
427
428def _connect(s, address):
429    try:
430        s.connect(address)
431    except socket.error:
432        (ty, v) = sys.exc_info()[:2]
433
434        if hasattr(v, 'errno'):
435            v_err = v.errno
436        else:
437            v_err = v[0]
438        if v_err not in [errno.EINPROGRESS, errno.EWOULDBLOCK, errno.EALREADY]:
439            raise v
440
441
442def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
443        one_rr_per_rrset=False, ignore_trailing=False):
444    """Return the response obtained after sending a query via TCP.
445
446    *q*, a ``dns.message.Message``, the query to send
447
448    *where*, a ``text`` containing an IPv4 or IPv6 address,  where
449    to send the message.
450
451    *timeout*, a ``float`` or ``None``, the number of seconds to wait before the
452    query times out.  If ``None``, the default, wait forever.
453
454    *port*, an ``int``, the port send the message to.  The default is 53.
455
456    *af*, an ``int``, the address family to use.  The default is ``None``,
457    which causes the address family to use to be inferred from the form of
458    *where*.  If the inference attempt fails, AF_INET is used.  This
459    parameter is historical; you need never set it.
460
461    *source*, a ``text`` containing an IPv4 or IPv6 address, specifying
462    the source address.  The default is the wildcard address.
463
464    *source_port*, an ``int``, the port from which to send the message.
465    The default is 0.
466
467    *one_rr_per_rrset*, a ``bool``.  If ``True``, put each RR into its own
468    RRset.
469
470    *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
471    junk at end of the received message.
472
473    Returns a ``dns.message.Message``.
474    """
475
476    wire = q.to_wire()
477    (af, destination, source) = _destination_and_source(af, where, port,
478                                                        source, source_port)
479    s = socket_factory(af, socket.SOCK_STREAM, 0)
480    begin_time = None
481    received_time = None
482    try:
483        expiration = _compute_expiration(timeout)
484        s.setblocking(0)
485        begin_time = time.time()
486        if source is not None:
487            s.bind(source)
488        _connect(s, destination)
489        send_tcp(s, wire, expiration)
490        (r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset,
491                                         q.keyring, q.mac, ignore_trailing)
492    finally:
493        if begin_time is None or received_time is None:
494            response_time = 0
495        else:
496            response_time = received_time - begin_time
497        s.close()
498    r.time = response_time
499    if not q.is_response(r):
500        raise BadResponse
501    return r
502
503
504def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
505        timeout=None, port=53, keyring=None, keyname=None, relativize=True,
506        af=None, lifetime=None, source=None, source_port=0, serial=0,
507        use_udp=False, keyalgorithm=dns.tsig.default_algorithm):
508    """Return a generator for the responses to a zone transfer.
509
510    *where*.  If the inference attempt fails, AF_INET is used.  This
511    parameter is historical; you need never set it.
512
513    *zone*, a ``dns.name.Name`` or ``text``, the name of the zone to transfer.
514
515    *rdtype*, an ``int`` or ``text``, the type of zone transfer.  The
516    default is ``dns.rdatatype.AXFR``.  ``dns.rdatatype.IXFR`` can be
517    used to do an incremental transfer instead.
518
519    *rdclass*, an ``int`` or ``text``, the class of the zone transfer.
520    The default is ``dns.rdataclass.IN``.
521
522    *timeout*, a ``float``, the number of seconds to wait for each
523    response message.  If None, the default, wait forever.
524
525    *port*, an ``int``, the port send the message to.  The default is 53.
526
527    *keyring*, a ``dict``, the keyring to use for TSIG.
528
529    *keyname*, a ``dns.name.Name`` or ``text``, the name of the TSIG
530    key to use.
531
532    *relativize*, a ``bool``.  If ``True``, all names in the zone will be
533    relativized to the zone origin.  It is essential that the
534    relativize setting matches the one specified to
535    ``dns.zone.from_xfr()`` if using this generator to make a zone.
536
537    *af*, an ``int``, the address family to use.  The default is ``None``,
538    which causes the address family to use to be inferred from the form of
539    *where*.  If the inference attempt fails, AF_INET is used.  This
540    parameter is historical; you need never set it.
541
542    *lifetime*, a ``float``, the total number of seconds to spend
543    doing the transfer.  If ``None``, the default, then there is no
544    limit on the time the transfer may take.
545
546    *source*, a ``text`` containing an IPv4 or IPv6 address, specifying
547    the source address.  The default is the wildcard address.
548
549    *source_port*, an ``int``, the port from which to send the message.
550    The default is 0.
551
552    *serial*, an ``int``, the SOA serial number to use as the base for
553    an IXFR diff sequence (only meaningful if *rdtype* is
554    ``dns.rdatatype.IXFR``).
555
556    *use_udp*, a ``bool``.  If ``True``, use UDP (only meaningful for IXFR).
557
558    *keyalgorithm*, a ``dns.name.Name`` or ``text``, the TSIG algorithm to use.
559
560    Raises on errors, and so does the generator.
561
562    Returns a generator of ``dns.message.Message`` objects.
563    """
564
565    if isinstance(zone, string_types):
566        zone = dns.name.from_text(zone)
567    if isinstance(rdtype, string_types):
568        rdtype = dns.rdatatype.from_text(rdtype)
569    q = dns.message.make_query(zone, rdtype, rdclass)
570    if rdtype == dns.rdatatype.IXFR:
571        rrset = dns.rrset.from_text(zone, 0, 'IN', 'SOA',
572                                    '. . %u 0 0 0 0' % serial)
573        q.authority.append(rrset)
574    if keyring is not None:
575        q.use_tsig(keyring, keyname, algorithm=keyalgorithm)
576    wire = q.to_wire()
577    (af, destination, source) = _destination_and_source(af, where, port,
578                                                        source, source_port)
579    if use_udp:
580        if rdtype != dns.rdatatype.IXFR:
581            raise ValueError('cannot do a UDP AXFR')
582        s = socket_factory(af, socket.SOCK_DGRAM, 0)
583    else:
584        s = socket_factory(af, socket.SOCK_STREAM, 0)
585    s.setblocking(0)
586    if source is not None:
587        s.bind(source)
588    expiration = _compute_expiration(lifetime)
589    _connect(s, destination)
590    l = len(wire)
591    if use_udp:
592        _wait_for_writable(s, expiration)
593        s.send(wire)
594    else:
595        tcpmsg = struct.pack("!H", l) + wire
596        _net_write(s, tcpmsg, expiration)
597    done = False
598    delete_mode = True
599    expecting_SOA = False
600    soa_rrset = None
601    if relativize:
602        origin = zone
603        oname = dns.name.empty
604    else:
605        origin = None
606        oname = zone
607    tsig_ctx = None
608    first = True
609    while not done:
610        mexpiration = _compute_expiration(timeout)
611        if mexpiration is None or mexpiration > expiration:
612            mexpiration = expiration
613        if use_udp:
614            _wait_for_readable(s, expiration)
615            (wire, from_address) = s.recvfrom(65535)
616        else:
617            ldata = _net_read(s, 2, mexpiration)
618            (l,) = struct.unpack("!H", ldata)
619            wire = _net_read(s, l, mexpiration)
620        is_ixfr = (rdtype == dns.rdatatype.IXFR)
621        r = dns.message.from_wire(wire, keyring=q.keyring, request_mac=q.mac,
622                                  xfr=True, origin=origin, tsig_ctx=tsig_ctx,
623                                  multi=True, first=first,
624                                  one_rr_per_rrset=is_ixfr)
625        rcode = r.rcode()
626        if rcode != dns.rcode.NOERROR:
627            raise TransferError(rcode)
628        tsig_ctx = r.tsig_ctx
629        first = False
630        answer_index = 0
631        if soa_rrset is None:
632            if not r.answer or r.answer[0].name != oname:
633                raise dns.exception.FormError(
634                    "No answer or RRset not for qname")
635            rrset = r.answer[0]
636            if rrset.rdtype != dns.rdatatype.SOA:
637                raise dns.exception.FormError("first RRset is not an SOA")
638            answer_index = 1
639            soa_rrset = rrset.copy()
640            if rdtype == dns.rdatatype.IXFR:
641                if soa_rrset[0].serial <= serial:
642                    #
643                    # We're already up-to-date.
644                    #
645                    done = True
646                else:
647                    expecting_SOA = True
648        #
649        # Process SOAs in the answer section (other than the initial
650        # SOA in the first message).
651        #
652        for rrset in r.answer[answer_index:]:
653            if done:
654                raise dns.exception.FormError("answers after final SOA")
655            if rrset.rdtype == dns.rdatatype.SOA and rrset.name == oname:
656                if expecting_SOA:
657                    if rrset[0].serial != serial:
658                        raise dns.exception.FormError(
659                            "IXFR base serial mismatch")
660                    expecting_SOA = False
661                elif rdtype == dns.rdatatype.IXFR:
662                    delete_mode = not delete_mode
663                #
664                # If this SOA RRset is equal to the first we saw then we're
665                # finished. If this is an IXFR we also check that we're seeing
666                # the record in the expected part of the response.
667                #
668                if rrset == soa_rrset and \
669                        (rdtype == dns.rdatatype.AXFR or
670                         (rdtype == dns.rdatatype.IXFR and delete_mode)):
671                    done = True
672            elif expecting_SOA:
673                #
674                # We made an IXFR request and are expecting another
675                # SOA RR, but saw something else, so this must be an
676                # AXFR response.
677                #
678                rdtype = dns.rdatatype.AXFR
679                expecting_SOA = False
680        if done and q.keyring and not r.had_tsig:
681            raise dns.exception.FormError("missing TSIG")
682        yield r
683    s.close()
684