1# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
2
3# Copyright (C) 2003-2017 Nominum, Inc.
4#
5# Permission to use, copy, modify, and distribute this software and its
6# documentation for any purpose with or without fee is hereby granted,
7# provided that the above copyright notice and this permission notice
8# appear in all copies.
9#
10# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
11# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
13# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
15# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
16# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17
18"""Talk to a DNS server."""
19
20import socket
21import struct
22import time
23
24import dns.asyncbackend
25import dns.exception
26import dns.inet
27import dns.name
28import dns.message
29import dns.rcode
30import dns.rdataclass
31import dns.rdatatype
32
33from dns.query import _compute_times, _matches_destination, BadResponse, ssl, \
34    UDPMode
35
36
37# for brevity
38_lltuple = dns.inet.low_level_address_tuple
39
40
41def _source_tuple(af, address, port):
42    # Make a high level source tuple, or return None if address and port
43    # are both None
44    if address or port:
45        if address is None:
46            if af == socket.AF_INET:
47                address = '0.0.0.0'
48            elif af == socket.AF_INET6:
49                address = '::'
50            else:
51                raise NotImplementedError(f'unknown address family {af}')
52        return (address, port)
53    else:
54        return None
55
56
57def _timeout(expiration, now=None):
58    if expiration:
59        if not now:
60            now = time.time()
61        return max(expiration - now, 0)
62    else:
63        return None
64
65
66async def send_udp(sock, what, destination, expiration=None):
67    """Send a DNS message to the specified UDP socket.
68
69    *sock*, a ``dns.asyncbackend.DatagramSocket``.
70
71    *what*, a ``bytes`` or ``dns.message.Message``, the message to send.
72
73    *destination*, a destination tuple appropriate for the address family
74    of the socket, specifying where to send the query.
75
76    *expiration*, a ``float`` or ``None``, the absolute time at which
77    a timeout exception should be raised.  If ``None``, no timeout will
78    occur.
79
80    Returns an ``(int, float)`` tuple of bytes sent and the sent time.
81    """
82
83    if isinstance(what, dns.message.Message):
84        what = what.to_wire()
85    sent_time = time.time()
86    n = await sock.sendto(what, destination, _timeout(expiration, sent_time))
87    return (n, sent_time)
88
89
90async def receive_udp(sock, destination=None, expiration=None,
91                      ignore_unexpected=False, one_rr_per_rrset=False,
92                      keyring=None, request_mac=b'', ignore_trailing=False,
93                      raise_on_truncation=False):
94    """Read a DNS message from a UDP socket.
95
96    *sock*, a ``dns.asyncbackend.DatagramSocket``.
97
98    See :py:func:`dns.query.receive_udp()` for the documentation of the other
99    parameters, exceptions, and return type of this method.
100    """
101
102    wire = b''
103    while 1:
104        (wire, from_address) = await sock.recvfrom(65535, _timeout(expiration))
105        if _matches_destination(sock.family, from_address, destination,
106                                ignore_unexpected):
107            break
108    received_time = time.time()
109    r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
110                              one_rr_per_rrset=one_rr_per_rrset,
111                              ignore_trailing=ignore_trailing,
112                              raise_on_truncation=raise_on_truncation)
113    return (r, received_time, from_address)
114
115async def udp(q, where, timeout=None, port=53, source=None, source_port=0,
116              ignore_unexpected=False, one_rr_per_rrset=False,
117              ignore_trailing=False, raise_on_truncation=False, sock=None,
118              backend=None):
119    """Return the response obtained after sending a query via UDP.
120
121    *sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``,
122    the socket to use for the query.  If ``None``, the default, a
123    socket is created.  Note that if a socket is provided, the
124    *source*, *source_port*, and *backend* are ignored.
125
126    *backend*, a ``dns.asyncbackend.Backend``, or ``None``.  If ``None``,
127    the default, then dnspython will use the default backend.
128
129    See :py:func:`dns.query.udp()` for the documentation of the other
130    parameters, exceptions, and return type of this method.
131    """
132    wire = q.to_wire()
133    (begin_time, expiration) = _compute_times(timeout)
134    s = None
135    # After 3.6 is no longer supported, this can use an AsyncExitStack.
136    try:
137        af = dns.inet.af_for_address(where)
138        destination = _lltuple((where, port), af)
139        if sock:
140            s = sock
141        else:
142            if not backend:
143                backend = dns.asyncbackend.get_default_backend()
144            stuple = _source_tuple(af, source, source_port)
145            s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple)
146        await send_udp(s, wire, destination, expiration)
147        (r, received_time, _) = await receive_udp(s, destination, expiration,
148                                                  ignore_unexpected,
149                                                  one_rr_per_rrset,
150                                                  q.keyring, q.mac,
151                                                  ignore_trailing,
152                                                  raise_on_truncation)
153        r.time = received_time - begin_time
154        if not q.is_response(r):
155            raise BadResponse
156        return r
157    finally:
158        if not sock and s:
159            await s.close()
160
161async def udp_with_fallback(q, where, timeout=None, port=53, source=None,
162                            source_port=0, ignore_unexpected=False,
163                            one_rr_per_rrset=False, ignore_trailing=False,
164                            udp_sock=None, tcp_sock=None, backend=None):
165    """Return the response to the query, trying UDP first and falling back
166    to TCP if UDP results in a truncated response.
167
168    *udp_sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``,
169    the socket to use for the UDP query.  If ``None``, the default, a
170    socket is created.  Note that if a socket is provided the *source*,
171    *source_port*, and *backend* are ignored for the UDP query.
172
173    *tcp_sock*, a ``dns.asyncbackend.StreamSocket``, or ``None``, the
174    socket to use for the TCP query.  If ``None``, the default, a
175    socket is created.  Note that if a socket is provided *where*,
176    *source*, *source_port*, and *backend*  are ignored for the TCP query.
177
178    *backend*, a ``dns.asyncbackend.Backend``, or ``None``.  If ``None``,
179    the default, then dnspython will use the default backend.
180
181    See :py:func:`dns.query.udp_with_fallback()` for the documentation
182    of the other parameters, exceptions, and return type of this
183    method.
184    """
185    try:
186        response = await udp(q, where, timeout, port, source, source_port,
187                             ignore_unexpected, one_rr_per_rrset,
188                             ignore_trailing, True, udp_sock, backend)
189        return (response, False)
190    except dns.message.Truncated:
191        response = await tcp(q, where, timeout, port, source, source_port,
192                             one_rr_per_rrset, ignore_trailing, tcp_sock,
193                             backend)
194        return (response, True)
195
196
197async def send_tcp(sock, what, expiration=None):
198    """Send a DNS message to the specified TCP socket.
199
200    *sock*, a ``dns.asyncbackend.StreamSocket``.
201
202    See :py:func:`dns.query.send_tcp()` for the documentation of the other
203    parameters, exceptions, and return type of this method.
204    """
205
206    if isinstance(what, dns.message.Message):
207        what = what.to_wire()
208    l = len(what)
209    # copying the wire into tcpmsg is inefficient, but lets us
210    # avoid writev() or doing a short write that would get pushed
211    # onto the net
212    tcpmsg = struct.pack("!H", l) + what
213    sent_time = time.time()
214    await sock.sendall(tcpmsg, _timeout(expiration, sent_time))
215    return (len(tcpmsg), sent_time)
216
217
218async def _read_exactly(sock, count, expiration):
219    """Read the specified number of bytes from stream.  Keep trying until we
220    either get the desired amount, or we hit EOF.
221    """
222    s = b''
223    while count > 0:
224        n = await sock.recv(count, _timeout(expiration))
225        if n == b'':
226            raise EOFError
227        count = count - len(n)
228        s = s + n
229    return s
230
231
232async def receive_tcp(sock, expiration=None, one_rr_per_rrset=False,
233                      keyring=None, request_mac=b'', ignore_trailing=False):
234    """Read a DNS message from a TCP socket.
235
236    *sock*, a ``dns.asyncbackend.StreamSocket``.
237
238    See :py:func:`dns.query.receive_tcp()` for the documentation of the other
239    parameters, exceptions, and return type of this method.
240    """
241
242    ldata = await _read_exactly(sock, 2, expiration)
243    (l,) = struct.unpack("!H", ldata)
244    wire = await _read_exactly(sock, l, expiration)
245    received_time = time.time()
246    r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
247                              one_rr_per_rrset=one_rr_per_rrset,
248                              ignore_trailing=ignore_trailing)
249    return (r, received_time)
250
251
252async def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
253              one_rr_per_rrset=False, ignore_trailing=False, sock=None,
254              backend=None):
255    """Return the response obtained after sending a query via TCP.
256
257    *sock*, a ``dns.asyncbacket.StreamSocket``, or ``None``, the
258    socket to use for the query.  If ``None``, the default, a socket
259    is created.  Note that if a socket is provided
260    *where*, *port*, *source*, *source_port*, and *backend* are ignored.
261
262    *backend*, a ``dns.asyncbackend.Backend``, or ``None``.  If ``None``,
263    the default, then dnspython will use the default backend.
264
265    See :py:func:`dns.query.tcp()` for the documentation of the other
266    parameters, exceptions, and return type of this method.
267    """
268
269    wire = q.to_wire()
270    (begin_time, expiration) = _compute_times(timeout)
271    s = None
272    # After 3.6 is no longer supported, this can use an AsyncExitStack.
273    try:
274        if sock:
275            # Verify that the socket is connected, as if it's not connected,
276            # it's not writable, and the polling in send_tcp() will time out or
277            # hang forever.
278            await sock.getpeername()
279            s = sock
280        else:
281            # These are simple (address, port) pairs, not
282            # family-dependent tuples you pass to lowlevel socket
283            # code.
284            af = dns.inet.af_for_address(where)
285            stuple = _source_tuple(af, source, source_port)
286            dtuple = (where, port)
287            if not backend:
288                backend = dns.asyncbackend.get_default_backend()
289            s = await backend.make_socket(af, socket.SOCK_STREAM, 0, stuple,
290                                          dtuple, timeout)
291        await send_tcp(s, wire, expiration)
292        (r, received_time) = await receive_tcp(s, expiration, one_rr_per_rrset,
293                                               q.keyring, q.mac,
294                                               ignore_trailing)
295        r.time = received_time - begin_time
296        if not q.is_response(r):
297            raise BadResponse
298        return r
299    finally:
300        if not sock and s:
301            await s.close()
302
303async def tls(q, where, timeout=None, port=853, source=None, source_port=0,
304              one_rr_per_rrset=False, ignore_trailing=False, sock=None,
305              backend=None, ssl_context=None, server_hostname=None):
306    """Return the response obtained after sending a query via TLS.
307
308    *sock*, an ``asyncbackend.StreamSocket``, or ``None``, the socket
309    to use for the query.  If ``None``, the default, a socket is
310    created.  Note that if a socket is provided, it must be a
311    connected SSL stream socket, and *where*, *port*,
312    *source*, *source_port*, *backend*, *ssl_context*, and *server_hostname*
313    are ignored.
314
315    *backend*, a ``dns.asyncbackend.Backend``, or ``None``.  If ``None``,
316    the default, then dnspython will use the default backend.
317
318    See :py:func:`dns.query.tls()` for the documentation of the other
319    parameters, exceptions, and return type of this method.
320    """
321    # After 3.6 is no longer supported, this can use an AsyncExitStack.
322    (begin_time, expiration) = _compute_times(timeout)
323    if not sock:
324        if ssl_context is None:
325            ssl_context = ssl.create_default_context()
326            if server_hostname is None:
327                ssl_context.check_hostname = False
328        else:
329            ssl_context = None
330            server_hostname = None
331        af = dns.inet.af_for_address(where)
332        stuple = _source_tuple(af, source, source_port)
333        dtuple = (where, port)
334        if not backend:
335            backend = dns.asyncbackend.get_default_backend()
336        s = await backend.make_socket(af, socket.SOCK_STREAM, 0, stuple,
337                                      dtuple, timeout, ssl_context,
338                                      server_hostname)
339    else:
340        s = sock
341    try:
342        timeout = _timeout(expiration)
343        response = await tcp(q, where, timeout, port, source, source_port,
344                             one_rr_per_rrset, ignore_trailing, s, backend)
345        end_time = time.time()
346        response.time = end_time - begin_time
347        return response
348    finally:
349        if not sock and s:
350            await s.close()
351
352async def inbound_xfr(where, txn_manager, query=None,
353                      port=53, timeout=None, lifetime=None, source=None,
354                      source_port=0, udp_mode=UDPMode.NEVER,
355                      backend=None):
356    """Conduct an inbound transfer and apply it via a transaction from the
357    txn_manager.
358
359    *backend*, a ``dns.asyncbackend.Backend``, or ``None``.  If ``None``,
360    the default, then dnspython will use the default backend.
361
362    See :py:func:`dns.query.inbound_xfr()` for the documentation of
363    the other parameters, exceptions, and return type of this method.
364    """
365    if query is None:
366        (query, serial) = dns.xfr.make_query(txn_manager)
367    rdtype = query.question[0].rdtype
368    is_ixfr = rdtype == dns.rdatatype.IXFR
369    origin = txn_manager.from_wire_origin()
370    wire = query.to_wire()
371    af = dns.inet.af_for_address(where)
372    stuple = _source_tuple(af, source, source_port)
373    dtuple = (where, port)
374    (_, expiration) = _compute_times(lifetime)
375    retry = True
376    while retry:
377        retry = False
378        if is_ixfr and udp_mode != UDPMode.NEVER:
379            sock_type = socket.SOCK_DGRAM
380            is_udp = True
381        else:
382            sock_type = socket.SOCK_STREAM
383            is_udp = False
384        if not backend:
385            backend = dns.asyncbackend.get_default_backend()
386        s = await backend.make_socket(af, sock_type, 0, stuple, dtuple,
387                                      _timeout(expiration))
388        async with s:
389            if is_udp:
390                await s.sendto(wire, dtuple, _timeout(expiration))
391            else:
392                tcpmsg = struct.pack("!H", len(wire)) + wire
393                await s.sendall(tcpmsg, expiration)
394            with dns.xfr.Inbound(txn_manager, rdtype, serial,
395                                 is_udp) as inbound:
396                done = False
397                tsig_ctx = None
398                while not done:
399                    (_, mexpiration) = _compute_times(timeout)
400                    if mexpiration is None or \
401                       (expiration is not None and mexpiration > expiration):
402                        mexpiration = expiration
403                    if is_udp:
404                        destination = _lltuple((where, port), af)
405                        while True:
406                            timeout = _timeout(mexpiration)
407                            (rwire, from_address) = await s.recvfrom(65535,
408                                                                     timeout)
409                            if _matches_destination(af, from_address,
410                                                    destination, True):
411                                break
412                    else:
413                        ldata = await _read_exactly(s, 2, mexpiration)
414                        (l,) = struct.unpack("!H", ldata)
415                        rwire = await _read_exactly(s, l, mexpiration)
416                    is_ixfr = (rdtype == dns.rdatatype.IXFR)
417                    r = dns.message.from_wire(rwire, keyring=query.keyring,
418                                              request_mac=query.mac, xfr=True,
419                                              origin=origin, tsig_ctx=tsig_ctx,
420                                              multi=(not is_udp),
421                                              one_rr_per_rrset=is_ixfr)
422                    try:
423                        done = inbound.process_message(r)
424                    except dns.xfr.UseTCP:
425                        assert is_udp  # should not happen if we used TCP!
426                        if udp_mode == UDPMode.ONLY:
427                            raise
428                        done = True
429                        retry = True
430                        udp_mode = UDPMode.NEVER
431                        continue
432                    tsig_ctx = r.tsig_ctx
433                if not retry and query.keyring and not r.had_tsig:
434                    raise dns.exception.FormError("missing TSIG")
435