1"""matchpart is used to compare two DNS messages using a single criterion"""
2
3from typing import (  # noqa
4    Any, Hashable, Sequence, Tuple, Union)
5
6import dns.edns
7import dns.rcode
8import dns.set
9
10MismatchValue = Union[str, Sequence[Any]]
11
12
13class DataMismatch(Exception):
14    def __init__(self, exp_val, got_val):
15        super().__init__()
16        self.exp_val = exp_val
17        self.got_val = got_val
18
19    @staticmethod
20    def format_value(value: MismatchValue) -> str:
21        if isinstance(value, list):
22            return ' '.join([str(val) for val in value])
23        else:
24            return str(value)
25
26    def __str__(self) -> str:
27        return 'expected "{}" got "{}"'.format(
28            self.format_value(self.exp_val),
29            self.format_value(self.got_val))
30
31    def __eq__(self, other):
32        return (isinstance(other, DataMismatch)
33                and self.exp_val == other.exp_val
34                and self.got_val == other.got_val)
35
36    def __ne__(self, other):
37        return not self.__eq__(other)
38
39    @property
40    def key(self) -> Tuple[Hashable, Hashable]:
41        def make_hashable(value):
42            if isinstance(value, (list, dns.set.Set)):
43                value = (make_hashable(item) for item in value)
44                value = tuple(value)
45            return value
46
47        return (make_hashable(self.exp_val), make_hashable(self.got_val))
48
49    def __hash__(self) -> int:
50        return hash(self.key)
51
52
53def compare_val(exp, got):
54    """Compare arbitraty objects, throw exception if different. """
55    if exp != got:
56        raise DataMismatch(exp, got)
57    return True
58
59
60def compare_rrs(expected, got):
61    """ Compare lists of RR sets, throw exception if different. """
62    for rr in expected:
63        if rr not in got:
64            raise DataMismatch(expected, got)
65    for rr in got:
66        if rr not in expected:
67            raise DataMismatch(expected, got)
68    if len(expected) != len(got):
69        raise DataMismatch(expected, got)
70    return True
71
72
73def compare_rrs_types(exp_val, got_val, skip_rrsigs):
74    """sets of RR types in both sections must match"""
75    def rr_ordering_key(rrset):
76        if rrset.covers:
77            return rrset.covers, 1  # RRSIGs go to the end of RRtype list
78        else:
79            return rrset.rdtype, 0
80
81    def key_to_text(rrtype, rrsig):
82        if not rrsig:
83            return dns.rdatatype.to_text(rrtype)
84        else:
85            return 'RRSIG(%s)' % dns.rdatatype.to_text(rrtype)
86
87    if skip_rrsigs:
88        exp_val = (rrset for rrset in exp_val
89                   if rrset.rdtype != dns.rdatatype.RRSIG)
90        got_val = (rrset for rrset in got_val
91                   if rrset.rdtype != dns.rdatatype.RRSIG)
92
93    exp_types = frozenset(rr_ordering_key(rrset) for rrset in exp_val)
94    got_types = frozenset(rr_ordering_key(rrset) for rrset in got_val)
95    if exp_types != got_types:
96        exp_types = tuple(key_to_text(*i) for i in sorted(exp_types))
97        got_types = tuple(key_to_text(*i) for i in sorted(got_types))
98        raise DataMismatch(exp_types, got_types)
99
100
101def check_question(question):
102    if len(question) > 2:
103        raise NotImplementedError("More than one record in QUESTION SECTION.")
104
105
106def match_opcode(exp, got):
107    return compare_val(exp.opcode(),
108                       got.opcode())
109
110
111def match_qtype(exp, got):
112    check_question(exp.question)
113    check_question(got.question)
114    if not exp.question and not got.question:
115        return True
116    if not exp.question:
117        raise DataMismatch("<empty question>", got.question[0].rdtype)
118    if not got.question:
119        raise DataMismatch(exp.question[0].rdtype, "<empty question>")
120    return compare_val(exp.question[0].rdtype,
121                       got.question[0].rdtype)
122
123
124def match_qname(exp, got):
125    check_question(exp.question)
126    check_question(got.question)
127    if not exp.question and not got.question:
128        return True
129    if not exp.question:
130        raise DataMismatch("<empty question>", got.question[0].name)
131    if not got.question:
132        raise DataMismatch(exp.question[0].name, "<empty question>")
133    return compare_val(exp.question[0].name,
134                       got.question[0].name)
135
136
137def match_qcase(exp, got):
138    check_question(exp.question)
139    check_question(got.question)
140    if not exp.question and not got.question:
141        return True
142    if not exp.question:
143        raise DataMismatch("<empty question>", got.question[0].name.labels)
144    if not got.question:
145        raise DataMismatch(exp.question[0].name.labels, "<empty question>")
146    return compare_val(exp.question[0].name.labels,
147                       got.question[0].name.labels)
148
149
150def match_subdomain(exp, got):
151    if not exp.question:
152        return True
153    if got.question:
154        qname = got.question[0].name
155    else:
156        qname = dns.name.root
157    if exp.question[0].name.is_superdomain(qname):
158        return True
159    raise DataMismatch(exp, got)
160
161
162def match_flags(exp, got):
163    return compare_val(dns.flags.to_text(exp.flags),
164                       dns.flags.to_text(got.flags))
165
166
167def match_rcode(exp, got):
168    return compare_val(dns.rcode.to_text(exp.rcode()),
169                       dns.rcode.to_text(got.rcode()))
170
171
172def match_answer(exp, got):
173    return compare_rrs(exp.answer,
174                       got.answer)
175
176
177def match_answertypes(exp, got):
178    return compare_rrs_types(exp.answer,
179                             got.answer, skip_rrsigs=True)
180
181
182def match_answerrrsigs(exp, got):
183    return compare_rrs_types(exp.answer,
184                             got.answer, skip_rrsigs=False)
185
186
187def match_authority(exp, got):
188    return compare_rrs(exp.authority,
189                       got.authority)
190
191
192def match_additional(exp, got):
193    return compare_rrs(exp.additional,
194                       got.additional)
195
196
197def match_edns(exp, got):
198    if got.edns != exp.edns:
199        raise DataMismatch(exp.edns,
200                           got.edns)
201    if got.payload != exp.payload:
202        raise DataMismatch(exp.payload,
203                           got.payload)
204
205
206def match_nsid(exp, got):
207    nsid_opt = None
208    for opt in exp.options:
209        if opt.otype == dns.edns.NSID:
210            nsid_opt = opt
211            break
212    # Find matching NSID
213    for opt in got.options:
214        if opt.otype == dns.edns.NSID:
215            if not nsid_opt:
216                raise DataMismatch(None, opt.data)
217            if opt == nsid_opt:
218                return True
219            else:
220                raise DataMismatch(nsid_opt.data, opt.data)
221    if nsid_opt:
222        raise DataMismatch(nsid_opt.data, None)
223    return True
224
225
226MATCH = {"opcode": match_opcode, "qtype": match_qtype, "qname": match_qname, "qcase": match_qcase,
227         "subdomain": match_subdomain, "flags": match_flags, "rcode": match_rcode,
228         "answer": match_answer, "answertypes": match_answertypes,
229         "answerrrsigs": match_answerrrsigs, "authority": match_authority,
230         "additional": match_additional, "edns": match_edns,
231         "nsid": match_nsid}
232
233
234def match_part(exp, got, code):
235    try:
236        return MATCH[code](exp, got)
237    except KeyError as ex:
238        raise NotImplementedError('unknown match request "%s"' % code) from ex
239