1# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
2
3"""DNS Versioned Zones."""
4
5import collections
6try:
7    import threading as _threading
8except ImportError:  # pragma: no cover
9    import dummy_threading as _threading    # type: ignore
10
11import dns.exception
12import dns.immutable
13import dns.name
14import dns.node
15import dns.rdataclass
16import dns.rdatatype
17import dns.rdata
18import dns.rdtypes.ANY.SOA
19import dns.transaction
20import dns.zone
21
22
23class UseTransaction(dns.exception.DNSException):
24    """To alter a versioned zone, use a transaction."""
25
26
27class Version:
28    def __init__(self, zone, id):
29        self.zone = zone
30        self.id = id
31        self.nodes = {}
32
33    def _validate_name(self, name):
34        if name.is_absolute():
35            if not name.is_subdomain(self.zone.origin):
36                raise KeyError("name is not a subdomain of the zone origin")
37            if self.zone.relativize:
38                name = name.relativize(self.origin)
39        return name
40
41    def get_node(self, name):
42        name = self._validate_name(name)
43        return self.nodes.get(name)
44
45    def get_rdataset(self, name, rdtype, covers):
46        node = self.get_node(name)
47        if node is None:
48            return None
49        return node.get_rdataset(self.zone.rdclass, rdtype, covers)
50
51    def items(self):
52        return self.nodes.items()  # pylint: disable=dict-items-not-iterating
53
54
55class WritableVersion(Version):
56    def __init__(self, zone, replacement=False):
57        # The zone._versions_lock must be held by our caller.
58        if len(zone._versions) > 0:
59            id = zone._versions[-1].id + 1
60        else:
61            id = 1
62        super().__init__(zone, id)
63        if not replacement:
64            # We copy the map, because that gives us a simple and thread-safe
65            # way of doing versions, and we have a garbage collector to help
66            # us.  We only make new node objects if we actually change the
67            # node.
68            self.nodes.update(zone.nodes)
69        # We have to copy the zone origin as it may be None in the first
70        # version, and we don't want to mutate the zone until we commit.
71        self.origin = zone.origin
72        self.changed = set()
73
74    def _maybe_cow(self, name):
75        name = self._validate_name(name)
76        node = self.nodes.get(name)
77        if node is None or node.id != self.id:
78            new_node = self.zone.node_factory()
79            new_node.id = self.id
80            if node is not None:
81                # moo!  copy on write!
82                new_node.rdatasets.extend(node.rdatasets)
83            self.nodes[name] = new_node
84            self.changed.add(name)
85            return new_node
86        else:
87            return node
88
89    def delete_node(self, name):
90        name = self._validate_name(name)
91        if name in self.nodes:
92            del self.nodes[name]
93            self.changed.add(name)
94
95    def put_rdataset(self, name, rdataset):
96        node = self._maybe_cow(name)
97        node.replace_rdataset(rdataset)
98
99    def delete_rdataset(self, name, rdtype, covers):
100        node = self._maybe_cow(name)
101        node.delete_rdataset(self.zone.rdclass, rdtype, covers)
102        if len(node) == 0:
103            del self.nodes[name]
104
105
106@dns.immutable.immutable
107class ImmutableVersion(Version):
108    def __init__(self, version):
109        # We tell super() that it's a replacement as we don't want it
110        # to copy the nodes, as we're about to do that with an
111        # immutable Dict.
112        super().__init__(version.zone, True)
113        # set the right id!
114        self.id = version.id
115        # Make changed nodes immutable
116        for name in version.changed:
117            node = version.nodes.get(name)
118            # it might not exist if we deleted it in the version
119            if node:
120                version.nodes[name] = ImmutableNode(node)
121        self.nodes = dns.immutable.Dict(version.nodes, True)
122
123
124# A node with a version id.
125
126class Node(dns.node.Node):
127    __slots__ = ['id']
128
129    def __init__(self):
130        super().__init__()
131        # A proper id will get set by the Version
132        self.id = 0
133
134
135@dns.immutable.immutable
136class ImmutableNode(Node):
137    __slots__ = ['id']
138
139    def __init__(self, node):
140        super().__init__()
141        self.id = node.id
142        self.rdatasets = tuple(
143            [dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets]
144        )
145
146    def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
147                      create=False):
148        if create:
149            raise TypeError("immutable")
150        return super().find_rdataset(rdclass, rdtype, covers, False)
151
152    def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
153                     create=False):
154        if create:
155            raise TypeError("immutable")
156        return super().get_rdataset(rdclass, rdtype, covers, False)
157
158    def delete_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE):
159        raise TypeError("immutable")
160
161    def replace_rdataset(self, replacement):
162        raise TypeError("immutable")
163
164
165class Zone(dns.zone.Zone):
166
167    __slots__ = ['_versions', '_versions_lock', '_write_txn',
168                 '_write_waiters', '_write_event', '_pruning_policy',
169                 '_readers']
170
171    node_factory = Node
172
173    def __init__(self, origin, rdclass=dns.rdataclass.IN, relativize=True,
174                 pruning_policy=None):
175        """Initialize a versioned zone object.
176
177        *origin* is the origin of the zone.  It may be a ``dns.name.Name``,
178        a ``str``, or ``None``.  If ``None``, then the zone's origin will
179        be set by the first ``$ORIGIN`` line in a zone file.
180
181        *rdclass*, an ``int``, the zone's rdata class; the default is class IN.
182
183        *relativize*, a ``bool``, determine's whether domain names are
184        relativized to the zone's origin.  The default is ``True``.
185
186        *pruning policy*, a function taking a `Version` and returning
187        a `bool`, or `None`.  Should the version be pruned?  If `None`,
188        the default policy, which retains one version is used.
189        """
190        super().__init__(origin, rdclass, relativize)
191        self._versions = collections.deque()
192        self._version_lock = _threading.Lock()
193        if pruning_policy is None:
194            self._pruning_policy = self._default_pruning_policy
195        else:
196            self._pruning_policy = pruning_policy
197        self._write_txn = None
198        self._write_event = None
199        self._write_waiters = collections.deque()
200        self._readers = set()
201        self._commit_version_unlocked(None, WritableVersion(self), origin)
202
203    def reader(self, id=None, serial=None):  # pylint: disable=arguments-differ
204        if id is not None and serial is not None:
205            raise ValueError('cannot specify both id and serial')
206        with self._version_lock:
207            if id is not None:
208                version = None
209                for v in reversed(self._versions):
210                    if v.id == id:
211                        version = v
212                        break
213                if version is None:
214                    raise KeyError('version not found')
215            elif serial is not None:
216                if self.relativize:
217                    oname = dns.name.empty
218                else:
219                    oname = self.origin
220                version = None
221                for v in reversed(self._versions):
222                    n = v.nodes.get(oname)
223                    if n:
224                        rds = n.get_rdataset(self.rdclass, dns.rdatatype.SOA)
225                        if rds and rds[0].serial == serial:
226                            version = v
227                            break
228                if version is None:
229                    raise KeyError('serial not found')
230            else:
231                version = self._versions[-1]
232            txn = Transaction(self, False, version)
233            self._readers.add(txn)
234            return txn
235
236    def writer(self, replacement=False):
237        event = None
238        while True:
239            with self._version_lock:
240                # Checking event == self._write_event ensures that either
241                # no one was waiting before we got lucky and found no write
242                # txn, or we were the one who was waiting and got woken up.
243                # This prevents "taking cuts" when creating a write txn.
244                if self._write_txn is None and event == self._write_event:
245                    # Creating the transaction defers version setup
246                    # (i.e.  copying the nodes dictionary) until we
247                    # give up the lock, so that we hold the lock as
248                    # short a time as possible.  This is why we call
249                    # _setup_version() below.
250                    self._write_txn = Transaction(self, replacement)
251                    # give up our exclusive right to make a Transaction
252                    self._write_event = None
253                    break
254                # Someone else is writing already, so we will have to
255                # wait, but we want to do the actual wait outside the
256                # lock.
257                event = _threading.Event()
258                self._write_waiters.append(event)
259            # wait (note we gave up the lock!)
260            #
261            # We only wake one sleeper at a time, so it's important
262            # that no event waiter can exit this method (e.g. via
263            # cancelation) without returning a transaction or waking
264            # someone else up.
265            #
266            # This is not a problem with Threading module threads as
267            # they cannot be canceled, but could be an issue with trio
268            # or curio tasks when we do the async version of writer().
269            # I.e. we'd need to do something like:
270            #
271            # try:
272            #     event.wait()
273            # except trio.Cancelled:
274            #     with self._version_lock:
275            #         self._maybe_wakeup_one_waiter_unlocked()
276            #     raise
277            #
278            event.wait()
279        # Do the deferred version setup.
280        self._write_txn._setup_version()
281        return self._write_txn
282
283    def _maybe_wakeup_one_waiter_unlocked(self):
284        if len(self._write_waiters) > 0:
285            self._write_event = self._write_waiters.popleft()
286            self._write_event.set()
287
288    # pylint: disable=unused-argument
289    def _default_pruning_policy(self, zone, version):
290        return True
291    # pylint: enable=unused-argument
292
293    def _prune_versions_unlocked(self):
294        assert len(self._versions) > 0
295        # Don't ever prune a version greater than or equal to one that
296        # a reader has open.  This pins versions in memory while the
297        # reader is open, and importantly lets the reader open a txn on
298        # a successor version (e.g. if generating an IXFR).
299        #
300        # Note our definition of least_kept also ensures we do not try to
301        # delete the greatest version.
302        if len(self._readers) > 0:
303            least_kept = min(txn.version.id for txn in self._readers)
304        else:
305            least_kept = self._versions[-1].id
306        while self._versions[0].id < least_kept and \
307              self._pruning_policy(self, self._versions[0]):
308            self._versions.popleft()
309
310    def set_max_versions(self, max_versions):
311        """Set a pruning policy that retains up to the specified number
312        of versions
313        """
314        if max_versions is not None and max_versions < 1:
315            raise ValueError('max versions must be at least 1')
316        if max_versions is None:
317            def policy(*_):
318                return False
319        else:
320            def policy(zone, _):
321                return len(zone._versions) > max_versions
322        self.set_pruning_policy(policy)
323
324    def set_pruning_policy(self, policy):
325        """Set the pruning policy for the zone.
326
327        The *policy* function takes a `Version` and returns `True` if
328        the version should be pruned, and `False` otherwise.  `None`
329        may also be specified for policy, in which case the default policy
330        is used.
331
332        Pruning checking proceeds from the least version and the first
333        time the function returns `False`, the checking stops.  I.e. the
334        retained versions are always a consecutive sequence.
335        """
336        if policy is None:
337            policy = self._default_pruning_policy
338        with self._version_lock:
339            self._pruning_policy = policy
340            self._prune_versions_unlocked()
341
342    def _end_read(self, txn):
343        with self._version_lock:
344            self._readers.remove(txn)
345            self._prune_versions_unlocked()
346
347    def _end_write_unlocked(self, txn):
348        assert self._write_txn == txn
349        self._write_txn = None
350        self._maybe_wakeup_one_waiter_unlocked()
351
352    def _end_write(self, txn):
353        with self._version_lock:
354            self._end_write_unlocked(txn)
355
356    def _commit_version_unlocked(self, txn, version, origin):
357        self._versions.append(version)
358        self._prune_versions_unlocked()
359        self.nodes = version.nodes
360        if self.origin is None:
361            self.origin = origin
362        # txn can be None in __init__ when we make the empty version.
363        if txn is not None:
364            self._end_write_unlocked(txn)
365
366    def _commit_version(self, txn, version, origin):
367        with self._version_lock:
368            self._commit_version_unlocked(txn, version, origin)
369
370    def find_node(self, name, create=False):
371        if create:
372            raise UseTransaction
373        return super().find_node(name)
374
375    def delete_node(self, name):
376        raise UseTransaction
377
378    def find_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE,
379                      create=False):
380        if create:
381            raise UseTransaction
382        rdataset = super().find_rdataset(name, rdtype, covers)
383        return dns.rdataset.ImmutableRdataset(rdataset)
384
385    def get_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE,
386                     create=False):
387        if create:
388            raise UseTransaction
389        rdataset = super().get_rdataset(name, rdtype, covers)
390        return dns.rdataset.ImmutableRdataset(rdataset)
391
392    def delete_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE):
393        raise UseTransaction
394
395    def replace_rdataset(self, name, replacement):
396        raise UseTransaction
397
398
399class Transaction(dns.transaction.Transaction):
400
401    def __init__(self, zone, replacement, version=None):
402        read_only = version is not None
403        super().__init__(zone, replacement, read_only)
404        self.version = version
405
406    @property
407    def zone(self):
408        return self.manager
409
410    def _setup_version(self):
411        assert self.version is None
412        self.version = WritableVersion(self.zone, self.replacement)
413
414    def _get_rdataset(self, name, rdtype, covers):
415        return self.version.get_rdataset(name, rdtype, covers)
416
417    def _put_rdataset(self, name, rdataset):
418        assert not self.read_only
419        self.version.put_rdataset(name, rdataset)
420
421    def _delete_name(self, name):
422        assert not self.read_only
423        self.version.delete_node(name)
424
425    def _delete_rdataset(self, name, rdtype, covers):
426        assert not self.read_only
427        self.version.delete_rdataset(name, rdtype, covers)
428
429    def _name_exists(self, name):
430        return self.version.get_node(name) is not None
431
432    def _changed(self):
433        if self.read_only:
434            return False
435        else:
436            return len(self.version.changed) > 0
437
438    def _end_transaction(self, commit):
439        if self.read_only:
440            self.zone._end_read(self)
441        elif commit and len(self.version.changed) > 0:
442            self.zone._commit_version(self, ImmutableVersion(self.version),
443                                      self.version.origin)
444        else:
445            # rollback
446            self.zone._end_write(self)
447
448    def _set_origin(self, origin):
449        if self.version.origin is None:
450            self.version.origin = origin
451
452    def _iterate_rdatasets(self):
453        for (name, node) in self.version.items():
454            for rdataset in node:
455                yield (name, rdataset)
456