1# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license 2 3# Copyright (C) 2003-2017 Nominum, Inc. 4# 5# Permission to use, copy, modify, and distribute this software and its 6# documentation for any purpose with or without fee is hereby granted, 7# provided that the above copyright notice and this permission notice 8# appear in all copies. 9# 10# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES 11# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 12# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR 13# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 14# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 15# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT 16# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 17 18"""Talk to a DNS server.""" 19 20import contextlib 21import errno 22import os 23import select 24import socket 25import struct 26import time 27import base64 28import urllib.parse 29 30import dns.exception 31import dns.inet 32import dns.name 33import dns.message 34import dns.rcode 35import dns.rdataclass 36import dns.rdatatype 37import dns.serial 38 39try: 40 import requests 41 from requests_toolbelt.adapters.source import SourceAddressAdapter 42 from requests_toolbelt.adapters.host_header_ssl import HostHeaderSSLAdapter 43 have_doh = True 44except ImportError: # pragma: no cover 45 have_doh = False 46 47try: 48 import ssl 49except ImportError: # pragma: no cover 50 class ssl: # type: ignore 51 52 class WantReadException(Exception): 53 pass 54 55 class WantWriteException(Exception): 56 pass 57 58 class SSLSocket: 59 pass 60 61 def create_default_context(self, *args, **kwargs): 62 raise Exception('no ssl support') 63 64# Function used to create a socket. Can be overridden if needed in special 65# situations. 66socket_factory = socket.socket 67 68class UnexpectedSource(dns.exception.DNSException): 69 """A DNS query response came from an unexpected address or port.""" 70 71 72class BadResponse(dns.exception.FormError): 73 """A DNS query response does not respond to the question asked.""" 74 75 76class TransferError(dns.exception.DNSException): 77 """A zone transfer response got a non-zero rcode.""" 78 79 def __init__(self, rcode): 80 message = 'Zone transfer error: %s' % dns.rcode.to_text(rcode) 81 super().__init__(message) 82 self.rcode = rcode 83 84 85class NoDOH(dns.exception.DNSException): 86 """DNS over HTTPS (DOH) was requested but the requests module is not 87 available.""" 88 89 90def _compute_times(timeout): 91 now = time.time() 92 if timeout is None: 93 return (now, None) 94 else: 95 return (now, now + timeout) 96 97# This module can use either poll() or select() as the "polling backend". 98# 99# A backend function takes an fd, bools for readability, writablity, and 100# error detection, and a timeout. 101 102def _poll_for(fd, readable, writable, error, timeout): 103 """Poll polling backend.""" 104 105 event_mask = 0 106 if readable: 107 event_mask |= select.POLLIN 108 if writable: 109 event_mask |= select.POLLOUT 110 if error: 111 event_mask |= select.POLLERR 112 113 pollable = select.poll() 114 pollable.register(fd, event_mask) 115 116 if timeout: 117 event_list = pollable.poll(timeout * 1000) 118 else: 119 event_list = pollable.poll() 120 121 return bool(event_list) 122 123 124def _select_for(fd, readable, writable, error, timeout): 125 """Select polling backend.""" 126 127 rset, wset, xset = [], [], [] 128 129 if readable: 130 rset = [fd] 131 if writable: 132 wset = [fd] 133 if error: 134 xset = [fd] 135 136 if timeout is None: 137 (rcount, wcount, xcount) = select.select(rset, wset, xset) 138 else: 139 (rcount, wcount, xcount) = select.select(rset, wset, xset, timeout) 140 141 return bool((rcount or wcount or xcount)) 142 143 144def _wait_for(fd, readable, writable, error, expiration): 145 # Use the selected polling backend to wait for any of the specified 146 # events. An "expiration" absolute time is converted into a relative 147 # timeout. 148 149 done = False 150 while not done: 151 if expiration is None: 152 timeout = None 153 else: 154 timeout = expiration - time.time() 155 if timeout <= 0.0: 156 raise dns.exception.Timeout 157 try: 158 if isinstance(fd, ssl.SSLSocket) and readable and fd.pending() > 0: 159 return True 160 if not _polling_backend(fd, readable, writable, error, timeout): 161 raise dns.exception.Timeout 162 except OSError as e: # pragma: no cover 163 if e.args[0] != errno.EINTR: 164 raise e 165 done = True 166 167 168def _set_polling_backend(fn): 169 # Internal API. Do not use. 170 171 global _polling_backend 172 173 _polling_backend = fn 174 175if hasattr(select, 'poll'): 176 # Prefer poll() on platforms that support it because it has no 177 # limits on the maximum value of a file descriptor (plus it will 178 # be more efficient for high values). 179 _polling_backend = _poll_for 180else: 181 _polling_backend = _select_for # pragma: no cover 182 183 184def _wait_for_readable(s, expiration): 185 _wait_for(s, True, False, True, expiration) 186 187 188def _wait_for_writable(s, expiration): 189 _wait_for(s, False, True, True, expiration) 190 191 192def _addresses_equal(af, a1, a2): 193 # Convert the first value of the tuple, which is a textual format 194 # address into binary form, so that we are not confused by different 195 # textual representations of the same address 196 try: 197 n1 = dns.inet.inet_pton(af, a1[0]) 198 n2 = dns.inet.inet_pton(af, a2[0]) 199 except dns.exception.SyntaxError: 200 return False 201 return n1 == n2 and a1[1:] == a2[1:] 202 203 204def _matches_destination(af, from_address, destination, ignore_unexpected): 205 # Check that from_address is appropriate for a response to a query 206 # sent to destination. 207 if not destination: 208 return True 209 if _addresses_equal(af, from_address, destination) or \ 210 (dns.inet.is_multicast(destination[0]) and 211 from_address[1:] == destination[1:]): 212 return True 213 elif ignore_unexpected: 214 return False 215 raise UnexpectedSource(f'got a response from {from_address} instead of ' 216 f'{destination}') 217 218 219def _destination_and_source(where, port, source, source_port, 220 where_must_be_address=True): 221 # Apply defaults and compute destination and source tuples 222 # suitable for use in connect(), sendto(), or bind(). 223 af = None 224 destination = None 225 try: 226 af = dns.inet.af_for_address(where) 227 destination = where 228 except Exception: 229 if where_must_be_address: 230 raise 231 # URLs are ok so eat the exception 232 if source: 233 saf = dns.inet.af_for_address(source) 234 if af: 235 # We know the destination af, so source had better agree! 236 if saf != af: 237 raise ValueError('different address families for source ' + 238 'and destination') 239 else: 240 # We didn't know the destination af, but we know the source, 241 # so that's our af. 242 af = saf 243 if source_port and not source: 244 # Caller has specified a source_port but not an address, so we 245 # need to return a source, and we need to use the appropriate 246 # wildcard address as the address. 247 if af == socket.AF_INET: 248 source = '0.0.0.0' 249 elif af == socket.AF_INET6: 250 source = '::' 251 else: 252 raise ValueError('source_port specified but address family is ' 253 'unknown') 254 # Convert high-level (address, port) tuples into low-level address 255 # tuples. 256 if destination: 257 destination = dns.inet.low_level_address_tuple((destination, port), af) 258 if source: 259 source = dns.inet.low_level_address_tuple((source, source_port), af) 260 return (af, destination, source) 261 262def _make_socket(af, type, source, ssl_context=None, server_hostname=None): 263 s = socket_factory(af, type) 264 try: 265 s.setblocking(False) 266 if source is not None: 267 s.bind(source) 268 if ssl_context: 269 return ssl_context.wrap_socket(s, do_handshake_on_connect=False, 270 server_hostname=server_hostname) 271 else: 272 return s 273 except Exception: 274 s.close() 275 raise 276 277def https(q, where, timeout=None, port=443, source=None, source_port=0, 278 one_rr_per_rrset=False, ignore_trailing=False, 279 session=None, path='/dns-query', post=True, 280 bootstrap_address=None, verify=True): 281 """Return the response obtained after sending a query via DNS-over-HTTPS. 282 283 *q*, a ``dns.message.Message``, the query to send. 284 285 *where*, a ``str``, the nameserver IP address or the full URL. If an IP 286 address is given, the URL will be constructed using the following schema: 287 https://<IP-address>:<port>/<path>. 288 289 *timeout*, a ``float`` or ``None``, the number of seconds to 290 wait before the query times out. If ``None``, the default, wait forever. 291 292 *port*, a ``int``, the port to send the query to. The default is 443. 293 294 *source*, a ``str`` containing an IPv4 or IPv6 address, specifying 295 the source address. The default is the wildcard address. 296 297 *source_port*, an ``int``, the port from which to send the message. 298 The default is 0. 299 300 *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own 301 RRset. 302 303 *ignore_trailing*, a ``bool``. If ``True``, ignore trailing 304 junk at end of the received message. 305 306 *session*, a ``requests.session.Session``. If provided, the session to use 307 to send the queries. 308 309 *path*, a ``str``. If *where* is an IP address, then *path* will be used to 310 construct the URL to send the DNS query to. 311 312 *post*, a ``bool``. If ``True``, the default, POST method will be used. 313 314 *bootstrap_address*, a ``str``, the IP address to use to bypass the 315 system's DNS resolver. 316 317 *verify*, a ``str``, containing a path to a certificate file or directory. 318 319 Returns a ``dns.message.Message``. 320 """ 321 322 if not have_doh: 323 raise NoDOH # pragma: no cover 324 325 wire = q.to_wire() 326 (af, destination, source) = _destination_and_source(where, port, 327 source, source_port, 328 False) 329 transport_adapter = None 330 headers = { 331 "accept": "application/dns-message" 332 } 333 try: 334 where_af = dns.inet.af_for_address(where) 335 if where_af == socket.AF_INET: 336 url = 'https://{}:{}{}'.format(where, port, path) 337 elif where_af == socket.AF_INET6: 338 url = 'https://[{}]:{}{}'.format(where, port, path) 339 except ValueError: 340 if bootstrap_address is not None: 341 split_url = urllib.parse.urlsplit(where) 342 headers['Host'] = split_url.hostname 343 url = where.replace(split_url.hostname, bootstrap_address) 344 transport_adapter = HostHeaderSSLAdapter() 345 else: 346 url = where 347 if source is not None: 348 # set source port and source address 349 transport_adapter = SourceAddressAdapter(source) 350 351 with contextlib.ExitStack() as stack: 352 if not session: 353 session = stack.enter_context(requests.sessions.Session()) 354 355 if transport_adapter: 356 session.mount(url, transport_adapter) 357 358 # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH 359 # GET and POST examples 360 if post: 361 headers.update({ 362 "content-type": "application/dns-message", 363 "content-length": str(len(wire)) 364 }) 365 response = session.post(url, headers=headers, data=wire, 366 timeout=timeout, verify=verify) 367 else: 368 wire = base64.urlsafe_b64encode(wire).rstrip(b"=") 369 response = session.get(url, headers=headers, 370 timeout=timeout, verify=verify, 371 params={"dns": wire}) 372 373 # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH 374 # status codes 375 if response.status_code < 200 or response.status_code > 299: 376 raise ValueError('{} responded with status code {}' 377 '\nResponse body: {}'.format(where, 378 response.status_code, 379 response.content)) 380 r = dns.message.from_wire(response.content, 381 keyring=q.keyring, 382 request_mac=q.request_mac, 383 one_rr_per_rrset=one_rr_per_rrset, 384 ignore_trailing=ignore_trailing) 385 r.time = response.elapsed 386 if not q.is_response(r): 387 raise BadResponse 388 return r 389 390def send_udp(sock, what, destination, expiration=None): 391 """Send a DNS message to the specified UDP socket. 392 393 *sock*, a ``socket``. 394 395 *what*, a ``bytes`` or ``dns.message.Message``, the message to send. 396 397 *destination*, a destination tuple appropriate for the address family 398 of the socket, specifying where to send the query. 399 400 *expiration*, a ``float`` or ``None``, the absolute time at which 401 a timeout exception should be raised. If ``None``, no timeout will 402 occur. 403 404 Returns an ``(int, float)`` tuple of bytes sent and the sent time. 405 """ 406 407 if isinstance(what, dns.message.Message): 408 what = what.to_wire() 409 _wait_for_writable(sock, expiration) 410 sent_time = time.time() 411 n = sock.sendto(what, destination) 412 return (n, sent_time) 413 414 415def receive_udp(sock, destination=None, expiration=None, 416 ignore_unexpected=False, one_rr_per_rrset=False, 417 keyring=None, request_mac=b'', ignore_trailing=False, 418 raise_on_truncation=False): 419 """Read a DNS message from a UDP socket. 420 421 *sock*, a ``socket``. 422 423 *destination*, a destination tuple appropriate for the address family 424 of the socket, specifying where the message is expected to arrive from. 425 When receiving a response, this would be where the associated query was 426 sent. 427 428 *expiration*, a ``float`` or ``None``, the absolute time at which 429 a timeout exception should be raised. If ``None``, no timeout will 430 occur. 431 432 *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from 433 unexpected sources. 434 435 *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own 436 RRset. 437 438 *keyring*, a ``dict``, the keyring to use for TSIG. 439 440 *request_mac*, a ``bytes``, the MAC of the request (for TSIG). 441 442 *ignore_trailing*, a ``bool``. If ``True``, ignore trailing 443 junk at end of the received message. 444 445 *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if 446 the TC bit is set. 447 448 Raises if the message is malformed, if network errors occur, of if 449 there is a timeout. 450 451 If *destination* is not ``None``, returns a ``(dns.message.Message, float)`` 452 tuple of the received message and the received time. 453 454 If *destination* is ``None``, returns a 455 ``(dns.message.Message, float, tuple)`` 456 tuple of the received message, the received time, and the address where 457 the message arrived from. 458 """ 459 460 wire = b'' 461 while 1: 462 _wait_for_readable(sock, expiration) 463 (wire, from_address) = sock.recvfrom(65535) 464 if _matches_destination(sock.family, from_address, destination, 465 ignore_unexpected): 466 break 467 received_time = time.time() 468 r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac, 469 one_rr_per_rrset=one_rr_per_rrset, 470 ignore_trailing=ignore_trailing, 471 raise_on_truncation=raise_on_truncation) 472 if destination: 473 return (r, received_time) 474 else: 475 return (r, received_time, from_address) 476 477def udp(q, where, timeout=None, port=53, source=None, source_port=0, 478 ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False, 479 raise_on_truncation=False, sock=None): 480 """Return the response obtained after sending a query via UDP. 481 482 *q*, a ``dns.message.Message``, the query to send 483 484 *where*, a ``str`` containing an IPv4 or IPv6 address, where 485 to send the message. 486 487 *timeout*, a ``float`` or ``None``, the number of seconds to wait before the 488 query times out. If ``None``, the default, wait forever. 489 490 *port*, an ``int``, the port send the message to. The default is 53. 491 492 *source*, a ``str`` containing an IPv4 or IPv6 address, specifying 493 the source address. The default is the wildcard address. 494 495 *source_port*, an ``int``, the port from which to send the message. 496 The default is 0. 497 498 *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from 499 unexpected sources. 500 501 *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own 502 RRset. 503 504 *ignore_trailing*, a ``bool``. If ``True``, ignore trailing 505 junk at end of the received message. 506 507 *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if 508 the TC bit is set. 509 510 *sock*, a ``socket.socket``, or ``None``, the socket to use for the 511 query. If ``None``, the default, a socket is created. Note that 512 if a socket is provided, it must be a nonblocking datagram socket, 513 and the *source* and *source_port* are ignored. 514 515 Returns a ``dns.message.Message``. 516 """ 517 518 wire = q.to_wire() 519 (af, destination, source) = _destination_and_source(where, port, 520 source, source_port) 521 (begin_time, expiration) = _compute_times(timeout) 522 with contextlib.ExitStack() as stack: 523 if sock: 524 s = sock 525 else: 526 s = stack.enter_context(_make_socket(af, socket.SOCK_DGRAM, source)) 527 send_udp(s, wire, destination, expiration) 528 (r, received_time) = receive_udp(s, destination, expiration, 529 ignore_unexpected, one_rr_per_rrset, 530 q.keyring, q.mac, ignore_trailing, 531 raise_on_truncation) 532 r.time = received_time - begin_time 533 if not q.is_response(r): 534 raise BadResponse 535 return r 536 537def udp_with_fallback(q, where, timeout=None, port=53, source=None, 538 source_port=0, ignore_unexpected=False, 539 one_rr_per_rrset=False, ignore_trailing=False, 540 udp_sock=None, tcp_sock=None): 541 """Return the response to the query, trying UDP first and falling back 542 to TCP if UDP results in a truncated response. 543 544 *q*, a ``dns.message.Message``, the query to send 545 546 *where*, a ``str`` containing an IPv4 or IPv6 address, where 547 to send the message. 548 549 *timeout*, a ``float`` or ``None``, the number of seconds to wait before the 550 query times out. If ``None``, the default, wait forever. 551 552 *port*, an ``int``, the port send the message to. The default is 53. 553 554 *source*, a ``str`` containing an IPv4 or IPv6 address, specifying 555 the source address. The default is the wildcard address. 556 557 *source_port*, an ``int``, the port from which to send the message. 558 The default is 0. 559 560 *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from 561 unexpected sources. 562 563 *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own 564 RRset. 565 566 *ignore_trailing*, a ``bool``. If ``True``, ignore trailing 567 junk at end of the received message. 568 569 *udp_sock*, a ``socket.socket``, or ``None``, the socket to use for the 570 UDP query. If ``None``, the default, a socket is created. Note that 571 if a socket is provided, it must be a nonblocking datagram socket, 572 and the *source* and *source_port* are ignored for the UDP query. 573 574 *tcp_sock*, a ``socket.socket``, or ``None``, the socket to use for the 575 TCP query. If ``None``, the default, a socket is created. Note that 576 if a socket is provided, it must be a nonblocking connected stream 577 socket, and *where*, *source* and *source_port* are ignored for the TCP 578 query. 579 580 Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True`` 581 if and only if TCP was used. 582 """ 583 try: 584 response = udp(q, where, timeout, port, source, source_port, 585 ignore_unexpected, one_rr_per_rrset, 586 ignore_trailing, True, udp_sock) 587 return (response, False) 588 except dns.message.Truncated: 589 response = tcp(q, where, timeout, port, source, source_port, 590 one_rr_per_rrset, ignore_trailing, tcp_sock) 591 return (response, True) 592 593def _net_read(sock, count, expiration): 594 """Read the specified number of bytes from sock. Keep trying until we 595 either get the desired amount, or we hit EOF. 596 A Timeout exception will be raised if the operation is not completed 597 by the expiration time. 598 """ 599 s = b'' 600 while count > 0: 601 _wait_for_readable(sock, expiration) 602 try: 603 n = sock.recv(count) 604 except ssl.SSLWantReadError: # pragma: no cover 605 continue 606 except ssl.SSLWantWriteError: # pragma: no cover 607 _wait_for_writable(sock, expiration) 608 continue 609 if n == b'': 610 raise EOFError 611 count = count - len(n) 612 s = s + n 613 return s 614 615 616def _net_write(sock, data, expiration): 617 """Write the specified data to the socket. 618 A Timeout exception will be raised if the operation is not completed 619 by the expiration time. 620 """ 621 current = 0 622 l = len(data) 623 while current < l: 624 _wait_for_writable(sock, expiration) 625 try: 626 current += sock.send(data[current:]) 627 except ssl.SSLWantReadError: # pragma: no cover 628 _wait_for_readable(sock, expiration) 629 continue 630 except ssl.SSLWantWriteError: # pragma: no cover 631 continue 632 633 634def send_tcp(sock, what, expiration=None): 635 """Send a DNS message to the specified TCP socket. 636 637 *sock*, a ``socket``. 638 639 *what*, a ``bytes`` or ``dns.message.Message``, the message to send. 640 641 *expiration*, a ``float`` or ``None``, the absolute time at which 642 a timeout exception should be raised. If ``None``, no timeout will 643 occur. 644 645 Returns an ``(int, float)`` tuple of bytes sent and the sent time. 646 """ 647 648 if isinstance(what, dns.message.Message): 649 what = what.to_wire() 650 l = len(what) 651 # copying the wire into tcpmsg is inefficient, but lets us 652 # avoid writev() or doing a short write that would get pushed 653 # onto the net 654 tcpmsg = struct.pack("!H", l) + what 655 _wait_for_writable(sock, expiration) 656 sent_time = time.time() 657 _net_write(sock, tcpmsg, expiration) 658 return (len(tcpmsg), sent_time) 659 660def receive_tcp(sock, expiration=None, one_rr_per_rrset=False, 661 keyring=None, request_mac=b'', ignore_trailing=False): 662 """Read a DNS message from a TCP socket. 663 664 *sock*, a ``socket``. 665 666 *expiration*, a ``float`` or ``None``, the absolute time at which 667 a timeout exception should be raised. If ``None``, no timeout will 668 occur. 669 670 *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own 671 RRset. 672 673 *keyring*, a ``dict``, the keyring to use for TSIG. 674 675 *request_mac*, a ``bytes``, the MAC of the request (for TSIG). 676 677 *ignore_trailing*, a ``bool``. If ``True``, ignore trailing 678 junk at end of the received message. 679 680 Raises if the message is malformed, if network errors occur, of if 681 there is a timeout. 682 683 Returns a ``(dns.message.Message, float)`` tuple of the received message 684 and the received time. 685 """ 686 687 ldata = _net_read(sock, 2, expiration) 688 (l,) = struct.unpack("!H", ldata) 689 wire = _net_read(sock, l, expiration) 690 received_time = time.time() 691 r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac, 692 one_rr_per_rrset=one_rr_per_rrset, 693 ignore_trailing=ignore_trailing) 694 return (r, received_time) 695 696def _connect(s, address, expiration): 697 err = s.connect_ex(address) 698 if err == 0: 699 return 700 if err in (errno.EINPROGRESS, errno.EWOULDBLOCK, errno.EALREADY): 701 _wait_for_writable(s, expiration) 702 err = s.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) 703 if err != 0: 704 raise OSError(err, os.strerror(err)) 705 706 707def tcp(q, where, timeout=None, port=53, source=None, source_port=0, 708 one_rr_per_rrset=False, ignore_trailing=False, sock=None): 709 """Return the response obtained after sending a query via TCP. 710 711 *q*, a ``dns.message.Message``, the query to send 712 713 *where*, a ``str`` containing an IPv4 or IPv6 address, where 714 to send the message. 715 716 *timeout*, a ``float`` or ``None``, the number of seconds to wait before the 717 query times out. If ``None``, the default, wait forever. 718 719 *port*, an ``int``, the port send the message to. The default is 53. 720 721 *source*, a ``str`` containing an IPv4 or IPv6 address, specifying 722 the source address. The default is the wildcard address. 723 724 *source_port*, an ``int``, the port from which to send the message. 725 The default is 0. 726 727 *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own 728 RRset. 729 730 *ignore_trailing*, a ``bool``. If ``True``, ignore trailing 731 junk at end of the received message. 732 733 *sock*, a ``socket.socket``, or ``None``, the socket to use for the 734 query. If ``None``, the default, a socket is created. Note that 735 if a socket is provided, it must be a nonblocking connected stream 736 socket, and *where*, *port*, *source* and *source_port* are ignored. 737 738 Returns a ``dns.message.Message``. 739 """ 740 741 wire = q.to_wire() 742 (begin_time, expiration) = _compute_times(timeout) 743 with contextlib.ExitStack() as stack: 744 if sock: 745 # 746 # Verify that the socket is connected, as if it's not connected, 747 # it's not writable, and the polling in send_tcp() will time out or 748 # hang forever. 749 sock.getpeername() 750 s = sock 751 else: 752 (af, destination, source) = _destination_and_source(where, port, 753 source, 754 source_port) 755 s = stack.enter_context(_make_socket(af, socket.SOCK_STREAM, 756 source)) 757 _connect(s, destination, expiration) 758 send_tcp(s, wire, expiration) 759 (r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset, 760 q.keyring, q.mac, ignore_trailing) 761 r.time = received_time - begin_time 762 if not q.is_response(r): 763 raise BadResponse 764 return r 765 766 767def _tls_handshake(s, expiration): 768 while True: 769 try: 770 s.do_handshake() 771 return 772 except ssl.SSLWantReadError: 773 _wait_for_readable(s, expiration) 774 except ssl.SSLWantWriteError: # pragma: no cover 775 _wait_for_writable(s, expiration) 776 777 778def tls(q, where, timeout=None, port=853, source=None, source_port=0, 779 one_rr_per_rrset=False, ignore_trailing=False, sock=None, 780 ssl_context=None, server_hostname=None): 781 """Return the response obtained after sending a query via TLS. 782 783 *q*, a ``dns.message.Message``, the query to send 784 785 *where*, a ``str`` containing an IPv4 or IPv6 address, where 786 to send the message. 787 788 *timeout*, a ``float`` or ``None``, the number of seconds to wait before the 789 query times out. If ``None``, the default, wait forever. 790 791 *port*, an ``int``, the port send the message to. The default is 853. 792 793 *source*, a ``str`` containing an IPv4 or IPv6 address, specifying 794 the source address. The default is the wildcard address. 795 796 *source_port*, an ``int``, the port from which to send the message. 797 The default is 0. 798 799 *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own 800 RRset. 801 802 *ignore_trailing*, a ``bool``. If ``True``, ignore trailing 803 junk at end of the received message. 804 805 *sock*, an ``ssl.SSLSocket``, or ``None``, the socket to use for 806 the query. If ``None``, the default, a socket is created. Note 807 that if a socket is provided, it must be a nonblocking connected 808 SSL stream socket, and *where*, *port*, *source*, *source_port*, 809 and *ssl_context* are ignored. 810 811 *ssl_context*, an ``ssl.SSLContext``, the context to use when establishing 812 a TLS connection. If ``None``, the default, creates one with the default 813 configuration. 814 815 *server_hostname*, a ``str`` containing the server's hostname. The 816 default is ``None``, which means that no hostname is known, and if an 817 SSL context is created, hostname checking will be disabled. 818 819 Returns a ``dns.message.Message``. 820 821 """ 822 823 if sock: 824 # 825 # If a socket was provided, there's no special TLS handling needed. 826 # 827 return tcp(q, where, timeout, port, source, source_port, 828 one_rr_per_rrset, ignore_trailing, sock) 829 830 wire = q.to_wire() 831 (begin_time, expiration) = _compute_times(timeout) 832 (af, destination, source) = _destination_and_source(where, port, 833 source, source_port) 834 if ssl_context is None and not sock: 835 ssl_context = ssl.create_default_context() 836 if server_hostname is None: 837 ssl_context.check_hostname = False 838 839 with _make_socket(af, socket.SOCK_STREAM, source, ssl_context=ssl_context, 840 server_hostname=server_hostname) as s: 841 _connect(s, destination, expiration) 842 _tls_handshake(s, expiration) 843 send_tcp(s, wire, expiration) 844 (r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset, 845 q.keyring, q.mac, ignore_trailing) 846 r.time = received_time - begin_time 847 if not q.is_response(r): 848 raise BadResponse 849 return r 850 851 852def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, 853 timeout=None, port=53, keyring=None, keyname=None, relativize=True, 854 lifetime=None, source=None, source_port=0, serial=0, 855 use_udp=False, keyalgorithm=dns.tsig.default_algorithm): 856 """Return a generator for the responses to a zone transfer. 857 858 *where*, a ``str`` containing an IPv4 or IPv6 address, where 859 to send the message. 860 861 *zone*, a ``dns.name.Name`` or ``str``, the name of the zone to transfer. 862 863 *rdtype*, an ``int`` or ``str``, the type of zone transfer. The 864 default is ``dns.rdatatype.AXFR``. ``dns.rdatatype.IXFR`` can be 865 used to do an incremental transfer instead. 866 867 *rdclass*, an ``int`` or ``str``, the class of the zone transfer. 868 The default is ``dns.rdataclass.IN``. 869 870 *timeout*, a ``float``, the number of seconds to wait for each 871 response message. If None, the default, wait forever. 872 873 *port*, an ``int``, the port send the message to. The default is 53. 874 875 *keyring*, a ``dict``, the keyring to use for TSIG. 876 877 *keyname*, a ``dns.name.Name`` or ``str``, the name of the TSIG 878 key to use. 879 880 *relativize*, a ``bool``. If ``True``, all names in the zone will be 881 relativized to the zone origin. It is essential that the 882 relativize setting matches the one specified to 883 ``dns.zone.from_xfr()`` if using this generator to make a zone. 884 885 *lifetime*, a ``float``, the total number of seconds to spend 886 doing the transfer. If ``None``, the default, then there is no 887 limit on the time the transfer may take. 888 889 *source*, a ``str`` containing an IPv4 or IPv6 address, specifying 890 the source address. The default is the wildcard address. 891 892 *source_port*, an ``int``, the port from which to send the message. 893 The default is 0. 894 895 *serial*, an ``int``, the SOA serial number to use as the base for 896 an IXFR diff sequence (only meaningful if *rdtype* is 897 ``dns.rdatatype.IXFR``). 898 899 *use_udp*, a ``bool``. If ``True``, use UDP (only meaningful for IXFR). 900 901 *keyalgorithm*, a ``dns.name.Name`` or ``str``, the TSIG algorithm to use. 902 903 Raises on errors, and so does the generator. 904 905 Returns a generator of ``dns.message.Message`` objects. 906 """ 907 908 if isinstance(zone, str): 909 zone = dns.name.from_text(zone) 910 rdtype = dns.rdatatype.RdataType.make(rdtype) 911 q = dns.message.make_query(zone, rdtype, rdclass) 912 if rdtype == dns.rdatatype.IXFR: 913 rrset = dns.rrset.from_text(zone, 0, 'IN', 'SOA', 914 '. . %u 0 0 0 0' % serial) 915 q.authority.append(rrset) 916 if keyring is not None: 917 q.use_tsig(keyring, keyname, algorithm=keyalgorithm) 918 wire = q.to_wire() 919 (af, destination, source) = _destination_and_source(where, port, 920 source, source_port) 921 if use_udp and rdtype != dns.rdatatype.IXFR: 922 raise ValueError('cannot do a UDP AXFR') 923 sock_type = socket.SOCK_DGRAM if use_udp else socket.SOCK_STREAM 924 with _make_socket(af, sock_type, source) as s: 925 (_, expiration) = _compute_times(lifetime) 926 _connect(s, destination, expiration) 927 l = len(wire) 928 if use_udp: 929 _wait_for_writable(s, expiration) 930 s.send(wire) 931 else: 932 tcpmsg = struct.pack("!H", l) + wire 933 _net_write(s, tcpmsg, expiration) 934 done = False 935 delete_mode = True 936 expecting_SOA = False 937 soa_rrset = None 938 if relativize: 939 origin = zone 940 oname = dns.name.empty 941 else: 942 origin = None 943 oname = zone 944 tsig_ctx = None 945 while not done: 946 (_, mexpiration) = _compute_times(timeout) 947 if mexpiration is None or \ 948 (expiration is not None and mexpiration > expiration): 949 mexpiration = expiration 950 if use_udp: 951 _wait_for_readable(s, expiration) 952 (wire, from_address) = s.recvfrom(65535) 953 else: 954 ldata = _net_read(s, 2, mexpiration) 955 (l,) = struct.unpack("!H", ldata) 956 wire = _net_read(s, l, mexpiration) 957 is_ixfr = (rdtype == dns.rdatatype.IXFR) 958 r = dns.message.from_wire(wire, keyring=q.keyring, 959 request_mac=q.mac, xfr=True, 960 origin=origin, tsig_ctx=tsig_ctx, 961 multi=True, one_rr_per_rrset=is_ixfr) 962 rcode = r.rcode() 963 if rcode != dns.rcode.NOERROR: 964 raise TransferError(rcode) 965 tsig_ctx = r.tsig_ctx 966 answer_index = 0 967 if soa_rrset is None: 968 if not r.answer or r.answer[0].name != oname: 969 raise dns.exception.FormError( 970 "No answer or RRset not for qname") 971 rrset = r.answer[0] 972 if rrset.rdtype != dns.rdatatype.SOA: 973 raise dns.exception.FormError("first RRset is not an SOA") 974 answer_index = 1 975 soa_rrset = rrset.copy() 976 if rdtype == dns.rdatatype.IXFR: 977 if dns.serial.Serial(soa_rrset[0].serial) <= serial: 978 # 979 # We're already up-to-date. 980 # 981 done = True 982 else: 983 expecting_SOA = True 984 # 985 # Process SOAs in the answer section (other than the initial 986 # SOA in the first message). 987 # 988 for rrset in r.answer[answer_index:]: 989 if done: 990 raise dns.exception.FormError("answers after final SOA") 991 if rrset.rdtype == dns.rdatatype.SOA and rrset.name == oname: 992 if expecting_SOA: 993 if rrset[0].serial != serial: 994 raise dns.exception.FormError( 995 "IXFR base serial mismatch") 996 expecting_SOA = False 997 elif rdtype == dns.rdatatype.IXFR: 998 delete_mode = not delete_mode 999 # 1000 # If this SOA RRset is equal to the first we saw then we're 1001 # finished. If this is an IXFR we also check that we're 1002 # seeing the record in the expected part of the response. 1003 # 1004 if rrset == soa_rrset and \ 1005 (rdtype == dns.rdatatype.AXFR or 1006 (rdtype == dns.rdatatype.IXFR and delete_mode)): 1007 done = True 1008 elif expecting_SOA: 1009 # 1010 # We made an IXFR request and are expecting another 1011 # SOA RR, but saw something else, so this must be an 1012 # AXFR response. 1013 # 1014 rdtype = dns.rdatatype.AXFR 1015 expecting_SOA = False 1016 if done and q.keyring and not r.had_tsig: 1017 raise dns.exception.FormError("missing TSIG") 1018 yield r 1019