1# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license 2 3"""DNS Versioned Zones.""" 4 5import collections 6try: 7 import threading as _threading 8except ImportError: # pragma: no cover 9 import dummy_threading as _threading # type: ignore 10 11import dns.exception 12import dns.immutable 13import dns.name 14import dns.node 15import dns.rdataclass 16import dns.rdatatype 17import dns.rdata 18import dns.rdtypes.ANY.SOA 19import dns.transaction 20import dns.zone 21 22 23class UseTransaction(dns.exception.DNSException): 24 """To alter a versioned zone, use a transaction.""" 25 26 27class Version: 28 def __init__(self, zone, id): 29 self.zone = zone 30 self.id = id 31 self.nodes = {} 32 33 def _validate_name(self, name): 34 if name.is_absolute(): 35 if not name.is_subdomain(self.zone.origin): 36 raise KeyError("name is not a subdomain of the zone origin") 37 if self.zone.relativize: 38 name = name.relativize(self.origin) 39 return name 40 41 def get_node(self, name): 42 name = self._validate_name(name) 43 return self.nodes.get(name) 44 45 def get_rdataset(self, name, rdtype, covers): 46 node = self.get_node(name) 47 if node is None: 48 return None 49 return node.get_rdataset(self.zone.rdclass, rdtype, covers) 50 51 def items(self): 52 return self.nodes.items() # pylint: disable=dict-items-not-iterating 53 54 55class WritableVersion(Version): 56 def __init__(self, zone, replacement=False): 57 # The zone._versions_lock must be held by our caller. 58 if len(zone._versions) > 0: 59 id = zone._versions[-1].id + 1 60 else: 61 id = 1 62 super().__init__(zone, id) 63 if not replacement: 64 # We copy the map, because that gives us a simple and thread-safe 65 # way of doing versions, and we have a garbage collector to help 66 # us. We only make new node objects if we actually change the 67 # node. 68 self.nodes.update(zone.nodes) 69 # We have to copy the zone origin as it may be None in the first 70 # version, and we don't want to mutate the zone until we commit. 71 self.origin = zone.origin 72 self.changed = set() 73 74 def _maybe_cow(self, name): 75 name = self._validate_name(name) 76 node = self.nodes.get(name) 77 if node is None or node.id != self.id: 78 new_node = self.zone.node_factory() 79 new_node.id = self.id 80 if node is not None: 81 # moo! copy on write! 82 new_node.rdatasets.extend(node.rdatasets) 83 self.nodes[name] = new_node 84 self.changed.add(name) 85 return new_node 86 else: 87 return node 88 89 def delete_node(self, name): 90 name = self._validate_name(name) 91 if name in self.nodes: 92 del self.nodes[name] 93 self.changed.add(name) 94 95 def put_rdataset(self, name, rdataset): 96 node = self._maybe_cow(name) 97 node.replace_rdataset(rdataset) 98 99 def delete_rdataset(self, name, rdtype, covers): 100 node = self._maybe_cow(name) 101 node.delete_rdataset(self.zone.rdclass, rdtype, covers) 102 if len(node) == 0: 103 del self.nodes[name] 104 105 106@dns.immutable.immutable 107class ImmutableVersion(Version): 108 def __init__(self, version): 109 # We tell super() that it's a replacement as we don't want it 110 # to copy the nodes, as we're about to do that with an 111 # immutable Dict. 112 super().__init__(version.zone, True) 113 # set the right id! 114 self.id = version.id 115 # Make changed nodes immutable 116 for name in version.changed: 117 node = version.nodes.get(name) 118 # it might not exist if we deleted it in the version 119 if node: 120 version.nodes[name] = ImmutableNode(node) 121 self.nodes = dns.immutable.Dict(version.nodes, True) 122 123 124# A node with a version id. 125 126class Node(dns.node.Node): 127 __slots__ = ['id'] 128 129 def __init__(self): 130 super().__init__() 131 # A proper id will get set by the Version 132 self.id = 0 133 134 135@dns.immutable.immutable 136class ImmutableNode(Node): 137 __slots__ = ['id'] 138 139 def __init__(self, node): 140 super().__init__() 141 self.id = node.id 142 self.rdatasets = tuple( 143 [dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets] 144 ) 145 146 def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE, 147 create=False): 148 if create: 149 raise TypeError("immutable") 150 return super().find_rdataset(rdclass, rdtype, covers, False) 151 152 def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE, 153 create=False): 154 if create: 155 raise TypeError("immutable") 156 return super().get_rdataset(rdclass, rdtype, covers, False) 157 158 def delete_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE): 159 raise TypeError("immutable") 160 161 def replace_rdataset(self, replacement): 162 raise TypeError("immutable") 163 164 165class Zone(dns.zone.Zone): 166 167 __slots__ = ['_versions', '_versions_lock', '_write_txn', 168 '_write_waiters', '_write_event', '_pruning_policy', 169 '_readers'] 170 171 node_factory = Node 172 173 def __init__(self, origin, rdclass=dns.rdataclass.IN, relativize=True, 174 pruning_policy=None): 175 """Initialize a versioned zone object. 176 177 *origin* is the origin of the zone. It may be a ``dns.name.Name``, 178 a ``str``, or ``None``. If ``None``, then the zone's origin will 179 be set by the first ``$ORIGIN`` line in a zone file. 180 181 *rdclass*, an ``int``, the zone's rdata class; the default is class IN. 182 183 *relativize*, a ``bool``, determine's whether domain names are 184 relativized to the zone's origin. The default is ``True``. 185 186 *pruning policy*, a function taking a `Version` and returning 187 a `bool`, or `None`. Should the version be pruned? If `None`, 188 the default policy, which retains one version is used. 189 """ 190 super().__init__(origin, rdclass, relativize) 191 self._versions = collections.deque() 192 self._version_lock = _threading.Lock() 193 if pruning_policy is None: 194 self._pruning_policy = self._default_pruning_policy 195 else: 196 self._pruning_policy = pruning_policy 197 self._write_txn = None 198 self._write_event = None 199 self._write_waiters = collections.deque() 200 self._readers = set() 201 self._commit_version_unlocked(None, WritableVersion(self), origin) 202 203 def reader(self, id=None, serial=None): # pylint: disable=arguments-differ 204 if id is not None and serial is not None: 205 raise ValueError('cannot specify both id and serial') 206 with self._version_lock: 207 if id is not None: 208 version = None 209 for v in reversed(self._versions): 210 if v.id == id: 211 version = v 212 break 213 if version is None: 214 raise KeyError('version not found') 215 elif serial is not None: 216 if self.relativize: 217 oname = dns.name.empty 218 else: 219 oname = self.origin 220 version = None 221 for v in reversed(self._versions): 222 n = v.nodes.get(oname) 223 if n: 224 rds = n.get_rdataset(self.rdclass, dns.rdatatype.SOA) 225 if rds and rds[0].serial == serial: 226 version = v 227 break 228 if version is None: 229 raise KeyError('serial not found') 230 else: 231 version = self._versions[-1] 232 txn = Transaction(self, False, version) 233 self._readers.add(txn) 234 return txn 235 236 def writer(self, replacement=False): 237 event = None 238 while True: 239 with self._version_lock: 240 # Checking event == self._write_event ensures that either 241 # no one was waiting before we got lucky and found no write 242 # txn, or we were the one who was waiting and got woken up. 243 # This prevents "taking cuts" when creating a write txn. 244 if self._write_txn is None and event == self._write_event: 245 # Creating the transaction defers version setup 246 # (i.e. copying the nodes dictionary) until we 247 # give up the lock, so that we hold the lock as 248 # short a time as possible. This is why we call 249 # _setup_version() below. 250 self._write_txn = Transaction(self, replacement) 251 # give up our exclusive right to make a Transaction 252 self._write_event = None 253 break 254 # Someone else is writing already, so we will have to 255 # wait, but we want to do the actual wait outside the 256 # lock. 257 event = _threading.Event() 258 self._write_waiters.append(event) 259 # wait (note we gave up the lock!) 260 # 261 # We only wake one sleeper at a time, so it's important 262 # that no event waiter can exit this method (e.g. via 263 # cancelation) without returning a transaction or waking 264 # someone else up. 265 # 266 # This is not a problem with Threading module threads as 267 # they cannot be canceled, but could be an issue with trio 268 # or curio tasks when we do the async version of writer(). 269 # I.e. we'd need to do something like: 270 # 271 # try: 272 # event.wait() 273 # except trio.Cancelled: 274 # with self._version_lock: 275 # self._maybe_wakeup_one_waiter_unlocked() 276 # raise 277 # 278 event.wait() 279 # Do the deferred version setup. 280 self._write_txn._setup_version() 281 return self._write_txn 282 283 def _maybe_wakeup_one_waiter_unlocked(self): 284 if len(self._write_waiters) > 0: 285 self._write_event = self._write_waiters.popleft() 286 self._write_event.set() 287 288 # pylint: disable=unused-argument 289 def _default_pruning_policy(self, zone, version): 290 return True 291 # pylint: enable=unused-argument 292 293 def _prune_versions_unlocked(self): 294 assert len(self._versions) > 0 295 # Don't ever prune a version greater than or equal to one that 296 # a reader has open. This pins versions in memory while the 297 # reader is open, and importantly lets the reader open a txn on 298 # a successor version (e.g. if generating an IXFR). 299 # 300 # Note our definition of least_kept also ensures we do not try to 301 # delete the greatest version. 302 if len(self._readers) > 0: 303 least_kept = min(txn.version.id for txn in self._readers) 304 else: 305 least_kept = self._versions[-1].id 306 while self._versions[0].id < least_kept and \ 307 self._pruning_policy(self, self._versions[0]): 308 self._versions.popleft() 309 310 def set_max_versions(self, max_versions): 311 """Set a pruning policy that retains up to the specified number 312 of versions 313 """ 314 if max_versions is not None and max_versions < 1: 315 raise ValueError('max versions must be at least 1') 316 if max_versions is None: 317 def policy(*_): 318 return False 319 else: 320 def policy(zone, _): 321 return len(zone._versions) > max_versions 322 self.set_pruning_policy(policy) 323 324 def set_pruning_policy(self, policy): 325 """Set the pruning policy for the zone. 326 327 The *policy* function takes a `Version` and returns `True` if 328 the version should be pruned, and `False` otherwise. `None` 329 may also be specified for policy, in which case the default policy 330 is used. 331 332 Pruning checking proceeds from the least version and the first 333 time the function returns `False`, the checking stops. I.e. the 334 retained versions are always a consecutive sequence. 335 """ 336 if policy is None: 337 policy = self._default_pruning_policy 338 with self._version_lock: 339 self._pruning_policy = policy 340 self._prune_versions_unlocked() 341 342 def _end_read(self, txn): 343 with self._version_lock: 344 self._readers.remove(txn) 345 self._prune_versions_unlocked() 346 347 def _end_write_unlocked(self, txn): 348 assert self._write_txn == txn 349 self._write_txn = None 350 self._maybe_wakeup_one_waiter_unlocked() 351 352 def _end_write(self, txn): 353 with self._version_lock: 354 self._end_write_unlocked(txn) 355 356 def _commit_version_unlocked(self, txn, version, origin): 357 self._versions.append(version) 358 self._prune_versions_unlocked() 359 self.nodes = version.nodes 360 if self.origin is None: 361 self.origin = origin 362 # txn can be None in __init__ when we make the empty version. 363 if txn is not None: 364 self._end_write_unlocked(txn) 365 366 def _commit_version(self, txn, version, origin): 367 with self._version_lock: 368 self._commit_version_unlocked(txn, version, origin) 369 370 def find_node(self, name, create=False): 371 if create: 372 raise UseTransaction 373 return super().find_node(name) 374 375 def delete_node(self, name): 376 raise UseTransaction 377 378 def find_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE, 379 create=False): 380 if create: 381 raise UseTransaction 382 rdataset = super().find_rdataset(name, rdtype, covers) 383 return dns.rdataset.ImmutableRdataset(rdataset) 384 385 def get_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE, 386 create=False): 387 if create: 388 raise UseTransaction 389 rdataset = super().get_rdataset(name, rdtype, covers) 390 return dns.rdataset.ImmutableRdataset(rdataset) 391 392 def delete_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE): 393 raise UseTransaction 394 395 def replace_rdataset(self, name, replacement): 396 raise UseTransaction 397 398 399class Transaction(dns.transaction.Transaction): 400 401 def __init__(self, zone, replacement, version=None): 402 read_only = version is not None 403 super().__init__(zone, replacement, read_only) 404 self.version = version 405 406 @property 407 def zone(self): 408 return self.manager 409 410 def _setup_version(self): 411 assert self.version is None 412 self.version = WritableVersion(self.zone, self.replacement) 413 414 def _get_rdataset(self, name, rdtype, covers): 415 return self.version.get_rdataset(name, rdtype, covers) 416 417 def _put_rdataset(self, name, rdataset): 418 assert not self.read_only 419 self.version.put_rdataset(name, rdataset) 420 421 def _delete_name(self, name): 422 assert not self.read_only 423 self.version.delete_node(name) 424 425 def _delete_rdataset(self, name, rdtype, covers): 426 assert not self.read_only 427 self.version.delete_rdataset(name, rdtype, covers) 428 429 def _name_exists(self, name): 430 return self.version.get_node(name) is not None 431 432 def _changed(self): 433 if self.read_only: 434 return False 435 else: 436 return len(self.version.changed) > 0 437 438 def _end_transaction(self, commit): 439 if self.read_only: 440 self.zone._end_read(self) 441 elif commit and len(self.version.changed) > 0: 442 self.zone._commit_version(self, ImmutableVersion(self.version), 443 self.version.origin) 444 else: 445 # rollback 446 self.zone._end_write(self) 447 448 def _set_origin(self, origin): 449 if self.version.origin is None: 450 self.version.origin = origin 451 452 def _iterate_rdatasets(self): 453 for (name, node) in self.version.items(): 454 for rdataset in node: 455 yield (name, rdataset) 456