1#!/usr/bin/env python3
2
3import itertools
4import os
5import sys
6from typing import Any, Callable, Iterable, Iterator, Optional, List, Union, Set  # noqa
7
8import pydnstest.augwrap
9import pydnstest.matchpart
10import pydnstest.scenario
11
12Element = Union["Entry", "Step", pydnstest.scenario.Range]
13
14RCODES = {"NOERROR", "FORMERR", "SERVFAIL", "NXDOMAIN", "NOTIMP", "REFUSED", "YXDOMAIN", "YXRRSET",
15          "NXRRSET", "NOTAUTH", "NOTZONE", "BADVERS", "BADSIG", "BADKEY", "BADTIME", "BADMODE",
16          "BADNAME", "BADALG", "BADTRUNC", "BADCOOKIE"}
17FLAGS = {"QR", "AA", "TC", "RD", "RA", "AD", "CD"}
18SECTIONS = {"question", "answer", "authority", "additional"}
19
20
21class RplintError(ValueError):
22    def __init__(self, fails):
23        msg = ""
24        for fail in fails:
25            msg += str(fail) + "\n"
26        super().__init__(msg)
27
28
29def get_line_number(file: str, char_number: int) -> int:
30    pos = 0
31    for number, line in enumerate(open(file)):
32        pos += len(line)
33        if pos >= char_number:
34            return number + 2
35    return 0
36
37
38def is_empty(iterable: Iterator[Any]) -> bool:
39    try:
40        next(iterable)
41    except StopIteration:
42        return True
43    return False
44
45
46class Entry:
47    def __init__(self, node: pydnstest.augwrap.AugeasNode) -> None:
48        self.match = {m.value for m in node.match("/match")}
49        self.adjust = {a.value for a in node.match("/adjust")}
50        self.answer = list(node.match("/section/answer/record"))
51        self.authority = list(node.match("/section/authority/record"))
52        self.additional = list(node.match("/section/additional/record"))
53        self.reply = {r.value for r in node.match("/reply")}
54        self.records = list(node.match("/section/*/record"))
55        self.node = node
56
57
58class Step:
59    def __init__(self, node: pydnstest.augwrap.AugeasNode) -> None:
60        self.node = node
61        self.type = node["/type"].value
62        try:
63            self.entry = Entry(node["/entry"])  # type: Optional[Entry]
64        except KeyError:
65            self.entry = None
66
67
68class RplintFail:
69    def __init__(self, test: "RplintTest",
70                 element: Optional[Element] = None,
71                 etc: str = "") -> None:
72        self.path = test.path
73        self.element = element  # type: Optional[Element]
74        self.line = get_line_number(self.path, element.node.char if element is not None else 0)
75        self.etc = etc
76        self.check = None  # type: Optional[Callable[[RplintTest], List[RplintFail]]]
77
78    def __str__(self):
79        if self.etc:
80            return "{}:{} {}: {} ({})".format(os.path.basename(self.path), self.line,
81                                              self.check.__name__, self.check.__doc__, self.etc)
82        return "{}:{} {}: {}".format(os.path.basename(self.path), self.line, self.check.__name__,
83                                     self.check.__doc__)
84
85
86class RplintTest:
87    def __init__(self, path: str) -> None:
88        aug = pydnstest.augwrap.AugeasWrapper(confpath=os.path.realpath(path),
89                                              lens='Deckard',
90                                              loadpath=os.path.join(os.path.dirname(__file__),
91                                                                    'pydnstest'))
92        self.node = aug.tree
93        self.name = os.path.basename(path)
94        self.path = path
95
96        _, self.config = pydnstest.scenario.parse_file(os.path.realpath(path))
97        self.range_entries = [Entry(node) for node in self.node.match("/scenario/range/entry")]
98        self.steps = [Step(node) for node in self.node.match("/scenario/step")]
99        self.step_entries = [step.entry for step in self.steps if step.entry is not None]
100        self.entries = self.range_entries + self.step_entries
101
102        self.ranges = [pydnstest.scenario.Range(n) for n in self.node.match("/scenario/range")]
103
104        self.fails = None  # type: Optional[List[RplintFail]]
105        self.checks = [
106            entry_more_than_one_rcode,
107            entry_no_qname_qtype_copy_query,
108            # Commented out for now until we implement selective turning off of checks
109            # entry_ns_in_authority,
110            range_overlapping_ips,
111            range_shadowing_match_rules,
112            step_check_answer_no_match,
113            step_query_match,
114            step_query_sections,
115            step_query_qr,
116            step_section_unchecked,
117            step_unchecked_match,
118            step_unchecked_rcode,
119            scenario_ad_or_rrsig_no_ta,
120            scenario_timestamp,
121            config_trust_anchor_trailing_period_missing,
122            step_duplicate_id,
123        ]
124
125    def run_checks(self) -> bool:
126        """returns True iff all tests passed"""
127        self.fails = []
128        for check in self.checks:
129            fails = check(self)
130            for fail in fails:
131                fail.check = check
132            self.fails += fails
133
134        if self.fails == []:
135            return True
136        return False
137
138    def print_fails(self) -> None:
139        if self.fails is None:
140            raise RuntimeError("Maybe you should run some test first…")
141        for fail in self.fails:
142            print(fail)
143
144
145def config_trust_anchor_trailing_period_missing(test: RplintTest) -> List[RplintFail]:
146    """Trust-anchor option in configuration contains domain without trailing period"""
147    for conf in test.config:
148        if conf[0] == "trust-anchor":
149            if conf[1].split()[0][-1] != ".":
150                return [RplintFail(test, etc=conf[1])]
151    return []
152
153
154def scenario_timestamp(test: RplintTest) -> List[RplintFail]:
155    """RRSSIG record present in test but no val-override-date or val-override-timestamp in config"""
156    rrsigs = []
157    for entry in test.entries:
158        for record in entry.records:
159            if record["/type"].value == "RRSIG":
160                rrsigs.append(RplintFail(test, entry))
161    if rrsigs:
162        for k in test.config:
163            if k[0] == "val-override-date" or k[0] == "val-override-timestamp":
164                return []
165    return rrsigs
166
167
168def entry_no_qname_qtype_copy_query(test: RplintTest) -> List[RplintFail]:
169    """ENTRY without qname and qtype in MATCH and without copy_query in ADJUST"""
170    fails = []
171    for entry in test.range_entries:
172        if "question" not in entry.match and ("qname" not in entry.match or
173                                              "qtype" not in entry.match):
174            if "copy_query" not in entry.adjust:
175                fails.append(RplintFail(test, entry))
176    return fails
177
178
179def entry_ns_in_authority(test: RplintTest) -> List[RplintFail]:
180    """ENTRY has authority section with NS records, consider using MATCH subdomain"""
181    fails = []
182    for entry in test.range_entries:
183        if entry.authority and "subdomain" not in entry.match:
184            for record in entry.authority:
185                if record["/type"].value == "NS":
186                    fails.append(RplintFail(test, entry))
187    return fails
188
189
190def entry_more_than_one_rcode(test: RplintTest) -> List[RplintFail]:
191    """ENTRY has more than one rcode in MATCH"""
192    fails = []
193    for entry in test.entries:
194        if len(RCODES & entry.reply) > 1:
195            fails.append(RplintFail(test, entry))
196    return fails
197
198
199def scenario_ad_or_rrsig_no_ta(test: RplintTest) -> List[RplintFail]:
200    """AD or RRSIG present in test but no trust-anchor present in config"""
201    dnssec = []
202    for entry in test.entries:
203        if "AD" in entry.reply or "AD" in entry.match:
204            dnssec.append(RplintFail(test, entry))
205        else:
206            for record in entry.records:
207                if record["/type"].value == "RRSIG":
208                    dnssec.append(RplintFail(test, entry))
209
210    if dnssec:
211        for k in test.config:
212            if k[0] == "trust-anchor":
213                return []
214    return dnssec
215
216
217def step_query_match(test: RplintTest) -> List[RplintFail]:
218    """STEP QUERY has a MATCH rule"""
219    return [RplintFail(test, step) for step in test.steps if step.type == "QUERY" and
220            step.entry and step.entry.match]
221
222
223def step_query_sections(test: RplintTest) -> List[RplintFail]:
224    """STEP QUERY has some records in sections other than QUESTION"""
225    return [RplintFail(test, step) for step in test.steps if step.type == "QUERY" and
226            step.entry and (step.entry.answer or step.entry.authority or step.entry.additional)]
227
228
229def step_query_qr(test: RplintTest) -> List[RplintFail]:
230    """STEP QUERY specified QR=1 flag (i.e. message is an answer)"""
231    return [RplintFail(test, step) for step in test.steps if step.type == "QUERY" and
232            step.entry and step.entry.reply and 'QR' in step.entry.reply]
233
234
235def step_check_answer_no_match(test: RplintTest) -> List[RplintFail]:
236    """ENTRY in STEP CHECK_ANSWER has no MATCH rule"""
237    return [RplintFail(test, step) for step in test.steps if step.type == "CHECK_ANSWER" and
238            step.entry and not step.entry.match]
239
240
241def step_unchecked_rcode(test: RplintTest) -> List[RplintFail]:
242    """ENTRY specifies rcode but STEP MATCH does not check for it."""
243    fails = []
244    for step in test.steps:
245        if step.type == "CHECK_ANSWER" and step.entry and "all" not in step.entry.match:
246            if step.entry.reply & RCODES and "rcode" not in step.entry.match:
247                fails.append(RplintFail(test, step.entry))
248    return fails
249
250
251def step_unchecked_match(test: RplintTest) -> List[RplintFail]:
252    """ENTRY specifies flags but MATCH does not check for them"""
253    fails = []
254    for step in test.steps:
255        if step.type == "CHECK_ANSWER":
256            entry = step.entry
257            if entry and "all" not in entry.match and entry.reply - RCODES and \
258               "flags" not in entry.match:
259                fails.append(RplintFail(test, entry, str(entry.reply - RCODES)))
260    return fails
261
262
263def step_section_unchecked(test: RplintTest) -> List[RplintFail]:
264    """ENTRY has non-empty sections but MATCH does not check for all of them"""
265    fails = []
266    for step in test.steps:
267        if step.type == "CHECK_ANSWER" and step.entry and "all" not in step.entry.match:
268            for section in SECTIONS:
269                if not is_empty(step.node.match("/entry/section/" + section + "/*")):
270                    if section not in step.entry.match:
271                        fails.append(RplintFail(test, step.entry, section))
272    return fails
273
274
275def range_overlapping_ips(test: RplintTest) -> List[RplintFail]:
276    """RANGE has common IPs with some previous overlapping RANGE"""
277    fails = []
278    for r1, r2 in itertools.combinations(test.ranges, 2):
279        # If the ranges overlap
280        if min(r1.b, r2.b) >= max(r1.a, r2.a):
281            if r1.addresses & r2.addresses:
282                info = "previous range on line %d" % get_line_number(test.path, r1.node.char)
283                fails.append(RplintFail(test, r2, info))
284    return fails
285
286
287def range_shadowing_match_rules(test: RplintTest) -> List[RplintFail]:
288    """ENTRY has no effect since one of previous entries has the same or broader match rules"""
289    fails = []
290    for r in test.ranges:
291        for e1, e2 in itertools.combinations(r.stored, 2):
292            try:
293                e1.match(e2.message)
294            except ValueError:
295                pass
296            else:
297                info = "previous entry on line %d" % get_line_number(test.path, e1.node.char)
298                if e1.match_fields > e2.match_fields:
299                    continue
300                if "subdomain" not in e1.match_fields and "subdomain" in e2.match_fields:
301                    continue
302                fails.append(RplintFail(test, e2, info))
303    return fails
304
305
306def step_duplicate_id(test: RplintTest) -> List[RplintFail]:
307    """STEP has the same ID as one of previous ones"""
308    fails = []
309    step_numbers = set()  # type: Set[int]
310    for step in test.steps:
311        if step.node.value in step_numbers:
312            fails.append(RplintFail(test, step))
313        else:
314            step_numbers.add(step.node.value)
315    return fails
316
317
318# TODO: This will make sense after we fix how we handle defaults in deckard.aug and scenario.py
319# We might just not use defaults altogether as testbound does
320# if "copy_id" not in adjust:
321#    entry_error(test, entry, "copy_id should be in ADJUST")
322
323def test_run_rplint(rpl_path: str) -> None:
324    t = RplintTest(rpl_path)
325    passed = t.run_checks()
326    if not passed:
327        raise RplintError(t.fails)
328
329
330def main():
331    try:
332        test_path = sys.argv[1]
333    except IndexError:
334        print("usage: %s <path to rpl file>" % sys.argv[0])
335        sys.exit(2)
336    if not os.path.isfile(test_path):
337        print("rplint.py works on single file only.")
338        print("Use rplint.sh with --scenarios=<directory with rpls> to run on rpls.")
339        sys.exit(2)
340    print("Linting %s" % test_path)
341    t = RplintTest(test_path)
342    passed = t.run_checks()
343    t.print_fails()
344
345    if passed:
346        sys.exit(0)
347    sys.exit(1)
348
349
350if __name__ == '__main__':
351    main()
352