1"""matchpart is used to compare two DNS messages using a single criterion""" 2 3from typing import ( # noqa 4 Any, Hashable, Sequence, Tuple, Union) 5 6import dns.edns 7import dns.rcode 8import dns.set 9 10MismatchValue = Union[str, Sequence[Any]] 11 12 13class DataMismatch(Exception): 14 def __init__(self, exp_val, got_val): 15 super().__init__() 16 self.exp_val = exp_val 17 self.got_val = got_val 18 19 @staticmethod 20 def format_value(value: MismatchValue) -> str: 21 if isinstance(value, list): 22 return ' '.join([str(val) for val in value]) 23 else: 24 return str(value) 25 26 def __str__(self) -> str: 27 return 'expected "{}" got "{}"'.format( 28 self.format_value(self.exp_val), 29 self.format_value(self.got_val)) 30 31 def __eq__(self, other): 32 return (isinstance(other, DataMismatch) 33 and self.exp_val == other.exp_val 34 and self.got_val == other.got_val) 35 36 def __ne__(self, other): 37 return not self.__eq__(other) 38 39 @property 40 def key(self) -> Tuple[Hashable, Hashable]: 41 def make_hashable(value): 42 if isinstance(value, (list, dns.set.Set)): 43 value = (make_hashable(item) for item in value) 44 value = tuple(value) 45 return value 46 47 return (make_hashable(self.exp_val), make_hashable(self.got_val)) 48 49 def __hash__(self) -> int: 50 return hash(self.key) 51 52 53def compare_val(exp, got): 54 """Compare arbitraty objects, throw exception if different. """ 55 if exp != got: 56 raise DataMismatch(exp, got) 57 return True 58 59 60def compare_rrs(expected, got): 61 """ Compare lists of RR sets, throw exception if different. """ 62 for rr in expected: 63 if rr not in got: 64 raise DataMismatch(expected, got) 65 for rr in got: 66 if rr not in expected: 67 raise DataMismatch(expected, got) 68 if len(expected) != len(got): 69 raise DataMismatch(expected, got) 70 return True 71 72 73def compare_rrs_types(exp_val, got_val, skip_rrsigs): 74 """sets of RR types in both sections must match""" 75 def rr_ordering_key(rrset): 76 if rrset.covers: 77 return rrset.covers, 1 # RRSIGs go to the end of RRtype list 78 else: 79 return rrset.rdtype, 0 80 81 def key_to_text(rrtype, rrsig): 82 if not rrsig: 83 return dns.rdatatype.to_text(rrtype) 84 else: 85 return 'RRSIG(%s)' % dns.rdatatype.to_text(rrtype) 86 87 if skip_rrsigs: 88 exp_val = (rrset for rrset in exp_val 89 if rrset.rdtype != dns.rdatatype.RRSIG) 90 got_val = (rrset for rrset in got_val 91 if rrset.rdtype != dns.rdatatype.RRSIG) 92 93 exp_types = frozenset(rr_ordering_key(rrset) for rrset in exp_val) 94 got_types = frozenset(rr_ordering_key(rrset) for rrset in got_val) 95 if exp_types != got_types: 96 exp_types = tuple(key_to_text(*i) for i in sorted(exp_types)) 97 got_types = tuple(key_to_text(*i) for i in sorted(got_types)) 98 raise DataMismatch(exp_types, got_types) 99 100 101def check_question(question): 102 if len(question) > 2: 103 raise NotImplementedError("More than one record in QUESTION SECTION.") 104 105 106def match_opcode(exp, got): 107 return compare_val(exp.opcode(), 108 got.opcode()) 109 110 111def match_qtype(exp, got): 112 check_question(exp.question) 113 check_question(got.question) 114 if not exp.question and not got.question: 115 return True 116 if not exp.question: 117 raise DataMismatch("<empty question>", got.question[0].rdtype) 118 if not got.question: 119 raise DataMismatch(exp.question[0].rdtype, "<empty question>") 120 return compare_val(exp.question[0].rdtype, 121 got.question[0].rdtype) 122 123 124def match_qname(exp, got): 125 check_question(exp.question) 126 check_question(got.question) 127 if not exp.question and not got.question: 128 return True 129 if not exp.question: 130 raise DataMismatch("<empty question>", got.question[0].name) 131 if not got.question: 132 raise DataMismatch(exp.question[0].name, "<empty question>") 133 return compare_val(exp.question[0].name, 134 got.question[0].name) 135 136 137def match_qcase(exp, got): 138 check_question(exp.question) 139 check_question(got.question) 140 if not exp.question and not got.question: 141 return True 142 if not exp.question: 143 raise DataMismatch("<empty question>", got.question[0].name.labels) 144 if not got.question: 145 raise DataMismatch(exp.question[0].name.labels, "<empty question>") 146 return compare_val(exp.question[0].name.labels, 147 got.question[0].name.labels) 148 149 150def match_subdomain(exp, got): 151 if not exp.question: 152 return True 153 if got.question: 154 qname = got.question[0].name 155 else: 156 qname = dns.name.root 157 if exp.question[0].name.is_superdomain(qname): 158 return True 159 raise DataMismatch(exp, got) 160 161 162def match_flags(exp, got): 163 return compare_val(dns.flags.to_text(exp.flags), 164 dns.flags.to_text(got.flags)) 165 166 167def match_rcode(exp, got): 168 return compare_val(dns.rcode.to_text(exp.rcode()), 169 dns.rcode.to_text(got.rcode())) 170 171 172def match_answer(exp, got): 173 return compare_rrs(exp.answer, 174 got.answer) 175 176 177def match_answertypes(exp, got): 178 return compare_rrs_types(exp.answer, 179 got.answer, skip_rrsigs=True) 180 181 182def match_answerrrsigs(exp, got): 183 return compare_rrs_types(exp.answer, 184 got.answer, skip_rrsigs=False) 185 186 187def match_authority(exp, got): 188 return compare_rrs(exp.authority, 189 got.authority) 190 191 192def match_additional(exp, got): 193 return compare_rrs(exp.additional, 194 got.additional) 195 196 197def match_edns(exp, got): 198 if got.edns != exp.edns: 199 raise DataMismatch(exp.edns, 200 got.edns) 201 if got.payload != exp.payload: 202 raise DataMismatch(exp.payload, 203 got.payload) 204 205 206def match_nsid(exp, got): 207 nsid_opt = None 208 for opt in exp.options: 209 if opt.otype == dns.edns.NSID: 210 nsid_opt = opt 211 break 212 # Find matching NSID 213 for opt in got.options: 214 if opt.otype == dns.edns.NSID: 215 if not nsid_opt: 216 raise DataMismatch(None, opt.data) 217 if opt == nsid_opt: 218 return True 219 else: 220 raise DataMismatch(nsid_opt.data, opt.data) 221 if nsid_opt: 222 raise DataMismatch(nsid_opt.data, None) 223 return True 224 225 226MATCH = {"opcode": match_opcode, "qtype": match_qtype, "qname": match_qname, "qcase": match_qcase, 227 "subdomain": match_subdomain, "flags": match_flags, "rcode": match_rcode, 228 "answer": match_answer, "answertypes": match_answertypes, 229 "answerrrsigs": match_answerrrsigs, "authority": match_authority, 230 "additional": match_additional, "edns": match_edns, 231 "nsid": match_nsid} 232 233 234def match_part(exp, got, code): 235 try: 236 return MATCH[code](exp, got) 237 except KeyError as ex: 238 raise NotImplementedError('unknown match request "%s"' % code) from ex 239