1""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine 2 Copyright 2003 Paul Scott-Murphy, 2014 William McBrine 3 4 This module provides a framework for the use of DNS Service Discovery 5 using IP multicast. 6 7 This library is free software; you can redistribute it and/or 8 modify it under the terms of the GNU Lesser General Public 9 License as published by the Free Software Foundation; either 10 version 2.1 of the License, or (at your option) any later version. 11 12 This library is distributed in the hope that it will be useful, 13 but WITHOUT ANY WARRANTY; without even the implied warranty of 14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 15 Lesser General Public License for more details. 16 17 You should have received a copy of the GNU Lesser General Public 18 License along with this library; if not, write to the Free Software 19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 20 USA 21""" 22 23import itertools 24import random 25from collections import deque 26from typing import Dict, Iterable, List, NamedTuple, Optional, Set, TYPE_CHECKING, Tuple, Union, cast 27 28from ._cache import DNSCache, _UniqueRecordsType 29from ._dns import DNSAddress, DNSNsec, DNSPointer, DNSQuestion, DNSRRSet, DNSRecord 30from ._history import QuestionHistory 31from ._logger import log 32from ._protocol.incoming import DNSIncoming 33from ._protocol.outgoing import DNSOutgoing 34from ._services.info import ServiceInfo 35from ._services.registry import ServiceRegistry 36from ._updates import RecordUpdate, RecordUpdateListener 37from ._utils.time import current_time_millis, millis_to_seconds 38from .const import ( 39 _CLASS_IN, 40 _CLASS_UNIQUE, 41 _DNS_OTHER_TTL, 42 _DNS_PTR_MIN_TTL, 43 _FLAGS_AA, 44 _FLAGS_QR_RESPONSE, 45 _ONE_SECOND, 46 _SERVICE_TYPE_ENUMERATION_NAME, 47 _TYPE_A, 48 _TYPE_AAAA, 49 _TYPE_ANY, 50 _TYPE_NSEC, 51 _TYPE_PTR, 52 _TYPE_SRV, 53 _TYPE_TXT, 54) 55 56if TYPE_CHECKING: 57 from ._core import Zeroconf 58 59 60_AnswerWithAdditionalsType = Dict[DNSRecord, Set[DNSRecord]] 61 62_MULTICAST_DELAY_RANDOM_INTERVAL = (20, 120) 63_ADDRESS_RECORD_TYPES = {_TYPE_A, _TYPE_AAAA} 64_RESPOND_IMMEDIATE_TYPES = {_TYPE_NSEC, _TYPE_SRV, *_ADDRESS_RECORD_TYPES} 65 66 67class QuestionAnswers(NamedTuple): 68 ucast: _AnswerWithAdditionalsType 69 mcast_now: _AnswerWithAdditionalsType 70 mcast_aggregate: _AnswerWithAdditionalsType 71 mcast_aggregate_last_second: _AnswerWithAdditionalsType 72 73 74class AnswerGroup(NamedTuple): 75 """A group of answers scheduled to be sent at the same time.""" 76 77 send_after: float # Must be sent after this time 78 send_before: float # Must be sent before this time 79 answers: _AnswerWithAdditionalsType 80 81 82def _message_is_probe(msg: DNSIncoming) -> bool: 83 return msg.num_authorities > 0 84 85 86def construct_nsec_record(name: str, types: List[int], now: float) -> DNSNsec: 87 """Construct an NSEC record for name and a list of dns types. 88 89 This function should only be used for SRV/A/AAAA records 90 which have a TTL of _DNS_OTHER_TTL 91 """ 92 return DNSNsec(name, _TYPE_NSEC, _CLASS_IN | _CLASS_UNIQUE, _DNS_OTHER_TTL, name, types, created=now) 93 94 95def construct_outgoing_multicast_answers(answers: _AnswerWithAdditionalsType) -> DNSOutgoing: 96 """Add answers and additionals to a DNSOutgoing.""" 97 out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, multicast=True) 98 _add_answers_additionals(out, answers) 99 return out 100 101 102def construct_outgoing_unicast_answers( 103 answers: _AnswerWithAdditionalsType, ucast_source: bool, questions: List[DNSQuestion], id_: int 104) -> DNSOutgoing: 105 """Add answers and additionals to a DNSOutgoing.""" 106 out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, multicast=False, id_=id_) 107 # Adding the questions back when the source is legacy unicast behavior 108 if ucast_source: 109 for question in questions: 110 out.add_question(question) 111 _add_answers_additionals(out, answers) 112 return out 113 114 115def _add_answers_additionals(out: DNSOutgoing, answers: _AnswerWithAdditionalsType) -> None: 116 # Find additionals and suppress any additionals that are already in answers 117 sending: Set[DNSRecord] = set(answers.keys()) 118 # Answers are sorted to group names together to increase the chance 119 # that similar names will end up in the same packet and can reduce the 120 # overall size of the outgoing response via name compression 121 for answer, additionals in sorted(answers.items(), key=lambda kv: kv[0].name): 122 out.add_answer_at_time(answer, 0) 123 for additional in additionals: 124 if additional not in sending: 125 out.add_additional_answer(additional) 126 sending.add(additional) 127 128 129def sanitize_incoming_record(record: DNSRecord) -> None: 130 """Protect zeroconf from records that can cause denial of service. 131 132 We enforce a minimum TTL for PTR records to avoid 133 ServiceBrowsers generating excessive queries refresh queries. 134 Apple uses a 15s minimum TTL, however we do not have the same 135 level of rate limit and safe guards so we use 1/4 of the recommended value. 136 """ 137 if record.ttl and record.ttl < _DNS_PTR_MIN_TTL and isinstance(record, DNSPointer): 138 log.debug( 139 "Increasing effective ttl of %s to minimum of %s to protect against excessive refreshes.", 140 record, 141 _DNS_PTR_MIN_TTL, 142 ) 143 record.set_created_ttl(record.created, _DNS_PTR_MIN_TTL) 144 145 146class _QueryResponse: 147 """A pair for unicast and multicast DNSOutgoing responses.""" 148 149 def __init__(self, cache: DNSCache, msgs: List[DNSIncoming]) -> None: 150 """Build a query response.""" 151 self._is_probe = any(_message_is_probe(msg) for msg in msgs) 152 self._msg = msgs[0] 153 self._now = self._msg.now 154 self._cache = cache 155 self._additionals: _AnswerWithAdditionalsType = {} 156 self._ucast: Set[DNSRecord] = set() 157 self._mcast_now: Set[DNSRecord] = set() 158 self._mcast_aggregate: Set[DNSRecord] = set() 159 self._mcast_aggregate_last_second: Set[DNSRecord] = set() 160 161 def add_qu_question_response(self, answers: _AnswerWithAdditionalsType) -> None: 162 """Generate a response to a multicast QU query.""" 163 for record, additionals in answers.items(): 164 self._additionals[record] = additionals 165 if self._is_probe: 166 self._ucast.add(record) 167 if not self._has_mcast_within_one_quarter_ttl(record): 168 self._mcast_now.add(record) 169 elif not self._is_probe: 170 self._ucast.add(record) 171 172 def add_ucast_question_response(self, answers: _AnswerWithAdditionalsType) -> None: 173 """Generate a response to a unicast query.""" 174 self._additionals.update(answers) 175 self._ucast.update(answers.keys()) 176 177 def add_mcast_question_response(self, answers: _AnswerWithAdditionalsType) -> None: 178 """Generate a response to a multicast query.""" 179 self._additionals.update(answers) 180 for answer in answers: 181 if self._is_probe: 182 self._mcast_now.add(answer) 183 continue 184 185 if self._has_mcast_record_in_last_second(answer): 186 self._mcast_aggregate_last_second.add(answer) 187 elif len(self._msg.questions) == 1 and self._msg.questions[0].type in _RESPOND_IMMEDIATE_TYPES: 188 self._mcast_now.add(answer) 189 else: 190 self._mcast_aggregate.add(answer) 191 192 def _generate_answers_with_additionals(self, rrset: Set[DNSRecord]) -> _AnswerWithAdditionalsType: 193 """Create answers with additionals from an rrset.""" 194 return {record: self._additionals[record] for record in rrset} 195 196 def answers( 197 self, 198 ) -> QuestionAnswers: 199 """Return answer sets that will be queued.""" 200 return QuestionAnswers( 201 self._generate_answers_with_additionals(self._ucast), 202 self._generate_answers_with_additionals(self._mcast_now), 203 self._generate_answers_with_additionals(self._mcast_aggregate), 204 self._generate_answers_with_additionals(self._mcast_aggregate_last_second), 205 ) 206 207 def _has_mcast_within_one_quarter_ttl(self, record: DNSRecord) -> bool: 208 """Check to see if a record has been mcasted recently. 209 210 https://datatracker.ietf.org/doc/html/rfc6762#section-5.4 211 When receiving a question with the unicast-response bit set, a 212 responder SHOULD usually respond with a unicast packet directed back 213 to the querier. However, if the responder has not multicast that 214 record recently (within one quarter of its TTL), then the responder 215 SHOULD instead multicast the response so as to keep all the peer 216 caches up to date 217 """ 218 maybe_entry = self._cache.async_get_unique(cast(_UniqueRecordsType, record)) 219 return bool(maybe_entry and maybe_entry.is_recent(self._now)) 220 221 def _has_mcast_record_in_last_second(self, record: DNSRecord) -> bool: 222 """Check if an answer was seen in the last second. 223 Protect the network against excessive packet flooding 224 https://datatracker.ietf.org/doc/html/rfc6762#section-14 225 """ 226 maybe_entry = self._cache.async_get_unique(cast(_UniqueRecordsType, record)) 227 return bool(maybe_entry and self._now - maybe_entry.created < _ONE_SECOND) 228 229 230def _get_address_and_nsec_records(service: ServiceInfo, now: float) -> Set[DNSRecord]: 231 """Build a set of address records and NSEC records for non-present record types.""" 232 seen_types: Set[int] = set() 233 records: Set[DNSRecord] = set() 234 for dns_address in service.dns_addresses(created=now): 235 seen_types.add(dns_address.type) 236 records.add(dns_address) 237 missing_types: Set[int] = _ADDRESS_RECORD_TYPES - seen_types 238 if missing_types: 239 records.add(construct_nsec_record(service.server, list(missing_types), now)) 240 return records 241 242 243class QueryHandler: 244 """Query the ServiceRegistry.""" 245 246 def __init__(self, registry: ServiceRegistry, cache: DNSCache, question_history: QuestionHistory) -> None: 247 """Init the query handler.""" 248 self.registry = registry 249 self.cache = cache 250 self.question_history = question_history 251 252 def _add_service_type_enumeration_query_answers( 253 self, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, now: float 254 ) -> None: 255 """Provide an answer to a service type enumeration query. 256 257 https://datatracker.ietf.org/doc/html/rfc6763#section-9 258 """ 259 for stype in self.registry.async_get_types(): 260 dns_pointer = DNSPointer( 261 _SERVICE_TYPE_ENUMERATION_NAME, _TYPE_PTR, _CLASS_IN, _DNS_OTHER_TTL, stype, now 262 ) 263 if not known_answers.suppresses(dns_pointer): 264 answer_set[dns_pointer] = set() 265 266 def _add_pointer_answers( 267 self, name: str, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, now: float 268 ) -> None: 269 """Answer PTR/ANY question.""" 270 for service in self.registry.async_get_infos_type(name): 271 # Add recommended additional answers according to 272 # https://tools.ietf.org/html/rfc6763#section-12.1. 273 dns_pointer = service.dns_pointer(created=now) 274 if known_answers.suppresses(dns_pointer): 275 continue 276 additionals: Set[DNSRecord] = {service.dns_service(created=now), service.dns_text(created=now)} 277 additionals |= _get_address_and_nsec_records(service, now) 278 answer_set[dns_pointer] = additionals 279 280 def _add_address_answers( 281 self, 282 name: str, 283 answer_set: _AnswerWithAdditionalsType, 284 known_answers: DNSRRSet, 285 now: float, 286 type_: int, 287 ) -> None: 288 """Answer A/AAAA/ANY question.""" 289 for service in self.registry.async_get_infos_server(name): 290 answers: List[DNSAddress] = [] 291 additionals: Set[DNSRecord] = set() 292 seen_types: Set[int] = set() 293 for dns_address in service.dns_addresses(created=now): 294 seen_types.add(dns_address.type) 295 if dns_address.type != type_: 296 additionals.add(dns_address) 297 elif not known_answers.suppresses(dns_address): 298 answers.append(dns_address) 299 missing_types: Set[int] = _ADDRESS_RECORD_TYPES - seen_types 300 if answers: 301 if missing_types: 302 additionals.add(construct_nsec_record(service.server, list(missing_types), now)) 303 for answer in answers: 304 answer_set[answer] = additionals 305 elif type_ in missing_types: 306 answer_set[construct_nsec_record(service.server, list(missing_types), now)] = set() 307 308 def _answer_question( 309 self, 310 question: DNSQuestion, 311 known_answers: DNSRRSet, 312 now: float, 313 ) -> _AnswerWithAdditionalsType: 314 answer_set: _AnswerWithAdditionalsType = {} 315 316 if question.type == _TYPE_PTR and question.name.lower() == _SERVICE_TYPE_ENUMERATION_NAME: 317 self._add_service_type_enumeration_query_answers(answer_set, known_answers, now) 318 return answer_set 319 320 type_ = question.type 321 322 if type_ in (_TYPE_PTR, _TYPE_ANY): 323 self._add_pointer_answers(question.name, answer_set, known_answers, now) 324 325 if type_ in (_TYPE_A, _TYPE_AAAA, _TYPE_ANY): 326 self._add_address_answers(question.name, answer_set, known_answers, now, type_) 327 328 if type_ in (_TYPE_SRV, _TYPE_TXT, _TYPE_ANY): 329 service = self.registry.async_get_info_name(question.name) # type: ignore 330 if service is not None: 331 if type_ in (_TYPE_SRV, _TYPE_ANY): 332 # Add recommended additional answers according to 333 # https://tools.ietf.org/html/rfc6763#section-12.2. 334 dns_service = service.dns_service(created=now) 335 if not known_answers.suppresses(dns_service): 336 answer_set[dns_service] = _get_address_and_nsec_records(service, now) 337 if type_ in (_TYPE_TXT, _TYPE_ANY): 338 dns_text = service.dns_text(created=now) 339 if not known_answers.suppresses(dns_text): 340 answer_set[dns_text] = set() 341 342 return answer_set 343 344 def async_response( # pylint: disable=unused-argument 345 self, msgs: List[DNSIncoming], ucast_source: bool 346 ) -> QuestionAnswers: 347 """Deal with incoming query packets. Provides a response if possible. 348 349 This function must be run in the event loop as it is not 350 threadsafe. 351 """ 352 known_answers = DNSRRSet( 353 itertools.chain.from_iterable(msg.answers for msg in msgs if not _message_is_probe(msg)) 354 ) 355 query_res = _QueryResponse(self.cache, msgs) 356 357 for msg in msgs: 358 for question in msg.questions: 359 if not question.unicast: 360 self.question_history.add_question_at_time(question, msg.now, set(known_answers.lookup)) 361 answer_set = self._answer_question(question, known_answers, msg.now) 362 if not ucast_source and question.unicast: 363 query_res.add_qu_question_response(answer_set) 364 continue 365 if ucast_source: 366 query_res.add_ucast_question_response(answer_set) 367 # We always multicast as well even if its a unicast 368 # source as long as we haven't done it recently (75% of ttl) 369 query_res.add_mcast_question_response(answer_set) 370 371 return query_res.answers() 372 373 374class RecordManager: 375 """Process records into the cache and notify listeners.""" 376 377 def __init__(self, zeroconf: 'Zeroconf') -> None: 378 """Init the record manager.""" 379 self.zc = zeroconf 380 self.cache = zeroconf.cache 381 self.listeners: List[RecordUpdateListener] = [] 382 383 def async_updates(self, now: float, records: List[RecordUpdate]) -> None: 384 """Used to notify listeners of new information that has updated 385 a record. 386 387 This method must be called before the cache is updated. 388 389 This method will be run in the event loop. 390 """ 391 for listener in self.listeners: 392 listener.async_update_records(self.zc, now, records) 393 394 def async_updates_complete(self) -> None: 395 """Used to notify listeners of new information that has updated 396 a record. 397 398 This method must be called after the cache is updated. 399 400 This method will be run in the event loop. 401 """ 402 for listener in self.listeners: 403 listener.async_update_records_complete() 404 self.zc.async_notify_all() 405 406 def async_updates_from_response(self, msg: DNSIncoming) -> None: 407 """Deal with incoming response packets. All answers 408 are held in the cache, and listeners are notified. 409 410 This function must be run in the event loop as it is not 411 threadsafe. 412 """ 413 updates: List[RecordUpdate] = [] 414 address_adds: List[DNSAddress] = [] 415 other_adds: List[DNSRecord] = [] 416 removes: Set[DNSRecord] = set() 417 now = msg.now 418 unique_types: Set[Tuple[str, int, int]] = set() 419 420 for record in msg.answers: 421 sanitize_incoming_record(record) 422 423 if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2 424 unique_types.add((record.name, record.type, record.class_)) 425 426 maybe_entry = self.cache.async_get_unique(cast(_UniqueRecordsType, record)) 427 if not record.is_expired(now): 428 if maybe_entry is not None: 429 maybe_entry.reset_ttl(record) 430 else: 431 if isinstance(record, DNSAddress): 432 address_adds.append(record) 433 else: 434 other_adds.append(record) 435 updates.append(RecordUpdate(record, maybe_entry)) 436 # This is likely a goodbye since the record is 437 # expired and exists in the cache 438 elif maybe_entry is not None: 439 updates.append(RecordUpdate(record, maybe_entry)) 440 removes.add(record) 441 442 if unique_types: 443 self._async_mark_unique_cached_records_older_than_1s_to_expire(unique_types, msg.answers, now) 444 445 if updates: 446 self.async_updates(now, updates) 447 # The cache adds must be processed AFTER we trigger 448 # the updates since we compare existing data 449 # with the new data and updating the cache 450 # ahead of update_record will cause listeners 451 # to miss changes 452 # 453 # We must process address adds before non-addresses 454 # otherwise a fetch of ServiceInfo may miss an address 455 # because it thinks the cache is complete 456 # 457 # The cache is processed under the context manager to ensure 458 # that any ServiceBrowser that is going to call 459 # zc.get_service_info will see the cached value 460 # but ONLY after all the record updates have been 461 # processsed. 462 if other_adds or address_adds: 463 self.cache.async_add_records(itertools.chain(address_adds, other_adds)) 464 # Removes are processed last since 465 # ServiceInfo could generate an un-needed query 466 # because the data was not yet populated. 467 if removes: 468 self.cache.async_remove_records(removes) 469 if updates: 470 self.async_updates_complete() 471 472 def _async_mark_unique_cached_records_older_than_1s_to_expire( 473 self, unique_types: Set[Tuple[str, int, int]], answers: Iterable[DNSRecord], now: float 474 ) -> None: 475 # rfc6762#section-10.2 para 2 476 # Since unique is set, all old records with that name, rrtype, 477 # and rrclass that were received more than one second ago are declared 478 # invalid, and marked to expire from the cache in one second. 479 answers_rrset = DNSRRSet(answers) 480 for name, type_, class_ in unique_types: 481 for entry in self.cache.async_all_by_details(name, type_, class_): 482 if (now - entry.created > _ONE_SECOND) and entry not in answers_rrset: 483 # Expire in 1s 484 entry.set_created_ttl(now, 1) 485 486 def async_add_listener( 487 self, listener: RecordUpdateListener, question: Optional[Union[DNSQuestion, List[DNSQuestion]]] 488 ) -> None: 489 """Adds a listener for a given question. The listener will have 490 its update_record method called when information is available to 491 answer the question(s). 492 493 This function is not threadsafe and must be called in the eventloop. 494 """ 495 if not isinstance(listener, RecordUpdateListener): 496 log.error( 497 "listeners passed to async_add_listener must inherit from RecordUpdateListener;" 498 " In the future this will fail" 499 ) 500 501 self.listeners.append(listener) 502 503 if question is None: 504 return 505 506 questions = [question] if isinstance(question, DNSQuestion) else question 507 assert self.zc.loop is not None 508 self._async_update_matching_records(listener, questions) 509 510 def _async_update_matching_records( 511 self, listener: RecordUpdateListener, questions: List[DNSQuestion] 512 ) -> None: 513 """Calls back any existing entries in the cache that answer the question. 514 515 This function must be run from the event loop. 516 """ 517 now = current_time_millis() 518 records: List[RecordUpdate] = [] 519 for question in questions: 520 for record in self.cache.async_entries_with_name(question.name): 521 if not record.is_expired(now) and question.answered_by(record): 522 records.append(RecordUpdate(record, None)) 523 524 if not records: 525 return 526 listener.async_update_records(self.zc, now, records) 527 listener.async_update_records_complete() 528 self.zc.async_notify_all() 529 530 def async_remove_listener(self, listener: RecordUpdateListener) -> None: 531 """Removes a listener. 532 533 This function is not threadsafe and must be called in the eventloop. 534 """ 535 try: 536 self.listeners.remove(listener) 537 self.zc.async_notify_all() 538 except ValueError as e: 539 log.exception('Failed to remove listener: %r', e) 540 541 542class MulticastOutgoingQueue: 543 """An outgoing queue used to aggregate multicast responses.""" 544 545 def __init__(self, zeroconf: 'Zeroconf', additional_delay: int, max_aggregation_delay: int) -> None: 546 self.zc = zeroconf 547 self.queue: deque = deque() 548 # Additional delay is used to implement 549 # Protect the network against excessive packet flooding 550 # https://datatracker.ietf.org/doc/html/rfc6762#section-14 551 self.additional_delay = additional_delay 552 self.aggregation_delay = max_aggregation_delay 553 554 def async_add(self, now: float, answers: _AnswerWithAdditionalsType) -> None: 555 """Add a group of answers with additionals to the outgoing queue.""" 556 assert self.zc.loop is not None 557 random_delay = random.randint(*_MULTICAST_DELAY_RANDOM_INTERVAL) + self.additional_delay 558 send_after = now + random_delay 559 send_before = now + self.aggregation_delay + self.additional_delay 560 if len(self.queue): 561 # If we calculate a random delay for the send after time 562 # that is less than the last group scheduled to go out, 563 # we instead add the answers to the last group as this 564 # allows aggregating additonal responses 565 last_group = self.queue[-1] 566 if send_after <= last_group.send_after: 567 last_group.answers.update(answers) 568 return 569 else: 570 self.zc.loop.call_later(millis_to_seconds(random_delay), self.async_ready) 571 self.queue.append(AnswerGroup(send_after, send_before, answers)) 572 573 def _remove_answers_from_queue(self, answers: _AnswerWithAdditionalsType) -> None: 574 """Remove a set of answers from the outgoing queue.""" 575 for pending in self.queue: 576 for record in answers: 577 pending.answers.pop(record, None) 578 579 def async_ready(self) -> None: 580 """Process anything in the queue that is ready.""" 581 assert self.zc.loop is not None 582 now = current_time_millis() 583 584 if len(self.queue) > 1 and self.queue[0].send_before > now: 585 # There is more than one answer in the queue, 586 # delay until we have to send it (first answer group reaches send_before) 587 self.zc.loop.call_later(millis_to_seconds(self.queue[0].send_before - now), self.async_ready) 588 return 589 590 answers: _AnswerWithAdditionalsType = {} 591 # Add all groups that can be sent now 592 while len(self.queue) and self.queue[0].send_after <= now: 593 answers.update(self.queue.popleft().answers) 594 595 if len(self.queue): 596 # If there are still groups in the queue that are not ready to send 597 # be sure we schedule them to go out later 598 self.zc.loop.call_later(millis_to_seconds(self.queue[0].send_after - now), self.async_ready) 599 600 if answers: 601 # If we have the same answer scheduled to go out, remove them 602 self._remove_answers_from_queue(answers) 603 self.zc.async_send(construct_outgoing_multicast_answers(answers)) 604