1#!/usr/bin/env python2.7
2# Copyright 2015 gRPC authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""Starts a local DNS server for use in tests"""
16
17import argparse
18import sys
19import yaml
20import signal
21import os
22import threading
23import time
24
25import twisted
26import twisted.internet
27import twisted.internet.reactor
28import twisted.internet.threads
29import twisted.internet.defer
30import twisted.internet.protocol
31import twisted.names
32import twisted.names.client
33import twisted.names.dns
34import twisted.names.server
35from twisted.names import client, server, common, authority, dns
36import argparse
37import platform
38
39_SERVER_HEALTH_CHECK_RECORD_NAME = 'health-check-local-dns-server-is-alive.resolver-tests.grpctestingexp'  # missing end '.' for twisted syntax
40_SERVER_HEALTH_CHECK_RECORD_DATA = '123.123.123.123'
41
42
43class NoFileAuthority(authority.FileAuthority):
44
45    def __init__(self, soa, records):
46        # skip FileAuthority
47        common.ResolverBase.__init__(self)
48        self.soa = soa
49        self.records = records
50
51
52def start_local_dns_server(args):
53    all_records = {}
54
55    def _push_record(name, r):
56        print('pushing record: |%s|' % name)
57        if all_records.get(name) is not None:
58            all_records[name].append(r)
59            return
60        all_records[name] = [r]
61
62    def _maybe_split_up_txt_data(name, txt_data, r_ttl):
63        start = 0
64        txt_data_list = []
65        while len(txt_data[start:]) > 0:
66            next_read = len(txt_data[start:])
67            if next_read > 255:
68                next_read = 255
69            txt_data_list.append(txt_data[start:start + next_read])
70            start += next_read
71        _push_record(name, dns.Record_TXT(*txt_data_list, ttl=r_ttl))
72
73    with open(args.records_config_path) as config:
74        test_records_config = yaml.load(config)
75    common_zone_name = test_records_config['resolver_tests_common_zone_name']
76    for group in test_records_config['resolver_component_tests']:
77        for name in group['records'].keys():
78            for record in group['records'][name]:
79                r_type = record['type']
80                r_data = record['data']
81                r_ttl = int(record['TTL'])
82                record_full_name = '%s.%s' % (name, common_zone_name)
83                assert record_full_name[-1] == '.'
84                record_full_name = record_full_name[:-1]
85                if r_type == 'A':
86                    _push_record(record_full_name,
87                                 dns.Record_A(r_data, ttl=r_ttl))
88                if r_type == 'AAAA':
89                    _push_record(record_full_name,
90                                 dns.Record_AAAA(r_data, ttl=r_ttl))
91                if r_type == 'SRV':
92                    p, w, port, target = r_data.split(' ')
93                    p = int(p)
94                    w = int(w)
95                    port = int(port)
96                    target_full_name = '%s.%s' % (target, common_zone_name)
97                    r_data = '%s %s %s %s' % (p, w, port, target_full_name)
98                    _push_record(
99                        record_full_name,
100                        dns.Record_SRV(p, w, port, target_full_name, ttl=r_ttl))
101                if r_type == 'TXT':
102                    _maybe_split_up_txt_data(record_full_name, r_data, r_ttl)
103    # Add an optional IPv4 record is specified
104    if args.add_a_record:
105        extra_host, extra_host_ipv4 = args.add_a_record.split(':')
106        _push_record(extra_host, dns.Record_A(extra_host_ipv4, ttl=0))
107    # Server health check record
108    _push_record(_SERVER_HEALTH_CHECK_RECORD_NAME,
109                 dns.Record_A(_SERVER_HEALTH_CHECK_RECORD_DATA, ttl=0))
110    soa_record = dns.Record_SOA(mname=common_zone_name)
111    test_domain_com = NoFileAuthority(
112        soa=(common_zone_name, soa_record),
113        records=all_records,
114    )
115    server = twisted.names.server.DNSServerFactory(
116        authorities=[test_domain_com], verbose=2)
117    server.noisy = 2
118    twisted.internet.reactor.listenTCP(args.port, server)
119    dns_proto = twisted.names.dns.DNSDatagramProtocol(server)
120    dns_proto.noisy = 2
121    twisted.internet.reactor.listenUDP(args.port, dns_proto)
122    print('starting local dns server on 127.0.0.1:%s' % args.port)
123    print('starting twisted.internet.reactor')
124    twisted.internet.reactor.suggestThreadPoolSize(1)
125    twisted.internet.reactor.run()
126
127
128def _quit_on_signal(signum, _frame):
129    print('Received SIGNAL %d. Quitting with exit code 0' % signum)
130    twisted.internet.reactor.stop()
131    sys.stdout.flush()
132    sys.exit(0)
133
134
135def flush_stdout_loop():
136    num_timeouts_so_far = 0
137    sleep_time = 1
138    # Prevent zombies. Tests that use this server are short-lived.
139    max_timeouts = 60 * 10
140    while num_timeouts_so_far < max_timeouts:
141        sys.stdout.flush()
142        time.sleep(sleep_time)
143        num_timeouts_so_far += 1
144    print('Process timeout reached, or cancelled. Exitting 0.')
145    os.kill(os.getpid(), signal.SIGTERM)
146
147
148def main():
149    argp = argparse.ArgumentParser(
150        description='Local DNS Server for resolver tests')
151    argp.add_argument('-p',
152                      '--port',
153                      default=None,
154                      type=int,
155                      help='Port for DNS server to listen on for TCP and UDP.')
156    argp.add_argument(
157        '-r',
158        '--records_config_path',
159        default=None,
160        type=str,
161        help=('Directory of resolver_test_record_groups.yaml file. '
162              'Defaults to path needed when the test is invoked as part '
163              'of run_tests.py.'))
164    argp.add_argument(
165        '--add_a_record',
166        default=None,
167        type=str,
168        help=('Add an A record via the command line. Useful for when we '
169              'need to serve a one-off A record that is under a '
170              'different domain then the rest the records configured in '
171              '--records_config_path (which all need to be under the '
172              'same domain). Format: <name>:<ipv4 address>'))
173    args = argp.parse_args()
174    signal.signal(signal.SIGTERM, _quit_on_signal)
175    signal.signal(signal.SIGINT, _quit_on_signal)
176    output_flush_thread = threading.Thread(target=flush_stdout_loop)
177    output_flush_thread.setDaemon(True)
178    output_flush_thread.start()
179    start_local_dns_server(args)
180
181
182if __name__ == '__main__':
183    main()
184