1""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine 2 Copyright 2003 Paul Scott-Murphy, 2014 William McBrine 3 4 This module provides a framework for the use of DNS Service Discovery 5 using IP multicast. 6 7 This library is free software; you can redistribute it and/or 8 modify it under the terms of the GNU Lesser General Public 9 License as published by the Free Software Foundation; either 10 version 2.1 of the License, or (at your option) any later version. 11 12 This library is distributed in the hope that it will be useful, 13 but WITHOUT ANY WARRANTY; without even the implied warranty of 14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 15 Lesser General Public License for more details. 16 17 You should have received a copy of the GNU Lesser General Public 18 License along with this library; if not, write to the Free Software 19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 20 USA 21""" 22 23import itertools 24from typing import Dict, Iterable, Iterator, List, Optional, Union, cast 25 26from ._dns import ( 27 DNSAddress, 28 DNSEntry, 29 DNSHinfo, 30 DNSPointer, 31 DNSRecord, 32 DNSService, 33 DNSText, 34 dns_entry_matches, 35) 36from ._utils.time import current_time_millis 37from .const import _TYPE_PTR 38 39_UNIQUE_RECORD_TYPES = (DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService) 40_UniqueRecordsType = Union[DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService] 41_DNSRecordCacheType = Dict[str, Dict[DNSRecord, DNSRecord]] 42 43 44def _remove_key(cache: _DNSRecordCacheType, key: str, entry: DNSRecord) -> None: 45 """Remove a key from a DNSRecord cache 46 47 This function must be run in from event loop. 48 """ 49 del cache[key][entry] 50 if not cache[key]: 51 del cache[key] 52 53 54class DNSCache: 55 """A cache of DNS entries.""" 56 57 def __init__(self) -> None: 58 self.cache: _DNSRecordCacheType = {} 59 self.service_cache: _DNSRecordCacheType = {} 60 61 # Functions prefixed with async_ are NOT threadsafe and must 62 # be run in the event loop. 63 64 def _async_add(self, entry: DNSRecord) -> None: 65 """Adds an entry. 66 67 This function must be run in from event loop. 68 """ 69 # Previously storage of records was implemented as a list 70 # instead a dict. Since DNSRecords are now hashable, the implementation 71 # uses a dict to ensure that adding a new record to the cache 72 # replaces any existing records that are __eq__ to each other which 73 # removes the risk that accessing the cache from the wrong 74 # direction would return the old incorrect entry. 75 self.cache.setdefault(entry.key, {})[entry] = entry 76 if isinstance(entry, DNSService): 77 self.service_cache.setdefault(entry.server, {})[entry] = entry 78 79 def async_add_records(self, entries: Iterable[DNSRecord]) -> None: 80 """Add multiple records. 81 82 This function must be run in from event loop. 83 """ 84 for entry in entries: 85 self._async_add(entry) 86 87 def _async_remove(self, entry: DNSRecord) -> None: 88 """Removes an entry. 89 90 This function must be run in from event loop. 91 """ 92 if isinstance(entry, DNSService): 93 _remove_key(self.service_cache, entry.server, entry) 94 _remove_key(self.cache, entry.key, entry) 95 96 def async_remove_records(self, entries: Iterable[DNSRecord]) -> None: 97 """Remove multiple records. 98 99 This function must be run in from event loop. 100 """ 101 for entry in entries: 102 self._async_remove(entry) 103 104 def async_expire(self, now: float) -> List[DNSRecord]: 105 """Purge expired entries from the cache. 106 107 This function must be run in from event loop. 108 """ 109 expired = [record for record in itertools.chain(*self.cache.values()) if record.is_expired(now)] 110 self.async_remove_records(expired) 111 return expired 112 113 def async_get_unique(self, entry: _UniqueRecordsType) -> Optional[DNSRecord]: 114 """Gets a unique entry by key. Will return None if there is no 115 matching entry. 116 117 This function is not threadsafe and must be called from 118 the event loop. 119 """ 120 return self.cache.get(entry.key, {}).get(entry) 121 122 def async_all_by_details(self, name: str, type_: int, class_: int) -> Iterator[DNSRecord]: 123 """Gets all matching entries by details. 124 125 This function is not threadsafe and must be called from 126 the event loop. 127 """ 128 key = name.lower() 129 for entry in self.cache.get(key, []): 130 if dns_entry_matches(entry, key, type_, class_): 131 yield entry 132 133 def async_entries_with_name(self, name: str) -> Dict[DNSRecord, DNSRecord]: 134 """Returns a dict of entries whose key matches the name. 135 136 This function is not threadsafe and must be called from 137 the event loop. 138 """ 139 return self.cache.get(name.lower(), {}) 140 141 def async_entries_with_server(self, name: str) -> Dict[DNSRecord, DNSRecord]: 142 """Returns a dict of entries whose key matches the server. 143 144 This function is not threadsafe and must be called from 145 the event loop. 146 """ 147 return self.service_cache.get(name.lower(), {}) 148 149 # The below functions are threadsafe and do not need to be run in the 150 # event loop, however they all make copies so they significantly 151 # inefficent 152 153 def get(self, entry: DNSEntry) -> Optional[DNSRecord]: 154 """Gets an entry by key. Will return None if there is no 155 matching entry.""" 156 if isinstance(entry, _UNIQUE_RECORD_TYPES): 157 return self.cache.get(entry.key, {}).get(entry) 158 for cached_entry in reversed(list(self.cache.get(entry.key, []))): 159 if entry.__eq__(cached_entry): 160 return cached_entry 161 return None 162 163 def get_by_details(self, name: str, type_: int, class_: int) -> Optional[DNSRecord]: 164 """Gets the first matching entry by details. Returns None if no entries match. 165 166 Calling this function is not recommended as it will only 167 return one record even if there are multiple entries. 168 169 For example if there are multiple A or AAAA addresses this 170 function will return the last one that was added to the cache 171 which may not be the one you expect. 172 173 Use get_all_by_details instead. 174 """ 175 key = name.lower() 176 for cached_entry in reversed(list(self.cache.get(key, []))): 177 if dns_entry_matches(cached_entry, key, type_, class_): 178 return cached_entry 179 return None 180 181 def get_all_by_details(self, name: str, type_: int, class_: int) -> List[DNSRecord]: 182 """Gets all matching entries by details.""" 183 key = name.lower() 184 return [ 185 entry for entry in list(self.cache.get(key, [])) if dns_entry_matches(entry, key, type_, class_) 186 ] 187 188 def entries_with_server(self, server: str) -> List[DNSRecord]: 189 """Returns a list of entries whose server matches the name.""" 190 return list(self.service_cache.get(server.lower(), [])) 191 192 def entries_with_name(self, name: str) -> List[DNSRecord]: 193 """Returns a list of entries whose key matches the name.""" 194 return list(self.cache.get(name.lower(), [])) 195 196 def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[DNSRecord]: 197 now = current_time_millis() 198 for record in reversed(self.entries_with_name(name)): 199 if ( 200 record.type == _TYPE_PTR 201 and not record.is_expired(now) 202 and cast(DNSPointer, record).alias == alias 203 ): 204 return record 205 return None 206 207 def names(self) -> List[str]: 208 """Return a copy of the list of current cache names.""" 209 return list(self.cache) 210