1""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine
2    Copyright 2003 Paul Scott-Murphy, 2014 William McBrine
3
4    This module provides a framework for the use of DNS Service Discovery
5    using IP multicast.
6
7    This library is free software; you can redistribute it and/or
8    modify it under the terms of the GNU Lesser General Public
9    License as published by the Free Software Foundation; either
10    version 2.1 of the License, or (at your option) any later version.
11
12    This library is distributed in the hope that it will be useful,
13    but WITHOUT ANY WARRANTY; without even the implied warranty of
14    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
15    Lesser General Public License for more details.
16
17    You should have received a copy of the GNU Lesser General Public
18    License along with this library; if not, write to the Free Software
19    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
20    USA
21"""
22
23import itertools
24import random
25from collections import deque
26from typing import Dict, Iterable, List, NamedTuple, Optional, Set, TYPE_CHECKING, Tuple, Union, cast
27
28from ._cache import DNSCache, _UniqueRecordsType
29from ._dns import DNSAddress, DNSNsec, DNSPointer, DNSQuestion, DNSRRSet, DNSRecord
30from ._history import QuestionHistory
31from ._logger import log
32from ._protocol.incoming import DNSIncoming
33from ._protocol.outgoing import DNSOutgoing
34from ._services.info import ServiceInfo
35from ._services.registry import ServiceRegistry
36from ._updates import RecordUpdate, RecordUpdateListener
37from ._utils.time import current_time_millis, millis_to_seconds
38from .const import (
39    _CLASS_IN,
40    _CLASS_UNIQUE,
41    _DNS_OTHER_TTL,
42    _DNS_PTR_MIN_TTL,
43    _FLAGS_AA,
44    _FLAGS_QR_RESPONSE,
45    _ONE_SECOND,
46    _SERVICE_TYPE_ENUMERATION_NAME,
47    _TYPE_A,
48    _TYPE_AAAA,
49    _TYPE_ANY,
50    _TYPE_NSEC,
51    _TYPE_PTR,
52    _TYPE_SRV,
53    _TYPE_TXT,
54)
55
56if TYPE_CHECKING:
57    from ._core import Zeroconf
58
59
60_AnswerWithAdditionalsType = Dict[DNSRecord, Set[DNSRecord]]
61
62_MULTICAST_DELAY_RANDOM_INTERVAL = (20, 120)
63_ADDRESS_RECORD_TYPES = {_TYPE_A, _TYPE_AAAA}
64_RESPOND_IMMEDIATE_TYPES = {_TYPE_NSEC, _TYPE_SRV, *_ADDRESS_RECORD_TYPES}
65
66
67class QuestionAnswers(NamedTuple):
68    ucast: _AnswerWithAdditionalsType
69    mcast_now: _AnswerWithAdditionalsType
70    mcast_aggregate: _AnswerWithAdditionalsType
71    mcast_aggregate_last_second: _AnswerWithAdditionalsType
72
73
74class AnswerGroup(NamedTuple):
75    """A group of answers scheduled to be sent at the same time."""
76
77    send_after: float  # Must be sent after this time
78    send_before: float  # Must be sent before this time
79    answers: _AnswerWithAdditionalsType
80
81
82def _message_is_probe(msg: DNSIncoming) -> bool:
83    return msg.num_authorities > 0
84
85
86def construct_nsec_record(name: str, types: List[int], now: float) -> DNSNsec:
87    """Construct an NSEC record for name and a list of dns types.
88
89    This function should only be used for SRV/A/AAAA records
90    which have a TTL of _DNS_OTHER_TTL
91    """
92    return DNSNsec(name, _TYPE_NSEC, _CLASS_IN | _CLASS_UNIQUE, _DNS_OTHER_TTL, name, types, created=now)
93
94
95def construct_outgoing_multicast_answers(answers: _AnswerWithAdditionalsType) -> DNSOutgoing:
96    """Add answers and additionals to a DNSOutgoing."""
97    out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, multicast=True)
98    _add_answers_additionals(out, answers)
99    return out
100
101
102def construct_outgoing_unicast_answers(
103    answers: _AnswerWithAdditionalsType, ucast_source: bool, questions: List[DNSQuestion], id_: int
104) -> DNSOutgoing:
105    """Add answers and additionals to a DNSOutgoing."""
106    out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, multicast=False, id_=id_)
107    # Adding the questions back when the source is legacy unicast behavior
108    if ucast_source:
109        for question in questions:
110            out.add_question(question)
111    _add_answers_additionals(out, answers)
112    return out
113
114
115def _add_answers_additionals(out: DNSOutgoing, answers: _AnswerWithAdditionalsType) -> None:
116    # Find additionals and suppress any additionals that are already in answers
117    sending: Set[DNSRecord] = set(answers.keys())
118    # Answers are sorted to group names together to increase the chance
119    # that similar names will end up in the same packet and can reduce the
120    # overall size of the outgoing response via name compression
121    for answer, additionals in sorted(answers.items(), key=lambda kv: kv[0].name):
122        out.add_answer_at_time(answer, 0)
123        for additional in additionals:
124            if additional not in sending:
125                out.add_additional_answer(additional)
126                sending.add(additional)
127
128
129def sanitize_incoming_record(record: DNSRecord) -> None:
130    """Protect zeroconf from records that can cause denial of service.
131
132    We enforce a minimum TTL for PTR records to avoid
133    ServiceBrowsers generating excessive queries refresh queries.
134    Apple uses a 15s minimum TTL, however we do not have the same
135    level of rate limit and safe guards so we use 1/4 of the recommended value.
136    """
137    if record.ttl and record.ttl < _DNS_PTR_MIN_TTL and isinstance(record, DNSPointer):
138        log.debug(
139            "Increasing effective ttl of %s to minimum of %s to protect against excessive refreshes.",
140            record,
141            _DNS_PTR_MIN_TTL,
142        )
143        record.set_created_ttl(record.created, _DNS_PTR_MIN_TTL)
144
145
146class _QueryResponse:
147    """A pair for unicast and multicast DNSOutgoing responses."""
148
149    def __init__(self, cache: DNSCache, msgs: List[DNSIncoming]) -> None:
150        """Build a query response."""
151        self._is_probe = any(_message_is_probe(msg) for msg in msgs)
152        self._msg = msgs[0]
153        self._now = self._msg.now
154        self._cache = cache
155        self._additionals: _AnswerWithAdditionalsType = {}
156        self._ucast: Set[DNSRecord] = set()
157        self._mcast_now: Set[DNSRecord] = set()
158        self._mcast_aggregate: Set[DNSRecord] = set()
159        self._mcast_aggregate_last_second: Set[DNSRecord] = set()
160
161    def add_qu_question_response(self, answers: _AnswerWithAdditionalsType) -> None:
162        """Generate a response to a multicast QU query."""
163        for record, additionals in answers.items():
164            self._additionals[record] = additionals
165            if self._is_probe:
166                self._ucast.add(record)
167            if not self._has_mcast_within_one_quarter_ttl(record):
168                self._mcast_now.add(record)
169            elif not self._is_probe:
170                self._ucast.add(record)
171
172    def add_ucast_question_response(self, answers: _AnswerWithAdditionalsType) -> None:
173        """Generate a response to a unicast query."""
174        self._additionals.update(answers)
175        self._ucast.update(answers.keys())
176
177    def add_mcast_question_response(self, answers: _AnswerWithAdditionalsType) -> None:
178        """Generate a response to a multicast query."""
179        self._additionals.update(answers)
180        for answer in answers:
181            if self._is_probe:
182                self._mcast_now.add(answer)
183                continue
184
185            if self._has_mcast_record_in_last_second(answer):
186                self._mcast_aggregate_last_second.add(answer)
187            elif len(self._msg.questions) == 1 and self._msg.questions[0].type in _RESPOND_IMMEDIATE_TYPES:
188                self._mcast_now.add(answer)
189            else:
190                self._mcast_aggregate.add(answer)
191
192    def _generate_answers_with_additionals(self, rrset: Set[DNSRecord]) -> _AnswerWithAdditionalsType:
193        """Create answers with additionals from an rrset."""
194        return {record: self._additionals[record] for record in rrset}
195
196    def answers(
197        self,
198    ) -> QuestionAnswers:
199        """Return answer sets that will be queued."""
200        return QuestionAnswers(
201            self._generate_answers_with_additionals(self._ucast),
202            self._generate_answers_with_additionals(self._mcast_now),
203            self._generate_answers_with_additionals(self._mcast_aggregate),
204            self._generate_answers_with_additionals(self._mcast_aggregate_last_second),
205        )
206
207    def _has_mcast_within_one_quarter_ttl(self, record: DNSRecord) -> bool:
208        """Check to see if a record has been mcasted recently.
209
210        https://datatracker.ietf.org/doc/html/rfc6762#section-5.4
211        When receiving a question with the unicast-response bit set, a
212        responder SHOULD usually respond with a unicast packet directed back
213        to the querier.  However, if the responder has not multicast that
214        record recently (within one quarter of its TTL), then the responder
215        SHOULD instead multicast the response so as to keep all the peer
216        caches up to date
217        """
218        maybe_entry = self._cache.async_get_unique(cast(_UniqueRecordsType, record))
219        return bool(maybe_entry and maybe_entry.is_recent(self._now))
220
221    def _has_mcast_record_in_last_second(self, record: DNSRecord) -> bool:
222        """Check if an answer was seen in the last second.
223        Protect the network against excessive packet flooding
224        https://datatracker.ietf.org/doc/html/rfc6762#section-14
225        """
226        maybe_entry = self._cache.async_get_unique(cast(_UniqueRecordsType, record))
227        return bool(maybe_entry and self._now - maybe_entry.created < _ONE_SECOND)
228
229
230def _get_address_and_nsec_records(service: ServiceInfo, now: float) -> Set[DNSRecord]:
231    """Build a set of address records and NSEC records for non-present record types."""
232    seen_types: Set[int] = set()
233    records: Set[DNSRecord] = set()
234    for dns_address in service.dns_addresses(created=now):
235        seen_types.add(dns_address.type)
236        records.add(dns_address)
237    missing_types: Set[int] = _ADDRESS_RECORD_TYPES - seen_types
238    if missing_types:
239        records.add(construct_nsec_record(service.server, list(missing_types), now))
240    return records
241
242
243class QueryHandler:
244    """Query the ServiceRegistry."""
245
246    def __init__(self, registry: ServiceRegistry, cache: DNSCache, question_history: QuestionHistory) -> None:
247        """Init the query handler."""
248        self.registry = registry
249        self.cache = cache
250        self.question_history = question_history
251
252    def _add_service_type_enumeration_query_answers(
253        self, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, now: float
254    ) -> None:
255        """Provide an answer to a service type enumeration query.
256
257        https://datatracker.ietf.org/doc/html/rfc6763#section-9
258        """
259        for stype in self.registry.async_get_types():
260            dns_pointer = DNSPointer(
261                _SERVICE_TYPE_ENUMERATION_NAME, _TYPE_PTR, _CLASS_IN, _DNS_OTHER_TTL, stype, now
262            )
263            if not known_answers.suppresses(dns_pointer):
264                answer_set[dns_pointer] = set()
265
266    def _add_pointer_answers(
267        self, name: str, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, now: float
268    ) -> None:
269        """Answer PTR/ANY question."""
270        for service in self.registry.async_get_infos_type(name):
271            # Add recommended additional answers according to
272            # https://tools.ietf.org/html/rfc6763#section-12.1.
273            dns_pointer = service.dns_pointer(created=now)
274            if known_answers.suppresses(dns_pointer):
275                continue
276            additionals: Set[DNSRecord] = {service.dns_service(created=now), service.dns_text(created=now)}
277            additionals |= _get_address_and_nsec_records(service, now)
278            answer_set[dns_pointer] = additionals
279
280    def _add_address_answers(
281        self,
282        name: str,
283        answer_set: _AnswerWithAdditionalsType,
284        known_answers: DNSRRSet,
285        now: float,
286        type_: int,
287    ) -> None:
288        """Answer A/AAAA/ANY question."""
289        for service in self.registry.async_get_infos_server(name):
290            answers: List[DNSAddress] = []
291            additionals: Set[DNSRecord] = set()
292            seen_types: Set[int] = set()
293            for dns_address in service.dns_addresses(created=now):
294                seen_types.add(dns_address.type)
295                if dns_address.type != type_:
296                    additionals.add(dns_address)
297                elif not known_answers.suppresses(dns_address):
298                    answers.append(dns_address)
299            missing_types: Set[int] = _ADDRESS_RECORD_TYPES - seen_types
300            if answers:
301                if missing_types:
302                    additionals.add(construct_nsec_record(service.server, list(missing_types), now))
303                for answer in answers:
304                    answer_set[answer] = additionals
305            elif type_ in missing_types:
306                answer_set[construct_nsec_record(service.server, list(missing_types), now)] = set()
307
308    def _answer_question(
309        self,
310        question: DNSQuestion,
311        known_answers: DNSRRSet,
312        now: float,
313    ) -> _AnswerWithAdditionalsType:
314        answer_set: _AnswerWithAdditionalsType = {}
315
316        if question.type == _TYPE_PTR and question.name.lower() == _SERVICE_TYPE_ENUMERATION_NAME:
317            self._add_service_type_enumeration_query_answers(answer_set, known_answers, now)
318            return answer_set
319
320        type_ = question.type
321
322        if type_ in (_TYPE_PTR, _TYPE_ANY):
323            self._add_pointer_answers(question.name, answer_set, known_answers, now)
324
325        if type_ in (_TYPE_A, _TYPE_AAAA, _TYPE_ANY):
326            self._add_address_answers(question.name, answer_set, known_answers, now, type_)
327
328        if type_ in (_TYPE_SRV, _TYPE_TXT, _TYPE_ANY):
329            service = self.registry.async_get_info_name(question.name)  # type: ignore
330            if service is not None:
331                if type_ in (_TYPE_SRV, _TYPE_ANY):
332                    # Add recommended additional answers according to
333                    # https://tools.ietf.org/html/rfc6763#section-12.2.
334                    dns_service = service.dns_service(created=now)
335                    if not known_answers.suppresses(dns_service):
336                        answer_set[dns_service] = _get_address_and_nsec_records(service, now)
337                if type_ in (_TYPE_TXT, _TYPE_ANY):
338                    dns_text = service.dns_text(created=now)
339                    if not known_answers.suppresses(dns_text):
340                        answer_set[dns_text] = set()
341
342        return answer_set
343
344    def async_response(  # pylint: disable=unused-argument
345        self, msgs: List[DNSIncoming], ucast_source: bool
346    ) -> QuestionAnswers:
347        """Deal with incoming query packets. Provides a response if possible.
348
349        This function must be run in the event loop as it is not
350        threadsafe.
351        """
352        known_answers = DNSRRSet(
353            itertools.chain.from_iterable(msg.answers for msg in msgs if not _message_is_probe(msg))
354        )
355        query_res = _QueryResponse(self.cache, msgs)
356
357        for msg in msgs:
358            for question in msg.questions:
359                if not question.unicast:
360                    self.question_history.add_question_at_time(question, msg.now, set(known_answers.lookup))
361                answer_set = self._answer_question(question, known_answers, msg.now)
362                if not ucast_source and question.unicast:
363                    query_res.add_qu_question_response(answer_set)
364                    continue
365                if ucast_source:
366                    query_res.add_ucast_question_response(answer_set)
367                # We always multicast as well even if its a unicast
368                # source as long as we haven't done it recently (75% of ttl)
369                query_res.add_mcast_question_response(answer_set)
370
371        return query_res.answers()
372
373
374class RecordManager:
375    """Process records into the cache and notify listeners."""
376
377    def __init__(self, zeroconf: 'Zeroconf') -> None:
378        """Init the record manager."""
379        self.zc = zeroconf
380        self.cache = zeroconf.cache
381        self.listeners: List[RecordUpdateListener] = []
382
383    def async_updates(self, now: float, records: List[RecordUpdate]) -> None:
384        """Used to notify listeners of new information that has updated
385        a record.
386
387        This method must be called before the cache is updated.
388
389        This method will be run in the event loop.
390        """
391        for listener in self.listeners:
392            listener.async_update_records(self.zc, now, records)
393
394    def async_updates_complete(self) -> None:
395        """Used to notify listeners of new information that has updated
396        a record.
397
398        This method must be called after the cache is updated.
399
400        This method will be run in the event loop.
401        """
402        for listener in self.listeners:
403            listener.async_update_records_complete()
404        self.zc.async_notify_all()
405
406    def async_updates_from_response(self, msg: DNSIncoming) -> None:
407        """Deal with incoming response packets.  All answers
408        are held in the cache, and listeners are notified.
409
410        This function must be run in the event loop as it is not
411        threadsafe.
412        """
413        updates: List[RecordUpdate] = []
414        address_adds: List[DNSAddress] = []
415        other_adds: List[DNSRecord] = []
416        removes: Set[DNSRecord] = set()
417        now = msg.now
418        unique_types: Set[Tuple[str, int, int]] = set()
419
420        for record in msg.answers:
421            sanitize_incoming_record(record)
422
423            if record.unique:  # https://tools.ietf.org/html/rfc6762#section-10.2
424                unique_types.add((record.name, record.type, record.class_))
425
426            maybe_entry = self.cache.async_get_unique(cast(_UniqueRecordsType, record))
427            if not record.is_expired(now):
428                if maybe_entry is not None:
429                    maybe_entry.reset_ttl(record)
430                else:
431                    if isinstance(record, DNSAddress):
432                        address_adds.append(record)
433                    else:
434                        other_adds.append(record)
435                updates.append(RecordUpdate(record, maybe_entry))
436            # This is likely a goodbye since the record is
437            # expired and exists in the cache
438            elif maybe_entry is not None:
439                updates.append(RecordUpdate(record, maybe_entry))
440                removes.add(record)
441
442        if unique_types:
443            self._async_mark_unique_cached_records_older_than_1s_to_expire(unique_types, msg.answers, now)
444
445        if updates:
446            self.async_updates(now, updates)
447        # The cache adds must be processed AFTER we trigger
448        # the updates since we compare existing data
449        # with the new data and updating the cache
450        # ahead of update_record will cause listeners
451        # to miss changes
452        #
453        # We must process address adds before non-addresses
454        # otherwise a fetch of ServiceInfo may miss an address
455        # because it thinks the cache is complete
456        #
457        # The cache is processed under the context manager to ensure
458        # that any ServiceBrowser that is going to call
459        # zc.get_service_info will see the cached value
460        # but ONLY after all the record updates have been
461        # processsed.
462        if other_adds or address_adds:
463            self.cache.async_add_records(itertools.chain(address_adds, other_adds))
464        # Removes are processed last since
465        # ServiceInfo could generate an un-needed query
466        # because the data was not yet populated.
467        if removes:
468            self.cache.async_remove_records(removes)
469        if updates:
470            self.async_updates_complete()
471
472    def _async_mark_unique_cached_records_older_than_1s_to_expire(
473        self, unique_types: Set[Tuple[str, int, int]], answers: Iterable[DNSRecord], now: float
474    ) -> None:
475        # rfc6762#section-10.2 para 2
476        # Since unique is set, all old records with that name, rrtype,
477        # and rrclass that were received more than one second ago are declared
478        # invalid, and marked to expire from the cache in one second.
479        answers_rrset = DNSRRSet(answers)
480        for name, type_, class_ in unique_types:
481            for entry in self.cache.async_all_by_details(name, type_, class_):
482                if (now - entry.created > _ONE_SECOND) and entry not in answers_rrset:
483                    # Expire in 1s
484                    entry.set_created_ttl(now, 1)
485
486    def async_add_listener(
487        self, listener: RecordUpdateListener, question: Optional[Union[DNSQuestion, List[DNSQuestion]]]
488    ) -> None:
489        """Adds a listener for a given question.  The listener will have
490        its update_record method called when information is available to
491        answer the question(s).
492
493        This function is not threadsafe and must be called in the eventloop.
494        """
495        if not isinstance(listener, RecordUpdateListener):
496            log.error(
497                "listeners passed to async_add_listener must inherit from RecordUpdateListener;"
498                " In the future this will fail"
499            )
500
501        self.listeners.append(listener)
502
503        if question is None:
504            return
505
506        questions = [question] if isinstance(question, DNSQuestion) else question
507        assert self.zc.loop is not None
508        self._async_update_matching_records(listener, questions)
509
510    def _async_update_matching_records(
511        self, listener: RecordUpdateListener, questions: List[DNSQuestion]
512    ) -> None:
513        """Calls back any existing entries in the cache that answer the question.
514
515        This function must be run from the event loop.
516        """
517        now = current_time_millis()
518        records: List[RecordUpdate] = []
519        for question in questions:
520            for record in self.cache.async_entries_with_name(question.name):
521                if not record.is_expired(now) and question.answered_by(record):
522                    records.append(RecordUpdate(record, None))
523
524        if not records:
525            return
526        listener.async_update_records(self.zc, now, records)
527        listener.async_update_records_complete()
528        self.zc.async_notify_all()
529
530    def async_remove_listener(self, listener: RecordUpdateListener) -> None:
531        """Removes a listener.
532
533        This function is not threadsafe and must be called in the eventloop.
534        """
535        try:
536            self.listeners.remove(listener)
537            self.zc.async_notify_all()
538        except ValueError as e:
539            log.exception('Failed to remove listener: %r', e)
540
541
542class MulticastOutgoingQueue:
543    """An outgoing queue used to aggregate multicast responses."""
544
545    def __init__(self, zeroconf: 'Zeroconf', additional_delay: int, max_aggregation_delay: int) -> None:
546        self.zc = zeroconf
547        self.queue: deque = deque()
548        # Additional delay is used to implement
549        # Protect the network against excessive packet flooding
550        # https://datatracker.ietf.org/doc/html/rfc6762#section-14
551        self.additional_delay = additional_delay
552        self.aggregation_delay = max_aggregation_delay
553
554    def async_add(self, now: float, answers: _AnswerWithAdditionalsType) -> None:
555        """Add a group of answers with additionals to the outgoing queue."""
556        assert self.zc.loop is not None
557        random_delay = random.randint(*_MULTICAST_DELAY_RANDOM_INTERVAL) + self.additional_delay
558        send_after = now + random_delay
559        send_before = now + self.aggregation_delay + self.additional_delay
560        if len(self.queue):
561            # If we calculate a random delay for the send after time
562            # that is less than the last group scheduled to go out,
563            # we instead add the answers to the last group as this
564            # allows aggregating additonal responses
565            last_group = self.queue[-1]
566            if send_after <= last_group.send_after:
567                last_group.answers.update(answers)
568                return
569        else:
570            self.zc.loop.call_later(millis_to_seconds(random_delay), self.async_ready)
571        self.queue.append(AnswerGroup(send_after, send_before, answers))
572
573    def _remove_answers_from_queue(self, answers: _AnswerWithAdditionalsType) -> None:
574        """Remove a set of answers from the outgoing queue."""
575        for pending in self.queue:
576            for record in answers:
577                pending.answers.pop(record, None)
578
579    def async_ready(self) -> None:
580        """Process anything in the queue that is ready."""
581        assert self.zc.loop is not None
582        now = current_time_millis()
583
584        if len(self.queue) > 1 and self.queue[0].send_before > now:
585            # There is more than one answer in the queue,
586            # delay until we have to send it (first answer group reaches send_before)
587            self.zc.loop.call_later(millis_to_seconds(self.queue[0].send_before - now), self.async_ready)
588            return
589
590        answers: _AnswerWithAdditionalsType = {}
591        # Add all groups that can be sent now
592        while len(self.queue) and self.queue[0].send_after <= now:
593            answers.update(self.queue.popleft().answers)
594
595        if len(self.queue):
596            # If there are still groups in the queue that are not ready to send
597            # be sure we schedule them to go out later
598            self.zc.loop.call_later(millis_to_seconds(self.queue[0].send_after - now), self.async_ready)
599
600        if answers:
601            # If we have the same answer scheduled to go out, remove them
602            self._remove_answers_from_queue(answers)
603            self.zc.async_send(construct_outgoing_multicast_answers(answers))
604