1#!/usr/bin/env python3
2#
3# Unix SMB/CIFS implementation.
4# Copyright (C) Volker Lendecke 2017
5#
6# This program is free software; you can redistribute it and/or modify
7# it under the terms of the GNU General Public License as published by
8# the Free Software Foundation; either version 3 of the License, or
9# (at your option) any later version.
10#
11# This program is distributed in the hope that it will be useful,
12# but WITHOUT ANY WARRANTY; without even the implied warranty of
13# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14# GNU General Public License for more details.
15#
16# You should have received a copy of the GNU General Public License
17# along with this program.  If not, see <http://www.gnu.org/licenses/>.
18#
19# Used by selftest to proxy DNS queries to the correct testenv DC.
20# See selftest/target/README for more details.
21# Based on the EchoServer example from python docs
22
23import threading
24import sys
25import select
26import socket
27import time
28from samba.dcerpc import dns
29import samba.ndr as ndr
30
31if sys.version_info[0] < 3:
32    import SocketServer
33    sserver = SocketServer
34else:
35    import socketserver
36    sserver = socketserver
37
38DNS_REQUEST_TIMEOUT = 10
39
40
41class DnsHandler(sserver.BaseRequestHandler):
42    dns_qtype_strings = dict((v, k) for k, v in vars(dns).items() if k.startswith('DNS_QTYPE_'))
43    def dns_qtype_string(self, qtype):
44        "Return a readable qtype code"
45        return self.dns_qtype_strings[qtype]
46
47    dns_rcode_strings = dict((v, k) for k, v in vars(dns).items() if k.startswith('DNS_RCODE_'))
48    def dns_rcode_string(self, rcode):
49        "Return a readable error code"
50        return self.dns_rcode_strings[rcode]
51
52    def dns_transaction_udp(self, packet, host):
53        "send a DNS query and read the reply"
54        s = None
55        try:
56            send_packet = ndr.ndr_pack(packet)
57            s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
58            s.settimeout(DNS_REQUEST_TIMEOUT)
59            s.connect((host, 53))
60            s.sendall(send_packet, 0)
61            recv_packet = s.recv(2048, 0)
62            return ndr.ndr_unpack(dns.name_packet, recv_packet)
63        except socket.error as err:
64            print("Error sending to host %s for name %s: %s\n" %
65                  (host, packet.questions[0].name, err.errno))
66            raise
67        finally:
68            if s is not None:
69                s.close()
70        return None
71
72    def get_pdc_ipv4_addr(self, lookup_name):
73        """Maps a DNS realm to the IPv4 address of the PDC for that testenv"""
74
75        realm_to_ip_mappings = self.server.realm_to_ip_mappings
76
77        # sort the realms so we find the longest-match first
78        testenv_realms = sorted(realm_to_ip_mappings.keys(), key=len)
79        testenv_realms.reverse()
80
81        for realm in testenv_realms:
82            if lookup_name.endswith(realm):
83                # return the corresponding IP address for this realm's PDC
84                return realm_to_ip_mappings[realm]
85
86        return None
87
88    def forwarder(self, name):
89        lname = name.lower()
90
91        # check for special cases used by tests (e.g. dns_forwarder.py)
92        if lname.endswith('an-address-that-will-not-resolve'):
93            return 'ignore'
94        if lname.endswith('dsfsdfs'):
95            return 'fail'
96        if lname.endswith("torture1", 0, len(lname)-2):
97            # CATCH TORTURE100, TORTURE101, ...
98            return 'torture'
99        if lname.endswith('_none_.example.com'):
100            return 'torture'
101        if lname.endswith('torturedom.samba.example.com'):
102            return 'torture'
103
104        # return the testenv PDC matching the realm being requested
105        return self.get_pdc_ipv4_addr(lname)
106
107    def handle(self):
108        start = time.monotonic()
109        data, sock = self.request
110        query = ndr.ndr_unpack(dns.name_packet, data)
111        name = query.questions[0].name
112        forwarder = self.forwarder(name)
113        response = None
114
115        if forwarder is 'ignore':
116            return
117        elif forwarder is 'fail':
118            pass
119        elif forwarder in ['torture', None]:
120            response = query
121            response.operation |= dns.DNS_FLAG_REPLY
122            response.operation |= dns.DNS_FLAG_RECURSION_AVAIL
123            response.operation |= dns.DNS_RCODE_NXDOMAIN
124        else:
125            response = self.dns_transaction_udp(query, forwarder)
126
127        if response is None:
128            response = query
129            response.operation |= dns.DNS_FLAG_REPLY
130            response.operation |= dns.DNS_FLAG_RECURSION_AVAIL
131            response.operation |= dns.DNS_RCODE_SERVFAIL
132
133        send_packet = ndr.ndr_pack(response)
134
135        end = time.monotonic()
136        tdiff = end - start
137        errcode = response.operation & dns.DNS_RCODE
138        if tdiff > (DNS_REQUEST_TIMEOUT/5):
139            debug = True
140        else:
141            debug = False
142        if debug:
143            print("dns_hub: forwarder[%s] client[%s] name[%s][%s] %s response.operation[0x%x] tdiff[%s]\n" %
144                (forwarder, self.client_address, name,
145                 self.dns_qtype_string(query.questions[0].question_type),
146                 self.dns_rcode_string(errcode), response.operation, tdiff))
147
148        try:
149            sock.sendto(send_packet, self.client_address)
150        except socket.error as err:
151            print("dns_hub: Error sending response to client[%s] for name[%s] tdiff[%s]: %s\n" %
152                (self.client_address, name, tdiff, err))
153
154
155class server_thread(threading.Thread):
156    def __init__(self, server):
157        threading.Thread.__init__(self)
158        self.server = server
159
160    def run(self):
161        self.server.serve_forever()
162        print("dns_hub: after serve_forever()")
163
164
165def main():
166    if len(sys.argv) < 4:
167        print("Usage: dns_hub.py TIMEOUT HOST MAPPING")
168        sys.exit(1)
169
170    timeout = int(sys.argv[1]) * 1000
171    timeout = min(timeout, 2**31 - 1)  # poll with 32-bit int can't take more
172    host = sys.argv[2]
173
174    server = sserver.UDPServer((host, int(53)), DnsHandler)
175
176    # we pass in the realm-to-IP mappings as a comma-separated key=value
177    # string. Convert this back into a dictionary that the DnsHandler can use
178    realm_mapping = dict(kv.split('=') for kv in sys.argv[3].split(','))
179    server.realm_to_ip_mappings = realm_mapping
180
181    print("dns_hub will proxy DNS requests for the following realms:")
182    for realm, ip in server.realm_to_ip_mappings.items():
183        print("  {0} ==> {1}".format(realm, ip))
184
185    t = server_thread(server)
186    t.start()
187    p = select.poll()
188    stdin = sys.stdin.fileno()
189    p.register(stdin, select.POLLIN)
190    p.poll(timeout)
191    print("dns_hub: after poll()")
192    server.shutdown()
193    t.join()
194    print("dns_hub: before exit()")
195    sys.exit(0)
196
197main()
198