1# -*- coding: utf-8 -*-
2
3import argparse
4import binascii
5import gzip
6import logging
7import os
8import subprocess
9import tempfile
10import unittest
11
12import dns.name, dns.rdatatype, dns.rrset, dns.zone
13
14from dnsviz.commands.probe import ZoneFileToServe, ArgHelper, DomainListArgHelper, StandardRecursiveQueryCD, WILDCARD_EXPLICIT_DELEGATION, AnalysisInputError, CustomQueryMixin
15from dnsviz import transport
16from dnsviz.resolver import Resolver
17from dnsviz.ipaddr import IPAddr
18
19DATA_DIR = os.path.dirname(__file__)
20EXAMPLE_COM_ZONE = os.path.join(DATA_DIR, 'zone', 'example.com.zone')
21EXAMPLE_COM_DELEGATION = os.path.join(DATA_DIR, 'zone', 'example.com.zone-delegation')
22EXAMPLE_AUTHORITATIVE = os.path.join(DATA_DIR, 'data', 'example-authoritative.json.gz')
23
24class DNSVizProbeOptionsTestCase(unittest.TestCase):
25    def setUp(self):
26        self.tm = transport.DNSQueryTransportManager()
27        self.resolver = Resolver.from_file('/etc/resolv.conf', StandardRecursiveQueryCD, transport_manager=self.tm)
28        self.helper = DomainListArgHelper(self.resolver)
29        self.logger = logging.getLogger()
30        for handler in self.logger.handlers:
31            self.logger.removeHandler(handler)
32        self.logger.addHandler(logging.NullHandler())
33        try:
34            ArgHelper.bindable_ip('::1')
35        except argparse.ArgumentTypeError:
36            self.use_ipv6 = False
37        else:
38            self.use_ipv6 = True
39        self.first_port = ZoneFileToServe._next_free_port
40        self.custom_query_mixin_edns_options_orig = CustomQueryMixin.edns_options[:]
41
42    def tearDown(self):
43        CustomQueryMixin.edns_options = self.custom_query_mixin_edns_options_orig[:]
44        if self.tm is not None:
45            self.tm.close()
46
47    def test_authoritative_option(self):
48        arg1 = 'example.com+:ns1.example.com=192.0.2.1:1234,ns1.example.com=[2001:db8::1],' + \
49                'ns1.example.com=192.0.2.2,ns2.example.com=[2001:db8::2],a.root-servers.net,192.0.2.3'
50
51        arg1_with_spaces = ' example.com+ : ns1.example.com = [192.0.2.1]:1234 , ns1.example.com = [2001:db8::1], ' + \
52                'ns1.example.com = [192.0.2.2] , ns2.example.com = [2001:db8::2] , a.root-servers.net , 192.0.2.3 '
53
54        arg2 = 'example.com:ns1.example.com=192.0.2.1'
55
56        arg3 = 'example.com:%s' % EXAMPLE_COM_ZONE
57
58        arg4 = 'example.com+:%s' % EXAMPLE_COM_ZONE
59
60        delegation_mapping1 = {
61                (dns.name.from_text('example.com'), dns.rdatatype.NS):
62                        dns.rrset.from_text_list(dns.name.from_text('example.com'), 0, dns.rdataclass.IN, dns.rdatatype.NS,
63                            ['ns1.example.com', 'ns2.example.com', 'a.root-servers.net', 'ns1._dnsviz.example.com']),
64                (dns.name.from_text('ns1.example.com'), dns.rdatatype.A):
65                        dns.rrset.from_text_list(dns.name.from_text('ns1.example.com'), 0, dns.rdataclass.IN, dns.rdatatype.A,
66                            ['192.0.2.1', '192.0.2.2']),
67                (dns.name.from_text('ns1.example.com'), dns.rdatatype.AAAA):
68                        dns.rrset.from_text_list(dns.name.from_text('ns1.example.com'), 0, dns.rdataclass.IN, dns.rdatatype.AAAA,
69                            ['2001:db8::1']),
70                (dns.name.from_text('ns1._dnsviz.example.com'), dns.rdatatype.A):
71                        dns.rrset.from_text_list(dns.name.from_text('ns1._dnsviz.example.com'), 0, dns.rdataclass.IN, dns.rdatatype.A,
72                            ['192.0.2.3']),
73                (dns.name.from_text('ns2.example.com'), dns.rdatatype.AAAA):
74                        dns.rrset.from_text_list(dns.name.from_text('ns2.example.com'), 0, dns.rdataclass.IN, dns.rdatatype.AAAA,
75                            ['2001:db8::2']),
76                (dns.name.from_text('a.root-servers.net'), dns.rdatatype.A):
77                        dns.rrset.from_text_list(dns.name.from_text('a.root-servers.net'), 0, dns.rdataclass.IN, dns.rdatatype.A,
78                            ['198.41.0.4']),
79                (dns.name.from_text('a.root-servers.net'), dns.rdatatype.AAAA):
80                        dns.rrset.from_text_list(dns.name.from_text('a.root-servers.net'), 0, dns.rdataclass.IN, dns.rdatatype.AAAA,
81                            ['2001:503:ba3e::2:30'])
82                }
83        stop_at1 = True
84        odd_ports1 = { (dns.name.from_text('example.com'), IPAddr('192.0.2.1')): 1234 }
85        zone_filename1 = None
86
87        delegation_mapping2 = {
88                (dns.name.from_text('example.com'), dns.rdatatype.NS):
89                        dns.rrset.from_text_list(dns.name.from_text('example.com'), 0, dns.rdataclass.IN, dns.rdatatype.NS,
90                            ['ns1.example.com']),
91                (dns.name.from_text('ns1.example.com'), dns.rdatatype.A):
92                        dns.rrset.from_text_list(dns.name.from_text('ns1.example.com'), 0, dns.rdataclass.IN, dns.rdatatype.A,
93                            ['192.0.2.1'])
94                }
95        stop_at2 = False
96        odd_ports2 = {}
97        zone_filename2 = None
98
99        delegation_mapping3 = {
100                (dns.name.from_text('example.com'), dns.rdatatype.NS):
101                        dns.rrset.from_text_list(dns.name.from_text('example.com'), 0, dns.rdataclass.IN, dns.rdatatype.NS,
102                            []),
103                }
104        stop_at3 = False
105        odd_ports3 = {}
106        zone_filename3 = EXAMPLE_COM_ZONE
107
108        delegation_mapping4 = {
109                (dns.name.from_text('example.com'), dns.rdatatype.NS):
110                        dns.rrset.from_text_list(dns.name.from_text('example.com'), 0, dns.rdataclass.IN, dns.rdatatype.NS,
111                            []),
112                }
113        stop_at4 = True
114        odd_ports4 = {}
115        zone_filename4 = EXAMPLE_COM_ZONE
116
117        obj = self.helper.authoritative_name_server_mappings(arg1)
118        self.assertEqual(obj.domain, dns.name.from_text('example.com'))
119        self.assertEqual(obj.delegation_mapping, delegation_mapping1)
120        self.assertEqual(obj.stop_at, stop_at1)
121        self.assertEqual(obj.odd_ports, odd_ports1)
122        self.assertEqual(obj.filename, zone_filename1)
123
124        obj = self.helper.authoritative_name_server_mappings(arg1_with_spaces)
125        self.assertEqual(obj.domain, dns.name.from_text('example.com'))
126        self.assertEqual(obj.delegation_mapping, delegation_mapping1)
127        self.assertEqual(obj.stop_at, stop_at1)
128        self.assertEqual(obj.odd_ports, odd_ports1)
129        self.assertEqual(obj.filename, zone_filename1)
130
131        obj = self.helper.authoritative_name_server_mappings(arg2)
132        self.assertEqual(obj.domain, dns.name.from_text('example.com'))
133        self.assertEqual(obj.delegation_mapping, delegation_mapping2)
134        self.assertEqual(obj.stop_at, stop_at2)
135        self.assertEqual(obj.odd_ports, odd_ports2)
136        self.assertEqual(obj.filename, zone_filename2)
137
138        obj = self.helper.authoritative_name_server_mappings(arg3)
139        self.assertEqual(obj.domain, dns.name.from_text('example.com'))
140        self.assertEqual(obj.delegation_mapping, delegation_mapping3)
141        self.assertEqual(obj.stop_at, stop_at3)
142        self.assertEqual(obj.odd_ports, odd_ports3)
143        self.assertEqual(obj.filename, zone_filename3)
144
145        obj = self.helper.authoritative_name_server_mappings(arg4)
146        self.assertEqual(obj.domain, dns.name.from_text('example.com'))
147        self.assertEqual(obj.delegation_mapping, delegation_mapping4)
148        self.assertEqual(obj.stop_at, stop_at4)
149        self.assertEqual(obj.odd_ports, odd_ports4)
150        self.assertEqual(obj.filename, zone_filename4)
151
152    def test_authoritative_errors(self):
153        # no mapping
154        arg = 'example.com'
155        with self.assertRaises(argparse.ArgumentTypeError):
156            self.helper.authoritative_name_server_mappings(arg)
157
158        # bad domain name
159        arg = 'example.com:ns1..foo.com'
160        with self.assertRaises(argparse.ArgumentTypeError):
161            self.helper.authoritative_name_server_mappings(arg)
162
163        # bad IPv4 address
164        arg = 'example.com:ns1.foo.com=192'
165        with self.assertRaises(argparse.ArgumentTypeError):
166            self.helper.authoritative_name_server_mappings(arg)
167
168        # Bad IPv6 address
169        arg = 'example.com:ns1.foo.com=2001:db8'
170        with self.assertRaises(argparse.ArgumentTypeError):
171            self.helper.authoritative_name_server_mappings(arg)
172
173        # IPv6 address needs brackets (IP valid even with port stripped)
174        arg = 'example.com:ns1.foo.com=2001:db8::1:3'
175        with self.assertRaises(argparse.ArgumentTypeError):
176            self.helper.authoritative_name_server_mappings(arg)
177
178        # IPv6 address needs brackets (IP invalid with port stripped)
179        arg = 'example.com:ns1.foo.com=2001:db8::3'
180        with self.assertRaises(argparse.ArgumentTypeError):
181            self.helper.authoritative_name_server_mappings(arg)
182
183        # Name does not resolve properly
184        arg = 'example.com:ns1.does-not-exist-foo-bar-baz-123-abc-dnsviz.net'
185        with self.assertRaises(argparse.ArgumentTypeError):
186            self.helper.authoritative_name_server_mappings(arg)
187
188    def test_delegation_option(self):
189        arg1 = 'example.com:ns1.example.com=192.0.2.1:1234,ns1.example.com=[2001:db8::1],' + \
190                'ns1.example.com=192.0.2.2,ns2.example.com=[2001:db8::2]'
191
192        arg1_with_spaces = ' example.com : ns1.example.com = [192.0.2.1]:1234 , ns1.example.com = [2001:db8::1], ' + \
193                'ns1.example.com = [192.0.2.2] , ns2.example.com = [2001:db8::2] '
194
195        arg2 = 'example.com:%s' % EXAMPLE_COM_DELEGATION
196
197        delegation_mapping1 = {
198                (dns.name.from_text('example.com'), dns.rdatatype.NS):
199                        dns.rrset.from_text_list(dns.name.from_text('example.com'), 0, dns.rdataclass.IN, dns.rdatatype.NS,
200                            ['ns1.example.com', 'ns2.example.com']),
201                (dns.name.from_text('ns1.example.com'), dns.rdatatype.A):
202                        dns.rrset.from_text_list(dns.name.from_text('ns1.example.com'), 0, dns.rdataclass.IN, dns.rdatatype.A,
203                            ['192.0.2.1', '192.0.2.2']),
204                (dns.name.from_text('ns1.example.com'), dns.rdatatype.AAAA):
205                        dns.rrset.from_text_list(dns.name.from_text('ns1.example.com'), 0, dns.rdataclass.IN, dns.rdatatype.AAAA,
206                            ['2001:db8::1']),
207                (dns.name.from_text('ns2.example.com'), dns.rdatatype.AAAA):
208                        dns.rrset.from_text_list(dns.name.from_text('ns2.example.com'), 0, dns.rdataclass.IN, dns.rdatatype.AAAA,
209                            ['2001:db8::2']),
210                }
211        stop_at1 = False
212        odd_ports1 = { (dns.name.from_text('example.com'), IPAddr('192.0.2.1')): 1234 }
213
214        delegation_mapping2 = {
215                (dns.name.from_text('example.com'), dns.rdatatype.NS):
216                        dns.rrset.from_text_list(dns.name.from_text('example.com'), 0, dns.rdataclass.IN, dns.rdatatype.NS,
217                            ['ns1.example.com']),
218                (dns.name.from_text('ns1.example.com'), dns.rdatatype.A):
219                        dns.rrset.from_text_list(dns.name.from_text('ns1.example.com'), 0, dns.rdataclass.IN, dns.rdatatype.A,
220                            ['127.0.0.1'])
221                }
222        stop_at2 = False
223        odd_ports2 = {}
224
225        obj = self.helper.delegation_name_server_mappings(arg1)
226        self.assertEqual(obj.domain, dns.name.from_text('example.com'))
227        self.assertEqual(obj.delegation_mapping, delegation_mapping1)
228        self.assertEqual(obj.stop_at, stop_at1)
229        self.assertEqual(obj.odd_ports, odd_ports1)
230
231        obj = self.helper.delegation_name_server_mappings(arg1_with_spaces)
232        self.assertEqual(obj.domain, dns.name.from_text('example.com'))
233        self.assertEqual(obj.delegation_mapping, delegation_mapping1)
234        self.assertEqual(obj.stop_at, stop_at1)
235        self.assertEqual(obj.odd_ports, odd_ports1)
236
237        obj = self.helper.delegation_name_server_mappings(arg2)
238        self.assertEqual(obj.domain, dns.name.from_text('example.com'))
239        self.assertEqual(obj.delegation_mapping, delegation_mapping2)
240        self.assertEqual(obj.stop_at, stop_at2)
241        self.assertEqual(obj.odd_ports, odd_ports2)
242
243    def test_delegation_errors(self):
244        # all the authoritative error tests as well
245
246        # requires name=addr mapping
247        arg = 'example.com:ns1.example.com'
248        with self.assertRaises(argparse.ArgumentTypeError):
249            self.helper.delegation_name_server_mappings(arg)
250
251        # requires name=addr mapping
252        arg = 'example.com:192.0.2.1'
253        with self.assertRaises(argparse.ArgumentTypeError):
254            self.helper.delegation_name_server_mappings(arg)
255
256        # doesn't allow +
257        arg = 'example.com+:ns1.example.com=192.0.2.1'
258        with self.assertRaises(argparse.ArgumentTypeError):
259            self.helper.delegation_name_server_mappings(arg)
260
261        # can't do this for root domain
262        arg = '.:ns1.example.com=192.0.2.1'
263        with self.assertRaises(argparse.ArgumentTypeError):
264            self.helper.delegation_name_server_mappings(arg)
265
266    def test_recursive_option(self):
267        arg1 = 'ns1.example.com=192.0.2.1:1234,ns1.example.com=[2001:db8::1],' + \
268                'ns1.example.com=192.0.2.2,ns2.example.com=[2001:db8::2],a.root-servers.net'
269
270        arg1_with_spaces = ' ns1.example.com = [192.0.2.1]:1234 , ns1.example.com = [2001:db8::1], ' + \
271                'ns1.example.com = [192.0.2.2] , ns2.example.com = [2001:db8::2] , a.root-servers.net '
272
273        delegation_mapping1 = {
274                (WILDCARD_EXPLICIT_DELEGATION, dns.rdatatype.NS):
275                        dns.rrset.from_text_list(WILDCARD_EXPLICIT_DELEGATION, 0, dns.rdataclass.IN, dns.rdatatype.NS,
276                            ['ns1.example.com', 'ns2.example.com', 'a.root-servers.net']),
277                (dns.name.from_text('ns1.example.com'), dns.rdatatype.A):
278                        dns.rrset.from_text_list(dns.name.from_text('ns1.example.com'), 0, dns.rdataclass.IN, dns.rdatatype.A,
279                            ['192.0.2.1', '192.0.2.2']),
280                (dns.name.from_text('ns1.example.com'), dns.rdatatype.AAAA):
281                        dns.rrset.from_text_list(dns.name.from_text('ns1.example.com'), 0, dns.rdataclass.IN, dns.rdatatype.AAAA,
282                            ['2001:db8::1']),
283                (dns.name.from_text('ns2.example.com'), dns.rdatatype.AAAA):
284                        dns.rrset.from_text_list(dns.name.from_text('ns2.example.com'), 0, dns.rdataclass.IN, dns.rdatatype.AAAA,
285                            ['2001:db8::2']),
286                (dns.name.from_text('a.root-servers.net'), dns.rdatatype.A):
287                        dns.rrset.from_text_list(dns.name.from_text('a.root-servers.net'), 0, dns.rdataclass.IN, dns.rdatatype.A,
288                            ['198.41.0.4']),
289                (dns.name.from_text('a.root-servers.net'), dns.rdatatype.AAAA):
290                        dns.rrset.from_text_list(dns.name.from_text('a.root-servers.net'), 0, dns.rdataclass.IN, dns.rdatatype.AAAA,
291                            ['2001:503:ba3e::2:30'])
292                }
293        stop_at1 = False
294        odd_ports1 = { (WILDCARD_EXPLICIT_DELEGATION, IPAddr('192.0.2.1')): 1234 }
295
296        obj = self.helper.recursive_servers_for_domain(arg1)
297        self.assertEqual(obj.domain, WILDCARD_EXPLICIT_DELEGATION)
298        self.assertEqual(obj.delegation_mapping, delegation_mapping1)
299        self.assertEqual(obj.stop_at, stop_at1)
300        self.assertEqual(obj.odd_ports, odd_ports1)
301
302        obj = self.helper.recursive_servers_for_domain(arg1_with_spaces)
303        self.assertEqual(obj.domain, WILDCARD_EXPLICIT_DELEGATION)
304        self.assertEqual(obj.delegation_mapping, delegation_mapping1)
305        self.assertEqual(obj.stop_at, stop_at1)
306        self.assertEqual(obj.odd_ports, odd_ports1)
307
308    def test_recursive_errors(self):
309        # all the authoritative error tests as well
310
311        # doesn't accept file
312        arg = EXAMPLE_COM_DELEGATION
313        with self.assertRaises(argparse.ArgumentTypeError):
314            self.helper.recursive_servers_for_domain(arg)
315
316    def test_ds_option(self):
317        arg1 = 'example.com:34983 10 1 EC358CFAAEC12266EF5ACFC1FEAF2CAFF083C418,' + \
318            '34983 10 2 608D3B089D79D554A1947BD10BEC0A5B1BDBE67B4E60E34B1432ED00 33F24B49'
319
320        delegation_mapping1 = {
321                (dns.name.from_text('example.com'), dns.rdatatype.DS):
322                        dns.rrset.from_text_list(dns.name.from_text('example.com'), 0, dns.rdataclass.IN, dns.rdatatype.DS,
323                            ['34983 10 1 EC358CFAAEC12266EF5ACFC1FEAF2CAFF083C418',
324                                '34983 10 2 608D3B089D79D554A1947BD10BEC0A5B1BDBE67B4E60E34B1432ED00 33F24B49'])
325                }
326
327        arg1_with_spaces = ' example.com : 34983 10 1 EC358CFAAEC12266EF5ACFC1FEAF2CAFF083C418, ' + \
328            ' 34983 10 2 608D3B089D79D554A1947BD10BEC0A5B1BDBE67B4E60E34B1432ED00 33F24B49 '
329
330        arg2 = 'example.com:%s' % EXAMPLE_COM_DELEGATION
331
332        delegation_mapping2 = {
333                (dns.name.from_text('example.com'), dns.rdatatype.DS):
334                        dns.rrset.from_text_list(dns.name.from_text('example.com'), 0, dns.rdataclass.IN, dns.rdatatype.DS,
335                            ['34983 10 1 EC358CFAAEC12266EF5ACFC1FEAF2CAFF083C418',
336                                '34983 10 2 608D3B089D79D554A1947BD10BEC0A5B1BDBE67B4E60E34B1432ED00 33F24B49'])
337                }
338
339
340        obj = self.helper.ds_for_domain(arg1)
341        self.assertEqual(obj.domain, dns.name.from_text('example.com'))
342        self.assertEqual(obj.delegation_mapping, delegation_mapping1)
343
344        obj = self.helper.ds_for_domain(arg1_with_spaces)
345        self.assertEqual(obj.domain, dns.name.from_text('example.com'))
346        self.assertEqual(obj.delegation_mapping, delegation_mapping1)
347
348        obj = self.helper.ds_for_domain(arg2)
349        self.assertEqual(obj.domain, dns.name.from_text('example.com'))
350        self.assertEqual(obj.delegation_mapping, delegation_mapping2)
351
352    def test_ds_error(self):
353        # bad DS record
354        arg = 'example.com:blah'
355        with self.assertRaises(argparse.ArgumentTypeError):
356            obj = self.helper.ds_for_domain(arg)
357
358    def test_positive_int(self):
359        self.assertEqual(ArgHelper.positive_int('1'), 1)
360        self.assertEqual(ArgHelper.positive_int('2'), 2)
361
362        # zero
363        with self.assertRaises(argparse.ArgumentTypeError):
364            ArgHelper.positive_int('0')
365
366        # negative
367        with self.assertRaises(argparse.ArgumentTypeError):
368            ArgHelper.positive_int('-1')
369
370    def test_bindable_ip(self):
371        self.assertEqual(ArgHelper.bindable_ip('127.0.0.1'), IPAddr('127.0.0.1'))
372        if self.use_ipv6:
373            self.assertEqual(ArgHelper.bindable_ip('::1'), IPAddr('::1'))
374
375        # invalid IPv4 address
376        with self.assertRaises(argparse.ArgumentTypeError):
377            ArgHelper.bindable_ip('192.')
378
379        # invalid IPv6 address
380        with self.assertRaises(argparse.ArgumentTypeError):
381            ArgHelper.bindable_ip('2001:')
382
383        # invalid IPv4 to bind to
384        with self.assertRaises(argparse.ArgumentTypeError):
385            ArgHelper.bindable_ip('192.0.2.1')
386
387        # invalid IPv6 to bind to
388        with self.assertRaises(argparse.ArgumentTypeError):
389            ArgHelper.bindable_ip('2001:db8::1')
390
391    def test_valid_url(self):
392        url1 = 'http://www.example.com/foo'
393        url2 = 'https://www.example.com/foo'
394        url3 = 'ws:///path/to/file'
395        url4 = 'ssh://user@example.com/foo'
396
397        self.assertEqual(ArgHelper.valid_url(url1), url1)
398        self.assertEqual(ArgHelper.valid_url(url2), url2)
399        self.assertEqual(ArgHelper.valid_url(url3), url3)
400        self.assertEqual(ArgHelper.valid_url(url4), url4)
401
402        # invalid schema
403        with self.assertRaises(argparse.ArgumentTypeError):
404            ArgHelper.valid_url('ftp://www.example.com/foo')
405
406        # ws with hostname
407        with self.assertRaises(argparse.ArgumentTypeError):
408            ArgHelper.valid_url('ws://www.example.com/foo')
409
410    def test_rrtype_list(self):
411        arg1 = 'A,AAAA,MX,CNAME'
412        arg1_with_spaces = ' A , AAAA , MX , CNAME '
413        arg2 = 'A'
414        arg3 = 'A,BLAH'
415        arg4_empty = ''
416        arg4_empty_spaces = ' '
417
418        type_list1 = [dns.rdatatype.A, dns.rdatatype.AAAA, dns.rdatatype.MX, dns.rdatatype.CNAME]
419        type_list2 = [dns.rdatatype.A]
420        empty_list = []
421
422        self.assertEqual(ArgHelper.comma_separated_dns_types(arg1), type_list1)
423        self.assertEqual(ArgHelper.comma_separated_dns_types(arg1_with_spaces), type_list1)
424        self.assertEqual(ArgHelper.comma_separated_dns_types(arg4_empty), empty_list)
425        self.assertEqual(ArgHelper.comma_separated_dns_types(arg4_empty_spaces), empty_list)
426
427        # invalid schema
428        with self.assertRaises(argparse.ArgumentTypeError):
429            ArgHelper.comma_separated_dns_types(arg3)
430
431    def test_valid_domain_name(self):
432        arg1 = '.'
433        arg2 = 'www.example.com'
434        arg3 = 'www..example.com'
435
436        self.assertEqual(ArgHelper.valid_domain_name(arg1), dns.name.from_text(arg1))
437        self.assertEqual(ArgHelper.valid_domain_name(arg2), dns.name.from_text(arg2))
438
439        # invalid domain name
440        with self.assertRaises(argparse.ArgumentTypeError):
441            ArgHelper.valid_domain_name(arg3)
442
443    def test_nsid_option(self):
444        self.assertEqual(ArgHelper.nsid_option(), dns.edns.GenericOption(3, b''))
445
446    def test_ecs_option(self):
447        arg1 = '192.0.2.0'
448        arg2 = '192.0.2.0/25'
449        arg3 = '192.0.2.255/25'
450        arg4 = '192.0.2.0/24'
451        arg5 = '2001:db8::'
452        arg6 = '2001:db8::/121'
453        arg7 = '2001:db8::ff/121'
454        arg8 = '2001:db8::/120'
455
456
457        ecs_option1 = dns.edns.GenericOption(8, binascii.unhexlify('00012000c0000200'))
458        ecs_option2 = dns.edns.GenericOption(8, binascii.unhexlify('00011900c0000200'))
459        ecs_option3 = dns.edns.GenericOption(8, binascii.unhexlify('00011900c0000280'))
460        ecs_option4 = dns.edns.GenericOption(8, binascii.unhexlify('00011800c00002'))
461        ecs_option5 = dns.edns.GenericOption(8, binascii.unhexlify('0002800020010db8000000000000000000000000'))
462        ecs_option6 = dns.edns.GenericOption(8, binascii.unhexlify('0002790020010db8000000000000000000000000'))
463        ecs_option7 = dns.edns.GenericOption(8, binascii.unhexlify('0002790020010db8000000000000000000000080'))
464        ecs_option8 = dns.edns.GenericOption(8, binascii.unhexlify('0002780020010db80000000000000000000000'))
465
466        self.assertEqual(ArgHelper.ecs_option(arg1), ecs_option1)
467        self.assertEqual(ArgHelper.ecs_option(arg2), ecs_option2)
468        self.assertEqual(ArgHelper.ecs_option(arg3), ecs_option3)
469        self.assertEqual(ArgHelper.ecs_option(arg4), ecs_option4)
470        self.assertEqual(ArgHelper.ecs_option(arg5), ecs_option5)
471        self.assertEqual(ArgHelper.ecs_option(arg6), ecs_option6)
472        self.assertEqual(ArgHelper.ecs_option(arg7), ecs_option7)
473        self.assertEqual(ArgHelper.ecs_option(arg8), ecs_option8)
474
475        # invalid IP address
476        with self.assertRaises(argparse.ArgumentTypeError):
477            ArgHelper.ecs_option('192')
478
479        # invalid length
480        with self.assertRaises(argparse.ArgumentTypeError):
481            ArgHelper.ecs_option('192.0.2.0/foo')
482
483        # invalid length
484        with self.assertRaises(argparse.ArgumentTypeError):
485            ArgHelper.ecs_option('192.0.2.0/33')
486
487        # invalid length
488        with self.assertRaises(argparse.ArgumentTypeError):
489            ArgHelper.ecs_option('2001:db8::/129')
490
491    def test_cookie_option(self):
492        arg1 = '0102030405060708'
493        arg2 = ''
494
495        cookie_option1 = dns.edns.GenericOption(10, binascii.unhexlify('0102030405060708'))
496        cookie_option2 = None
497
498        self.assertEqual(ArgHelper.dns_cookie_option(arg1), cookie_option1)
499        self.assertEqual(ArgHelper.dns_cookie_option(arg2), None)
500
501        self.assertIsInstance(ArgHelper.dns_cookie_rand(), dns.edns.GenericOption)
502
503        # too short
504        with self.assertRaises(argparse.ArgumentTypeError):
505            ArgHelper.dns_cookie_option('01')
506
507        # too long
508        with self.assertRaises(argparse.ArgumentTypeError):
509            ArgHelper.dns_cookie_option('010203040506070809')
510
511        # non-hexadecimal
512        with self.assertRaises(argparse.ArgumentTypeError):
513            ArgHelper.dns_cookie_option('010203040506070h')
514
515    def test_delegation_aggregation(self):
516        args1 = ['-A', '-N', 'example.com:ns1.example.com=192.0.2.1,ns1.example.com=[2001:db8::1]',
517                        '-N', 'example.com:ns1.example.com=192.0.2.4',
518                        '-N', 'example.com:ns2.example.com=192.0.2.2',
519                        '-N', 'example.com:ns3.example.com=192.0.2.3']
520        args2 = ['-A', '-N', 'example.com:ns1.example.com=192.0.2.1',
521                        '-D', 'example.com:34983 10 1 EC358CFAAEC12266EF5ACFC1FEAF2CAFF083C418',
522                        '-D', 'example.com:34983 10 2 608D3B089D79D554A1947BD10BEC0A5B1BDBE67B4E60E34B1432ED00 33F24B49']
523        args3 = ['-A', '-N', 'example.com:ns1.example.com=192.0.2.1',
524                        '-N', 'example1.com:ns1.example1.com=192.0.2.2']
525        args4 = ['-A', '-N', 'example.com:ns1.example.com=192.0.2.1',
526                        '-N', 'example.net:ns1.example.net=192.0.2.2']
527
528        explicit_delegations1 = {
529                (dns.name.from_text('com'), dns.rdatatype.NS):
530                        dns.rrset.from_text_list(dns.name.from_text('com'), 0, dns.rdataclass.IN, dns.rdatatype.NS,
531                            ['localhost']),
532                }
533        explicit_delegations2 = {
534                (dns.name.from_text('com'), dns.rdatatype.NS):
535                        dns.rrset.from_text_list(dns.name.from_text('com'), 0, dns.rdataclass.IN, dns.rdatatype.NS,
536                            ['localhost']),
537                }
538        explicit_delegations3 = {
539                (dns.name.from_text('com'), dns.rdatatype.NS):
540                        dns.rrset.from_text_list(dns.name.from_text('com'), 0, dns.rdataclass.IN, dns.rdatatype.NS,
541                            ['localhost']),
542                }
543        explicit_delegations4 = {
544                (dns.name.from_text('com'), dns.rdatatype.NS):
545                        dns.rrset.from_text_list(dns.name.from_text('com'), 0, dns.rdataclass.IN, dns.rdatatype.NS,
546                            ['localhost']),
547                (dns.name.from_text('net'), dns.rdatatype.NS):
548                        dns.rrset.from_text_list(dns.name.from_text('net'), 0, dns.rdataclass.IN, dns.rdatatype.NS,
549                            ['localhost']),
550                }
551
552        for ex in (explicit_delegations1, explicit_delegations2, explicit_delegations3, explicit_delegations4):
553            if self.use_ipv6:
554                    ex[(dns.name.from_text('localhost'), dns.rdatatype.AAAA)] = \
555                            dns.rrset.from_text_list(dns.name.from_text('localhost'), 0, dns.rdataclass.IN, dns.rdatatype.AAAA,
556                                ['::1'])
557                    loopback_ip = IPAddr('::1')
558            else:
559                    ex[(dns.name.from_text('localhost'), dns.rdatatype.A)] = \
560                            dns.rrset.from_text_list(dns.name.from_text('localhost'), 0, dns.rdataclass.IN, dns.rdatatype.A,
561                                ['127.0.0.1'])
562                    loopback_ip = IPAddr('127.0.0.1')
563
564        odd_ports1 = { (dns.name.from_text('com'), loopback_ip): self.first_port }
565        odd_ports2 = { (dns.name.from_text('com'), loopback_ip): self.first_port }
566        odd_ports3 = { (dns.name.from_text('com'), loopback_ip): self.first_port }
567        odd_ports4 = {
568                (dns.name.from_text('com'), loopback_ip): self.first_port,
569                (dns.name.from_text('net'), loopback_ip): self.first_port + 1,
570                }
571
572        if self.use_ipv6:
573            rdata = b'AAAA ::1'
574        else:
575            rdata = b'A 127.0.0.1'
576
577        zone_contents1 = b'''@ 600 IN SOA localhost. root.localhost. 1 1800 900 86400 600
578@ 600 IN NS @
579@ 600 IN ''' + rdata + \
580b'''
581example 0 IN NS ns1.example
582example 0 IN NS ns2.example
583example 0 IN NS ns3.example
584ns1.example 0 IN A 192.0.2.1
585ns1.example 0 IN A 192.0.2.4
586ns1.example 0 IN AAAA 2001:db8::1
587ns2.example 0 IN A 192.0.2.2
588ns3.example 0 IN A 192.0.2.3
589'''
590        zone_contents2 = b'''@ 600 IN SOA localhost. root.localhost. 1 1800 900 86400 600
591@ 600 IN NS @
592@ 600 IN ''' + rdata + \
593b'''
594example 0 IN DS 34983 10 1 ec358cfaaec12266ef5acfc1feaf2caff083c418
595example 0 IN DS 34983 10 2 608d3b089d79d554a1947bd10bec0a5b1bdbe67b4e60e34b1432ed0033f24b49
596example 0 IN NS ns1.example
597ns1.example 0 IN A 192.0.2.1
598'''
599
600        ZoneFileToServe._next_free_port = self.first_port
601
602        arghelper1 = ArgHelper(self.resolver, self.logger)
603        arghelper1.build_parser('probe')
604        arghelper1.parse_args(args1)
605        arghelper1.aggregate_delegation_info()
606        zone_to_serve = arghelper1._zones_to_serve[0]
607        zone_obj = dns.zone.from_file(zone_to_serve.filename, dns.name.from_text('com'))
608        zone_obj_other = dns.zone.from_text(zone_contents1, dns.name.from_text('com'))
609        self.assertEqual(zone_obj, zone_obj_other)
610        self.assertEqual(arghelper1.explicit_delegations, explicit_delegations1)
611        self.assertEqual(arghelper1.odd_ports, odd_ports1)
612
613        ZoneFileToServe._next_free_port = self.first_port
614
615        arghelper2 = ArgHelper(self.resolver, self.logger)
616        arghelper2.build_parser('probe')
617        arghelper2.parse_args(args2)
618        arghelper2.aggregate_delegation_info()
619        zone_to_serve = arghelper2._zones_to_serve[0]
620        zone_obj = dns.zone.from_file(zone_to_serve.filename, dns.name.from_text('com'))
621        zone_obj_other = dns.zone.from_text(zone_contents2, dns.name.from_text('com'))
622        self.assertEqual(zone_obj, zone_obj_other)
623        self.assertEqual(arghelper2.explicit_delegations, explicit_delegations2)
624        self.assertEqual(arghelper2.odd_ports, odd_ports2)
625
626        ZoneFileToServe._next_free_port = self.first_port
627
628        arghelper3 = ArgHelper(self.resolver, self.logger)
629        arghelper3.build_parser('probe')
630        arghelper3.parse_args(args3)
631        arghelper3.aggregate_delegation_info()
632        self.assertEqual(arghelper3.explicit_delegations, explicit_delegations3)
633        self.assertEqual(arghelper3.odd_ports, odd_ports3)
634
635        ZoneFileToServe._next_free_port = self.first_port
636
637        arghelper4 = ArgHelper(self.resolver, self.logger)
638        arghelper4.build_parser('probe')
639        arghelper4.parse_args(args4)
640        arghelper4.aggregate_delegation_info()
641        self.assertEqual(arghelper4.explicit_delegations, explicit_delegations4)
642        self.assertEqual(arghelper4.odd_ports, odd_ports4)
643
644    def test_delegation_authoritative_aggregation(self):
645        args1 = ['-A', '-N', 'example.com:ns1.example.com=192.0.2.1,ns1.example.com=[2001:db8::1]',
646                '-x', 'foo.com:ns1.foo.com=192.0.2.3:50503']
647
648        explicit_delegations1 = {
649                (dns.name.from_text('com'), dns.rdatatype.NS):
650                        dns.rrset.from_text_list(dns.name.from_text('com'), 0, dns.rdataclass.IN, dns.rdatatype.NS,
651                            ['localhost']),
652                (dns.name.from_text('foo.com'), dns.rdatatype.NS):
653                        dns.rrset.from_text_list(dns.name.from_text('foo.com'), 0, dns.rdataclass.IN, dns.rdatatype.NS,
654                            ['ns1.foo.com']),
655                (dns.name.from_text('ns1.foo.com'), dns.rdatatype.A):
656                        dns.rrset.from_text_list(dns.name.from_text('ns1.foo.com'), 0, dns.rdataclass.IN, dns.rdatatype.A,
657                            ['192.0.2.3']),
658                }
659
660        for ex in (explicit_delegations1,):
661            if self.use_ipv6:
662                    ex[(dns.name.from_text('localhost'), dns.rdatatype.AAAA)] = \
663                            dns.rrset.from_text_list(dns.name.from_text('localhost'), 0, dns.rdataclass.IN, dns.rdatatype.AAAA,
664                                ['::1'])
665                    loopback_ip = IPAddr('::1')
666            else:
667                    ex[(dns.name.from_text('localhost'), dns.rdatatype.A)] = \
668                            dns.rrset.from_text_list(dns.name.from_text('localhost'), 0, dns.rdataclass.IN, dns.rdatatype.A,
669                                ['127.0.0.1'])
670                    loopback_ip = IPAddr('127.0.0.1')
671
672        odd_ports1 = { (dns.name.from_text('com'), loopback_ip): self.first_port,
673                            (dns.name.from_text('foo.com'), IPAddr('192.0.2.3')): 50503,
674                    }
675
676        ZoneFileToServe._next_free_port = self.first_port
677
678        arghelper1 = ArgHelper(self.resolver, self.logger)
679        arghelper1.build_parser('probe')
680        arghelper1.parse_args(args1)
681        arghelper1.aggregate_delegation_info()
682        self.assertEqual(arghelper1.explicit_delegations, explicit_delegations1)
683        self.assertEqual(arghelper1.odd_ports, odd_ports1)
684
685    def test_delegation_authoritative_aggregation_errors(self):
686        args1 = ['-A', '-N', 'example.com:ns1.example.com=192.0.2.1,ns1.example.com=[2001:db8::1]',
687                '-x', 'com:ns1.foo.com=192.0.2.3']
688
689        arghelper1 = ArgHelper(self.resolver, self.logger)
690        arghelper1.build_parser('probe')
691        arghelper1.parse_args(args1)
692
693        # com is specified with -x but example.com is specified with -N
694        with self.assertRaises(argparse.ArgumentTypeError):
695            arghelper1.aggregate_delegation_info()
696
697    def test_recursive_aggregation(self):
698        args1 = ['-s', 'ns1.example.com=192.0.2.1,ns1.example.com=[2001:db8::1]',
699                        '-s', 'ns1.example.com=192.0.2.4,a.root-servers.net']
700
701        explicit_delegations1 = {
702                (WILDCARD_EXPLICIT_DELEGATION, dns.rdatatype.NS):
703                        dns.rrset.from_text_list(WILDCARD_EXPLICIT_DELEGATION, 0, dns.rdataclass.IN, dns.rdatatype.NS,
704                            ['ns1.example.com', 'a.root-servers.net']),
705                (dns.name.from_text('ns1.example.com'), dns.rdatatype.A):
706                        dns.rrset.from_text_list(dns.name.from_text('ns1.example.com'), 0, dns.rdataclass.IN, dns.rdatatype.A,
707                            ['192.0.2.1', '192.0.2.4']),
708                (dns.name.from_text('ns1.example.com'), dns.rdatatype.AAAA):
709                        dns.rrset.from_text_list(dns.name.from_text('ns1.example.com'), 0, dns.rdataclass.IN, dns.rdatatype.AAAA,
710                            ['2001:db8::1']),
711                (dns.name.from_text('a.root-servers.net'), dns.rdatatype.A):
712                        dns.rrset.from_text_list(dns.name.from_text('a.root-servers.net'), 0, dns.rdataclass.IN, dns.rdatatype.A,
713                            ['198.41.0.4']),
714                (dns.name.from_text('a.root-servers.net'), dns.rdatatype.AAAA):
715                        dns.rrset.from_text_list(dns.name.from_text('a.root-servers.net'), 0, dns.rdataclass.IN, dns.rdatatype.AAAA,
716                            ['2001:503:ba3e::2:30'])
717                }
718
719        odd_ports1 = {}
720
721        arghelper1 = ArgHelper(self.resolver, self.logger)
722        arghelper1.build_parser('probe')
723        arghelper1.parse_args(args1)
724        arghelper1.aggregate_delegation_info()
725        self.assertEqual(arghelper1.explicit_delegations, explicit_delegations1)
726        self.assertEqual(arghelper1.odd_ports, odd_ports1)
727
728    def test_option_combination_errors(self):
729
730        # Names, input file, or names file required
731        args = []
732        arghelper = ArgHelper(self.resolver, self.logger)
733        arghelper.build_parser('probe')
734        arghelper.parse_args(args)
735        with self.assertRaises(argparse.ArgumentTypeError):
736            arghelper.check_args()
737
738        # Names file and command-line domain names are mutually exclusive
739        args = ['-f', '/dev/null', 'example.com']
740        arghelper = ArgHelper(self.resolver, self.logger)
741        arghelper.build_parser('probe')
742        arghelper.parse_args(args)
743        with self.assertRaises(argparse.ArgumentTypeError):
744            arghelper.check_args()
745        arghelper.args.names_file.close()
746
747        # Authoritative analysis and recursive servers
748        args = ['-A', '-s', '192.0.2.1', 'example.com']
749        arghelper = ArgHelper(self.resolver, self.logger)
750        arghelper.build_parser('probe')
751        arghelper.parse_args(args)
752        with self.assertRaises(argparse.ArgumentTypeError):
753            arghelper.check_args()
754
755        # Authoritative servers with recursive analysis
756        args = ['-x', 'example.com:ns1.example.com=192.0.2.1', 'example.com']
757        arghelper = ArgHelper(self.resolver, self.logger)
758        arghelper.build_parser('probe')
759        arghelper.parse_args(args)
760        with self.assertRaises(argparse.ArgumentTypeError):
761            arghelper.check_args()
762
763        # Delegation information with recursive analysis
764        args = ['-N', 'example.com:ns1.example.com=192.0.2.1', 'example.com']
765        arghelper = ArgHelper(self.resolver, self.logger)
766        arghelper.build_parser('probe')
767        arghelper.parse_args(args)
768        with self.assertRaises(argparse.ArgumentTypeError):
769            arghelper.check_args()
770
771        # Delegation information with recursive analysis
772        args = [ '-D', 'example.com:34983 10 1 EC358CFAAEC12266EF5ACFC1FEAF2CAFF083C418', 'example.com']
773        arghelper = ArgHelper(self.resolver, self.logger)
774        arghelper.build_parser('probe')
775        arghelper.parse_args(args)
776        with self.assertRaises(argparse.ArgumentTypeError):
777            arghelper.check_args()
778
779    def test_ceiling(self):
780        args = ['-a', 'com', 'example.com']
781        arghelper = ArgHelper(self.resolver, self.logger)
782        arghelper.build_parser('probe')
783        arghelper.parse_args(args)
784        arghelper.set_kwargs()
785        self.assertEqual(arghelper.ceiling, dns.name.from_text('com'))
786
787        args = ['example.com']
788        arghelper = ArgHelper(self.resolver, self.logger)
789        arghelper.build_parser('probe')
790        arghelper.parse_args(args)
791        arghelper.set_kwargs()
792        self.assertEqual(arghelper.ceiling, dns.name.root)
793
794        args = ['-A', 'example.com']
795        arghelper = ArgHelper(self.resolver, self.logger)
796        arghelper.build_parser('probe')
797        arghelper.parse_args(args)
798        arghelper.set_kwargs()
799        self.assertIsNone(arghelper.ceiling)
800
801    def test_ip4_ipv6(self):
802        args = []
803        arghelper = ArgHelper(self.resolver, self.logger)
804        arghelper.build_parser('probe')
805        arghelper.parse_args(args)
806        arghelper.set_kwargs()
807        self.assertEqual(arghelper.try_ipv4, True)
808        self.assertEqual(arghelper.try_ipv6, True)
809
810        args = ['-4', '-6']
811        arghelper = ArgHelper(self.resolver, self.logger)
812        arghelper.build_parser('probe')
813        arghelper.parse_args(args)
814        arghelper.set_kwargs()
815        self.assertEqual(arghelper.try_ipv4, True)
816        self.assertEqual(arghelper.try_ipv6, True)
817
818        args = ['-4']
819        arghelper = ArgHelper(self.resolver, self.logger)
820        arghelper.build_parser('probe')
821        arghelper.parse_args(args)
822        arghelper.set_kwargs()
823        self.assertEqual(arghelper.try_ipv4, True)
824        self.assertEqual(arghelper.try_ipv6, False)
825
826        args = ['-6']
827        arghelper = ArgHelper(self.resolver, self.logger)
828        arghelper.build_parser('probe')
829        arghelper.parse_args(args)
830        arghelper.set_kwargs()
831        self.assertEqual(arghelper.try_ipv4, False)
832        self.assertEqual(arghelper.try_ipv6, True)
833
834    def test_client_ip(self):
835        args = []
836        arghelper = ArgHelper(self.resolver, self.logger)
837        arghelper.build_parser('probe')
838        arghelper.parse_args(args)
839        arghelper.set_kwargs()
840        self.assertIsNone(arghelper.client_ipv4)
841        self.assertIsNone(arghelper.client_ipv6)
842
843        args = ['-b', '127.0.0.1']
844        if self.use_ipv6:
845            args.extend(['-b', '::1'])
846        arghelper = ArgHelper(self.resolver, self.logger)
847        arghelper.build_parser('probe')
848        arghelper.parse_args(args)
849        arghelper.set_kwargs()
850        self.assertEqual(arghelper.client_ipv4, IPAddr('127.0.0.1'))
851        if self.use_ipv6:
852            self.assertEqual(arghelper.client_ipv6, IPAddr('::1'))
853
854    def test_th_factories(self):
855        args = ['example.com']
856        arghelper = ArgHelper(self.resolver, self.logger)
857        arghelper.build_parser('probe')
858        arghelper.parse_args(args)
859        arghelper.set_kwargs()
860        self.assertIsNone(arghelper.th_factories)
861
862        args = ['-u', 'http://example.com/', 'example.com']
863        arghelper = ArgHelper(self.resolver, self.logger)
864        arghelper.build_parser('probe')
865        arghelper.parse_args(args)
866        arghelper.set_kwargs()
867        self.assertIsInstance(arghelper.th_factories[0], transport.DNSQueryTransportHandlerHTTPFactory)
868
869        args = ['-u', 'ws:///dev/null', 'example.com']
870        arghelper = ArgHelper(self.resolver, self.logger)
871        arghelper.build_parser('probe')
872        arghelper.parse_args(args)
873        arghelper.set_kwargs()
874        self.assertIsInstance(arghelper.th_factories[0], transport.DNSQueryTransportHandlerWebSocketServerFactory)
875
876        args = ['-u', 'ssh://example.com/', 'example.com']
877        arghelper = ArgHelper(self.resolver, self.logger)
878        arghelper.build_parser('probe')
879        arghelper.parse_args(args)
880        arghelper.set_kwargs()
881        self.assertIsInstance(arghelper.th_factories[0], transport.DNSQueryTransportHandlerRemoteCmdFactory)
882
883    def test_edns_options(self):
884        CustomQueryMixin.edns_options = self.custom_query_mixin_edns_options_orig[:]
885
886        # None
887        args = ['-c', '', 'example.com']
888        arghelper = ArgHelper(self.resolver, self.logger)
889        arghelper.build_parser('probe')
890        arghelper.parse_args(args)
891        arghelper.set_kwargs()
892        self.assertEqual(len(CustomQueryMixin.edns_options), 0)
893
894        CustomQueryMixin.edns_options = self.custom_query_mixin_edns_options_orig[:]
895
896        # Only DNS cookie
897        args = ['example.com']
898        arghelper = ArgHelper(self.resolver, self.logger)
899        arghelper.build_parser('probe')
900        arghelper.parse_args(args)
901        arghelper.set_kwargs()
902        self.assertEqual(set([o.otype for o in CustomQueryMixin.edns_options]), set([10]))
903
904        CustomQueryMixin.edns_options = self.custom_query_mixin_edns_options_orig[:]
905
906        # All EDNS options
907        args = ['-n', '-e', '192.0.2.0/24', 'example.com']
908        arghelper = ArgHelper(self.resolver, self.logger)
909        arghelper.build_parser('probe')
910        arghelper.parse_args(args)
911        arghelper.set_kwargs()
912        self.assertEqual(set([o.otype for o in CustomQueryMixin.edns_options]), set([3, 8, 10]))
913
914        CustomQueryMixin.edns_options = self.custom_query_mixin_edns_options_orig[:]
915
916    def test_ingest_input(self):
917        with tempfile.NamedTemporaryFile('wb', prefix='dnsviz', delete=False) as example_bad_json:
918            example_bad_json.write(b'{')
919
920        with tempfile.NamedTemporaryFile('wb', prefix='dnsviz', delete=False) as example_no_version:
921            example_no_version.write(b'{}')
922
923        with tempfile.NamedTemporaryFile('wb', prefix='dnsviz', delete=False) as example_invalid_version_1:
924            example_invalid_version_1.write(b'{ "_meta._dnsviz.": { "version": 1.11 } }')
925
926        with tempfile.NamedTemporaryFile('wb', prefix='dnsviz', delete=False) as example_invalid_version_2:
927            example_invalid_version_2.write(b'{ "_meta._dnsviz.": { "version": 5.0 } }')
928
929        with gzip.open(EXAMPLE_AUTHORITATIVE, 'rb') as example_auth_in:
930            with tempfile.NamedTemporaryFile('wb', prefix='dnsviz', delete=False) as example_auth_out:
931                example_auth_out.write(example_auth_in.read())
932
933        try:
934            args = ['-r', example_auth_out.name]
935            arghelper = ArgHelper(self.resolver, self.logger)
936            arghelper.build_parser('probe')
937            arghelper.parse_args(args)
938            arghelper.ingest_input()
939
940            # Bad json
941            args = ['-r', example_bad_json.name]
942            arghelper = ArgHelper(self.resolver, self.logger)
943            arghelper.build_parser('probe')
944            arghelper.parse_args(args)
945            with self.assertRaises(AnalysisInputError):
946                arghelper.ingest_input()
947
948            # No version
949            args = ['-r', example_no_version.name]
950            arghelper = ArgHelper(self.resolver, self.logger)
951            arghelper.build_parser('probe')
952            arghelper.parse_args(args)
953            with self.assertRaises(AnalysisInputError):
954                arghelper.ingest_input()
955
956            # Invalid version
957            args = ['-r', example_invalid_version_1.name]
958            arghelper = ArgHelper(self.resolver, self.logger)
959            arghelper.build_parser('probe')
960            arghelper.parse_args(args)
961            with self.assertRaises(AnalysisInputError):
962                arghelper.ingest_input()
963
964            # Invalid version
965            args = ['-r', example_invalid_version_2.name]
966            arghelper = ArgHelper(self.resolver, self.logger)
967            arghelper.build_parser('probe')
968            arghelper.parse_args(args)
969            with self.assertRaises(AnalysisInputError):
970                arghelper.ingest_input()
971
972        finally:
973            for tmpfile in (example_auth_out, example_bad_json, example_no_version, \
974                    example_invalid_version_1, example_invalid_version_2):
975                os.remove(tmpfile.name)
976
977    def test_ingest_names(self):
978        args = ['example.com', 'example.net']
979        arghelper = ArgHelper(self.resolver, self.logger)
980        arghelper.build_parser('probe')
981        arghelper.parse_args(args)
982        arghelper.ingest_names()
983        self.assertEqual(list(arghelper.names), [dns.name.from_text('example.com'), dns.name.from_text('example.net')])
984
985        unicode_name = 'テスト'
986
987        args = [unicode_name]
988        arghelper = ArgHelper(self.resolver, self.logger)
989        arghelper.build_parser('probe')
990        arghelper.parse_args(args)
991        arghelper.ingest_names()
992        self.assertEqual(list(arghelper.names), [dns.name.from_text('xn--zckzah.')])
993
994        with tempfile.NamedTemporaryFile('wb', prefix='dnsviz', delete=False) as names_file:
995            names_file.write('example.com\nexample.net\n'.encode('utf-8'))
996
997        with tempfile.NamedTemporaryFile('wb', prefix='dnsviz', delete=False) as names_file_unicode:
998            try:
999                names_file_unicode.write(('%s\n' % (unicode_name)).encode('utf-8'))
1000            # python3/python2 dual compatibility
1001            except UnicodeDecodeError:
1002                names_file_unicode.write(('%s\n' % (unicode_name)))
1003
1004        with tempfile.NamedTemporaryFile('wb', prefix='dnsviz', delete=False) as example_names_only:
1005            example_names_only.write(b'{ "_meta._dnsviz.": { "version": 1.2, "names": [ "example.com.", "example.net.", "example.org." ] } }')
1006
1007        try:
1008            args = ['-f', names_file.name]
1009            arghelper = ArgHelper(self.resolver, self.logger)
1010            arghelper.build_parser('probe')
1011            arghelper.parse_args(args)
1012            arghelper.ingest_names()
1013            self.assertEqual(list(arghelper.names), [dns.name.from_text('example.com'), dns.name.from_text('example.net')])
1014
1015            args = ['-f', names_file_unicode.name]
1016            arghelper = ArgHelper(self.resolver, self.logger)
1017            arghelper.build_parser('probe')
1018            arghelper.parse_args(args)
1019            arghelper.ingest_names()
1020            self.assertEqual(list(arghelper.names), [dns.name.from_text('xn--zckzah.')])
1021
1022            args = ['-r', example_names_only.name]
1023            arghelper = ArgHelper(self.resolver, self.logger)
1024            arghelper.build_parser('probe')
1025            arghelper.parse_args(args)
1026            arghelper.ingest_input()
1027            arghelper.ingest_names()
1028            self.assertEqual(list(arghelper.names), [dns.name.from_text('example.com'), dns.name.from_text('example.net'), dns.name.from_text('example.org')])
1029
1030            args = ['-r', example_names_only.name, 'example.com']
1031            arghelper = ArgHelper(self.resolver, self.logger)
1032            arghelper.build_parser('probe')
1033            arghelper.parse_args(args)
1034            arghelper.ingest_input()
1035            arghelper.ingest_names()
1036            self.assertEqual(list(arghelper.names), [dns.name.from_text('example.com')])
1037        finally:
1038            for tmpfile in (names_file, names_file_unicode, example_names_only):
1039                os.remove(tmpfile.name)
1040
1041if __name__ == '__main__':
1042    unittest.main()
1043