1""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine
2    Copyright 2003 Paul Scott-Murphy, 2014 William McBrine
3
4    This module provides a framework for the use of DNS Service Discovery
5    using IP multicast.
6
7    This library is free software; you can redistribute it and/or
8    modify it under the terms of the GNU Lesser General Public
9    License as published by the Free Software Foundation; either
10    version 2.1 of the License, or (at your option) any later version.
11
12    This library is distributed in the hope that it will be useful,
13    but WITHOUT ANY WARRANTY; without even the implied warranty of
14    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
15    Lesser General Public License for more details.
16
17    You should have received a copy of the GNU Lesser General Public
18    License along with this library; if not, write to the Free Software
19    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
20    USA
21"""
22
23import itertools
24from typing import Dict, Iterable, Iterator, List, Optional, Union, cast
25
26from ._dns import (
27    DNSAddress,
28    DNSEntry,
29    DNSHinfo,
30    DNSPointer,
31    DNSRecord,
32    DNSService,
33    DNSText,
34    dns_entry_matches,
35)
36from ._utils.time import current_time_millis
37from .const import _TYPE_PTR
38
39_UNIQUE_RECORD_TYPES = (DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService)
40_UniqueRecordsType = Union[DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService]
41_DNSRecordCacheType = Dict[str, Dict[DNSRecord, DNSRecord]]
42
43
44def _remove_key(cache: _DNSRecordCacheType, key: str, entry: DNSRecord) -> None:
45    """Remove a key from a DNSRecord cache
46
47    This function must be run in from event loop.
48    """
49    del cache[key][entry]
50    if not cache[key]:
51        del cache[key]
52
53
54class DNSCache:
55    """A cache of DNS entries."""
56
57    def __init__(self) -> None:
58        self.cache: _DNSRecordCacheType = {}
59        self.service_cache: _DNSRecordCacheType = {}
60
61    # Functions prefixed with async_ are NOT threadsafe and must
62    # be run in the event loop.
63
64    def _async_add(self, entry: DNSRecord) -> None:
65        """Adds an entry.
66
67        This function must be run in from event loop.
68        """
69        # Previously storage of records was implemented as a list
70        # instead a dict. Since DNSRecords are now hashable, the implementation
71        # uses a dict to ensure that adding a new record to the cache
72        # replaces any existing records that are __eq__ to each other which
73        # removes the risk that accessing the cache from the wrong
74        # direction would return the old incorrect entry.
75        self.cache.setdefault(entry.key, {})[entry] = entry
76        if isinstance(entry, DNSService):
77            self.service_cache.setdefault(entry.server, {})[entry] = entry
78
79    def async_add_records(self, entries: Iterable[DNSRecord]) -> None:
80        """Add multiple records.
81
82        This function must be run in from event loop.
83        """
84        for entry in entries:
85            self._async_add(entry)
86
87    def _async_remove(self, entry: DNSRecord) -> None:
88        """Removes an entry.
89
90        This function must be run in from event loop.
91        """
92        if isinstance(entry, DNSService):
93            _remove_key(self.service_cache, entry.server, entry)
94        _remove_key(self.cache, entry.key, entry)
95
96    def async_remove_records(self, entries: Iterable[DNSRecord]) -> None:
97        """Remove multiple records.
98
99        This function must be run in from event loop.
100        """
101        for entry in entries:
102            self._async_remove(entry)
103
104    def async_expire(self, now: float) -> List[DNSRecord]:
105        """Purge expired entries from the cache.
106
107        This function must be run in from event loop.
108        """
109        expired = [record for record in itertools.chain(*self.cache.values()) if record.is_expired(now)]
110        self.async_remove_records(expired)
111        return expired
112
113    def async_get_unique(self, entry: _UniqueRecordsType) -> Optional[DNSRecord]:
114        """Gets a unique entry by key.  Will return None if there is no
115        matching entry.
116
117        This function is not threadsafe and must be called from
118        the event loop.
119        """
120        return self.cache.get(entry.key, {}).get(entry)
121
122    def async_all_by_details(self, name: str, type_: int, class_: int) -> Iterator[DNSRecord]:
123        """Gets all matching entries by details.
124
125        This function is not threadsafe and must be called from
126        the event loop.
127        """
128        key = name.lower()
129        for entry in self.cache.get(key, []):
130            if dns_entry_matches(entry, key, type_, class_):
131                yield entry
132
133    def async_entries_with_name(self, name: str) -> Dict[DNSRecord, DNSRecord]:
134        """Returns a dict of entries whose key matches the name.
135
136        This function is not threadsafe and must be called from
137        the event loop.
138        """
139        return self.cache.get(name.lower(), {})
140
141    def async_entries_with_server(self, name: str) -> Dict[DNSRecord, DNSRecord]:
142        """Returns a dict of entries whose key matches the server.
143
144        This function is not threadsafe and must be called from
145        the event loop.
146        """
147        return self.service_cache.get(name.lower(), {})
148
149    # The below functions are threadsafe and do not need to be run in the
150    # event loop, however they all make copies so they significantly
151    # inefficent
152
153    def get(self, entry: DNSEntry) -> Optional[DNSRecord]:
154        """Gets an entry by key.  Will return None if there is no
155        matching entry."""
156        if isinstance(entry, _UNIQUE_RECORD_TYPES):
157            return self.cache.get(entry.key, {}).get(entry)
158        for cached_entry in reversed(list(self.cache.get(entry.key, []))):
159            if entry.__eq__(cached_entry):
160                return cached_entry
161        return None
162
163    def get_by_details(self, name: str, type_: int, class_: int) -> Optional[DNSRecord]:
164        """Gets the first matching entry by details. Returns None if no entries match.
165
166        Calling this function is not recommended as it will only
167        return one record even if there are multiple entries.
168
169        For example if there are multiple A or AAAA addresses this
170        function will return the last one that was added to the cache
171        which may not be the one you expect.
172
173        Use get_all_by_details instead.
174        """
175        key = name.lower()
176        for cached_entry in reversed(list(self.cache.get(key, []))):
177            if dns_entry_matches(cached_entry, key, type_, class_):
178                return cached_entry
179        return None
180
181    def get_all_by_details(self, name: str, type_: int, class_: int) -> List[DNSRecord]:
182        """Gets all matching entries by details."""
183        key = name.lower()
184        return [
185            entry for entry in list(self.cache.get(key, [])) if dns_entry_matches(entry, key, type_, class_)
186        ]
187
188    def entries_with_server(self, server: str) -> List[DNSRecord]:
189        """Returns a list of entries whose server matches the name."""
190        return list(self.service_cache.get(server.lower(), []))
191
192    def entries_with_name(self, name: str) -> List[DNSRecord]:
193        """Returns a list of entries whose key matches the name."""
194        return list(self.cache.get(name.lower(), []))
195
196    def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[DNSRecord]:
197        now = current_time_millis()
198        for record in reversed(self.entries_with_name(name)):
199            if (
200                record.type == _TYPE_PTR
201                and not record.is_expired(now)
202                and cast(DNSPointer, record).alias == alias
203            ):
204                return record
205        return None
206
207    def names(self) -> List[str]:
208        """Return a copy of the list of current cache names."""
209        return list(self.cache)
210