1# Copyright 2014-2016 OpenMarket Ltd
2# Copyright 2018 New Vector Ltd
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import abc
17import collections
18import inspect
19import itertools
20import logging
21from contextlib import contextmanager
22from typing import (
23    Any,
24    Awaitable,
25    Callable,
26    Collection,
27    Dict,
28    Generic,
29    Hashable,
30    Iterable,
31    Iterator,
32    Optional,
33    Set,
34    Tuple,
35    TypeVar,
36    Union,
37    cast,
38    overload,
39)
40
41import attr
42from typing_extensions import ContextManager
43
44from twisted.internet import defer
45from twisted.internet.defer import CancelledError
46from twisted.internet.interfaces import IReactorTime
47from twisted.python.failure import Failure
48
49from synapse.logging.context import (
50    PreserveLoggingContext,
51    make_deferred_yieldable,
52    run_in_background,
53)
54from synapse.util import Clock, unwrapFirstError
55
56logger = logging.getLogger(__name__)
57
58_T = TypeVar("_T")
59
60
61class AbstractObservableDeferred(Generic[_T], metaclass=abc.ABCMeta):
62    """Abstract base class defining the consumer interface of ObservableDeferred"""
63
64    __slots__ = ()
65
66    @abc.abstractmethod
67    def observe(self) -> "defer.Deferred[_T]":
68        """Add a new observer for this ObservableDeferred
69
70        This returns a brand new deferred that is resolved when the underlying
71        deferred is resolved. Interacting with the returned deferred does not
72        effect the underlying deferred.
73
74        Note that the returned Deferred doesn't follow the Synapse logcontext rules -
75        you will probably want to `make_deferred_yieldable` it.
76        """
77        ...
78
79
80class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]):
81    """Wraps a deferred object so that we can add observer deferreds. These
82    observer deferreds do not affect the callback chain of the original
83    deferred.
84
85    If consumeErrors is true errors will be captured from the origin deferred.
86
87    Cancelling or otherwise resolving an observer will not affect the original
88    ObservableDeferred.
89
90    NB that it does not attempt to do anything with logcontexts; in general
91    you should probably make_deferred_yieldable the deferreds
92    returned by `observe`, and ensure that the original deferred runs its
93    callbacks in the sentinel logcontext.
94    """
95
96    __slots__ = ["_deferred", "_observers", "_result"]
97
98    def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False):
99        object.__setattr__(self, "_deferred", deferred)
100        object.__setattr__(self, "_result", None)
101        object.__setattr__(self, "_observers", [])
102
103        def callback(r: _T) -> _T:
104            object.__setattr__(self, "_result", (True, r))
105
106            # once we have set _result, no more entries will be added to _observers,
107            # so it's safe to replace it with the empty tuple.
108            observers = self._observers
109            object.__setattr__(self, "_observers", ())
110
111            for observer in observers:
112                try:
113                    observer.callback(r)
114                except Exception as e:
115                    logger.exception(
116                        "%r threw an exception on .callback(%r), ignoring...",
117                        observer,
118                        r,
119                        exc_info=e,
120                    )
121            return r
122
123        def errback(f: Failure) -> Optional[Failure]:
124            object.__setattr__(self, "_result", (False, f))
125
126            # once we have set _result, no more entries will be added to _observers,
127            # so it's safe to replace it with the empty tuple.
128            observers = self._observers
129            object.__setattr__(self, "_observers", ())
130
131            for observer in observers:
132                # This is a little bit of magic to correctly propagate stack
133                # traces when we `await` on one of the observer deferreds.
134                f.value.__failure__ = f  # type: ignore[union-attr]
135                try:
136                    observer.errback(f)
137                except Exception as e:
138                    logger.exception(
139                        "%r threw an exception on .errback(%r), ignoring...",
140                        observer,
141                        f,
142                        exc_info=e,
143                    )
144
145            if consumeErrors:
146                return None
147            else:
148                return f
149
150        deferred.addCallbacks(callback, errback)
151
152    def observe(self) -> "defer.Deferred[_T]":
153        """Observe the underlying deferred.
154
155        This returns a brand new deferred that is resolved when the underlying
156        deferred is resolved. Interacting with the returned deferred does not
157        effect the underlying deferred.
158        """
159        if not self._result:
160            d: "defer.Deferred[_T]" = defer.Deferred()
161            self._observers.append(d)
162            return d
163        else:
164            success, res = self._result
165            return defer.succeed(res) if success else defer.fail(res)
166
167    def observers(self) -> "Collection[defer.Deferred[_T]]":
168        return self._observers
169
170    def has_called(self) -> bool:
171        return self._result is not None
172
173    def has_succeeded(self) -> bool:
174        return self._result is not None and self._result[0] is True
175
176    def get_result(self) -> Union[_T, Failure]:
177        return self._result[1]
178
179    def __getattr__(self, name: str) -> Any:
180        return getattr(self._deferred, name)
181
182    def __setattr__(self, name: str, value: Any) -> None:
183        setattr(self._deferred, name, value)
184
185    def __repr__(self) -> str:
186        return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
187            id(self),
188            self._result,
189            self._deferred,
190        )
191
192
193T = TypeVar("T")
194
195
196def concurrently_execute(
197    func: Callable[[T], Any], args: Iterable[T], limit: int
198) -> defer.Deferred:
199    """Executes the function with each argument concurrently while limiting
200    the number of concurrent executions.
201
202    Args:
203        func: Function to execute, should return a deferred or coroutine.
204        args: List of arguments to pass to func, each invocation of func
205            gets a single argument.
206        limit: Maximum number of conccurent executions.
207
208    Returns:
209        Deferred: Resolved when all function invocations have finished.
210    """
211    it = iter(args)
212
213    async def _concurrently_execute_inner(value: T) -> None:
214        try:
215            while True:
216                await maybe_awaitable(func(value))
217                value = next(it)
218        except StopIteration:
219            pass
220
221    # We use `itertools.islice` to handle the case where the number of args is
222    # less than the limit, avoiding needlessly spawning unnecessary background
223    # tasks.
224    return make_deferred_yieldable(
225        defer.gatherResults(
226            [
227                run_in_background(_concurrently_execute_inner, value)
228                for value in itertools.islice(it, limit)
229            ],
230            consumeErrors=True,
231        )
232    ).addErrback(unwrapFirstError)
233
234
235def yieldable_gather_results(
236    func: Callable, iter: Iterable, *args: Any, **kwargs: Any
237) -> defer.Deferred:
238    """Executes the function with each argument concurrently.
239
240    Args:
241        func: Function to execute that returns a Deferred
242        iter: An iterable that yields items that get passed as the first
243            argument to the function
244        *args: Arguments to be passed to each call to func
245        **kwargs: Keyword arguments to be passed to each call to func
246
247    Returns
248        Deferred[list]: Resolved when all functions have been invoked, or errors if
249        one of the function calls fails.
250    """
251    return make_deferred_yieldable(
252        defer.gatherResults(
253            [run_in_background(func, item, *args, **kwargs) for item in iter],
254            consumeErrors=True,
255        )
256    ).addErrback(unwrapFirstError)
257
258
259T1 = TypeVar("T1")
260T2 = TypeVar("T2")
261T3 = TypeVar("T3")
262
263
264@overload
265def gather_results(
266    deferredList: Tuple[()], consumeErrors: bool = ...
267) -> "defer.Deferred[Tuple[()]]":
268    ...
269
270
271@overload
272def gather_results(
273    deferredList: Tuple["defer.Deferred[T1]"],
274    consumeErrors: bool = ...,
275) -> "defer.Deferred[Tuple[T1]]":
276    ...
277
278
279@overload
280def gather_results(
281    deferredList: Tuple["defer.Deferred[T1]", "defer.Deferred[T2]"],
282    consumeErrors: bool = ...,
283) -> "defer.Deferred[Tuple[T1, T2]]":
284    ...
285
286
287@overload
288def gather_results(
289    deferredList: Tuple[
290        "defer.Deferred[T1]", "defer.Deferred[T2]", "defer.Deferred[T3]"
291    ],
292    consumeErrors: bool = ...,
293) -> "defer.Deferred[Tuple[T1, T2, T3]]":
294    ...
295
296
297def gather_results(  # type: ignore[misc]
298    deferredList: Tuple["defer.Deferred[T1]", ...],
299    consumeErrors: bool = False,
300) -> "defer.Deferred[Tuple[T1, ...]]":
301    """Combines a tuple of `Deferred`s into a single `Deferred`.
302
303    Wraps `defer.gatherResults` to provide type annotations that support heterogenous
304    lists of `Deferred`s.
305    """
306    # The `type: ignore[misc]` above suppresses
307    # "Overloaded function implementation cannot produce return type of signature 1/2/3"
308    deferred = defer.gatherResults(deferredList, consumeErrors=consumeErrors)
309    return deferred.addCallback(tuple)
310
311
312@attr.s(slots=True)
313class _LinearizerEntry:
314    # The number of things executing.
315    count = attr.ib(type=int)
316    # Deferreds for the things blocked from executing.
317    deferreds = attr.ib(type=collections.OrderedDict)
318
319
320class Linearizer:
321    """Limits concurrent access to resources based on a key. Useful to ensure
322    only a few things happen at a time on a given resource.
323
324    Example:
325
326        with await limiter.queue("test_key"):
327            # do some work.
328
329    """
330
331    def __init__(
332        self,
333        name: Optional[str] = None,
334        max_count: int = 1,
335        clock: Optional[Clock] = None,
336    ):
337        """
338        Args:
339            max_count: The maximum number of concurrent accesses
340        """
341        if name is None:
342            self.name: Union[str, int] = id(self)
343        else:
344            self.name = name
345
346        if not clock:
347            from twisted.internet import reactor
348
349            clock = Clock(cast(IReactorTime, reactor))
350        self._clock = clock
351        self.max_count = max_count
352
353        # key_to_defer is a map from the key to a _LinearizerEntry.
354        self.key_to_defer: Dict[Hashable, _LinearizerEntry] = {}
355
356    def is_queued(self, key: Hashable) -> bool:
357        """Checks whether there is a process queued up waiting"""
358        entry = self.key_to_defer.get(key)
359        if not entry:
360            # No entry so nothing is waiting.
361            return False
362
363        # There are waiting deferreds only in the OrderedDict of deferreds is
364        # non-empty.
365        return bool(entry.deferreds)
366
367    def queue(self, key: Hashable) -> defer.Deferred:
368        # we avoid doing defer.inlineCallbacks here, so that cancellation works correctly.
369        # (https://twistedmatrix.com/trac/ticket/4632 meant that cancellations were not
370        # propagated inside inlineCallbacks until Twisted 18.7)
371        entry = self.key_to_defer.setdefault(
372            key, _LinearizerEntry(0, collections.OrderedDict())
373        )
374
375        # If the number of things executing is greater than the maximum
376        # then add a deferred to the list of blocked items
377        # When one of the things currently executing finishes it will callback
378        # this item so that it can continue executing.
379        if entry.count >= self.max_count:
380            res = self._await_lock(key)
381        else:
382            logger.debug(
383                "Acquired uncontended linearizer lock %r for key %r", self.name, key
384            )
385            entry.count += 1
386            res = defer.succeed(None)
387
388        # once we successfully get the lock, we need to return a context manager which
389        # will release the lock.
390
391        @contextmanager
392        def _ctx_manager(_: None) -> Iterator[None]:
393            try:
394                yield
395            finally:
396                logger.debug("Releasing linearizer lock %r for key %r", self.name, key)
397
398                # We've finished executing so check if there are any things
399                # blocked waiting to execute and start one of them
400                entry.count -= 1
401
402                if entry.deferreds:
403                    (next_def, _) = entry.deferreds.popitem(last=False)
404
405                    # we need to run the next thing in the sentinel context.
406                    with PreserveLoggingContext():
407                        next_def.callback(None)
408                elif entry.count == 0:
409                    # We were the last thing for this key: remove it from the
410                    # map.
411                    del self.key_to_defer[key]
412
413        res.addCallback(_ctx_manager)
414        return res
415
416    def _await_lock(self, key: Hashable) -> defer.Deferred:
417        """Helper for queue: adds a deferred to the queue
418
419        Assumes that we've already checked that we've reached the limit of the number
420        of lock-holders we allow. Creates a new deferred which is added to the list, and
421        adds some management around cancellations.
422
423        Returns the deferred, which will callback once we have secured the lock.
424
425        """
426        entry = self.key_to_defer[key]
427
428        logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key)
429
430        new_defer: "defer.Deferred[None]" = make_deferred_yieldable(defer.Deferred())
431        entry.deferreds[new_defer] = 1
432
433        def cb(_r: None) -> "defer.Deferred[None]":
434            logger.debug("Acquired linearizer lock %r for key %r", self.name, key)
435            entry.count += 1
436
437            # if the code holding the lock completes synchronously, then it
438            # will recursively run the next claimant on the list. That can
439            # relatively rapidly lead to stack exhaustion. This is essentially
440            # the same problem as http://twistedmatrix.com/trac/ticket/9304.
441            #
442            # In order to break the cycle, we add a cheeky sleep(0) here to
443            # ensure that we fall back to the reactor between each iteration.
444            #
445            # (This needs to happen while we hold the lock, and the context manager's exit
446            # code must be synchronous, so this is the only sensible place.)
447            return self._clock.sleep(0)
448
449        def eb(e: Failure) -> Failure:
450            logger.info("defer %r got err %r", new_defer, e)
451            if isinstance(e, CancelledError):
452                logger.debug(
453                    "Cancelling wait for linearizer lock %r for key %r", self.name, key
454                )
455
456            else:
457                logger.warning(
458                    "Unexpected exception waiting for linearizer lock %r for key %r",
459                    self.name,
460                    key,
461                )
462
463            # we just have to take ourselves back out of the queue.
464            del entry.deferreds[new_defer]
465            return e
466
467        new_defer.addCallbacks(cb, eb)
468        return new_defer
469
470
471class ReadWriteLock:
472    """An async read write lock.
473
474    Example:
475
476        with await read_write_lock.read("test_key"):
477            # do some work
478    """
479
480    # IMPLEMENTATION NOTES
481    #
482    # We track the most recent queued reader and writer deferreds (which get
483    # resolved when they release the lock).
484    #
485    # Read: We know its safe to acquire a read lock when the latest writer has
486    # been resolved. The new reader is appended to the list of latest readers.
487    #
488    # Write: We know its safe to acquire the write lock when both the latest
489    # writers and readers have been resolved. The new writer replaces the latest
490    # writer.
491
492    def __init__(self) -> None:
493        # Latest readers queued
494        self.key_to_current_readers: Dict[str, Set[defer.Deferred]] = {}
495
496        # Latest writer queued
497        self.key_to_current_writer: Dict[str, defer.Deferred] = {}
498
499    async def read(self, key: str) -> ContextManager:
500        new_defer: "defer.Deferred[None]" = defer.Deferred()
501
502        curr_readers = self.key_to_current_readers.setdefault(key, set())
503        curr_writer = self.key_to_current_writer.get(key, None)
504
505        curr_readers.add(new_defer)
506
507        # We wait for the latest writer to finish writing. We can safely ignore
508        # any existing readers... as they're readers.
509        if curr_writer:
510            await make_deferred_yieldable(curr_writer)
511
512        @contextmanager
513        def _ctx_manager() -> Iterator[None]:
514            try:
515                yield
516            finally:
517                with PreserveLoggingContext():
518                    new_defer.callback(None)
519                self.key_to_current_readers.get(key, set()).discard(new_defer)
520
521        return _ctx_manager()
522
523    async def write(self, key: str) -> ContextManager:
524        new_defer: "defer.Deferred[None]" = defer.Deferred()
525
526        curr_readers = self.key_to_current_readers.get(key, set())
527        curr_writer = self.key_to_current_writer.get(key, None)
528
529        # We wait on all latest readers and writer.
530        to_wait_on = list(curr_readers)
531        if curr_writer:
532            to_wait_on.append(curr_writer)
533
534        # We can clear the list of current readers since the new writer waits
535        # for them to finish.
536        curr_readers.clear()
537        self.key_to_current_writer[key] = new_defer
538
539        await make_deferred_yieldable(defer.gatherResults(to_wait_on))
540
541        @contextmanager
542        def _ctx_manager() -> Iterator[None]:
543            try:
544                yield
545            finally:
546                with PreserveLoggingContext():
547                    new_defer.callback(None)
548                if self.key_to_current_writer[key] == new_defer:
549                    self.key_to_current_writer.pop(key)
550
551        return _ctx_manager()
552
553
554R = TypeVar("R")
555
556
557def timeout_deferred(
558    deferred: "defer.Deferred[_T]", timeout: float, reactor: IReactorTime
559) -> "defer.Deferred[_T]":
560    """The in built twisted `Deferred.addTimeout` fails to time out deferreds
561    that have a canceller that throws exceptions. This method creates a new
562    deferred that wraps and times out the given deferred, correctly handling
563    the case where the given deferred's canceller throws.
564
565    (See https://twistedmatrix.com/trac/ticket/9534)
566
567    NOTE: Unlike `Deferred.addTimeout`, this function returns a new deferred.
568
569    NOTE: the TimeoutError raised by the resultant deferred is
570    twisted.internet.defer.TimeoutError, which is *different* to the built-in
571    TimeoutError, as well as various other TimeoutErrors you might have imported.
572
573    Args:
574        deferred: The Deferred to potentially timeout.
575        timeout: Timeout in seconds
576        reactor: The twisted reactor to use
577
578
579    Returns:
580        A new Deferred, which will errback with defer.TimeoutError on timeout.
581    """
582    new_d: "defer.Deferred[_T]" = defer.Deferred()
583
584    timed_out = [False]
585
586    def time_it_out() -> None:
587        timed_out[0] = True
588
589        try:
590            deferred.cancel()
591        except Exception:  # if we throw any exception it'll break time outs
592            logger.exception("Canceller failed during timeout")
593
594        # the cancel() call should have set off a chain of errbacks which
595        # will have errbacked new_d, but in case it hasn't, errback it now.
596
597        if not new_d.called:
598            new_d.errback(defer.TimeoutError("Timed out after %gs" % (timeout,)))
599
600    delayed_call = reactor.callLater(timeout, time_it_out)
601
602    def convert_cancelled(value: Failure) -> Failure:
603        # if the original deferred was cancelled, and our timeout has fired, then
604        # the reason it was cancelled was due to our timeout. Turn the CancelledError
605        # into a TimeoutError.
606        if timed_out[0] and value.check(CancelledError):
607            raise defer.TimeoutError("Timed out after %gs" % (timeout,))
608        return value
609
610    deferred.addErrback(convert_cancelled)
611
612    def cancel_timeout(result: _T) -> _T:
613        # stop the pending call to cancel the deferred if it's been fired
614        if delayed_call.active():
615            delayed_call.cancel()
616        return result
617
618    deferred.addBoth(cancel_timeout)
619
620    def success_cb(val: _T) -> None:
621        if not new_d.called:
622            new_d.callback(val)
623
624    def failure_cb(val: Failure) -> None:
625        if not new_d.called:
626            new_d.errback(val)
627
628    deferred.addCallbacks(success_cb, failure_cb)
629
630    return new_d
631
632
633# This class can't be generic because it uses slots with attrs.
634# See: https://github.com/python-attrs/attrs/issues/313
635@attr.s(slots=True, frozen=True, auto_attribs=True)
636class DoneAwaitable:  # should be: Generic[R]
637    """Simple awaitable that returns the provided value."""
638
639    value: Any  # should be: R
640
641    def __await__(self) -> Any:
642        return self
643
644    def __iter__(self) -> "DoneAwaitable":
645        return self
646
647    def __next__(self) -> None:
648        raise StopIteration(self.value)
649
650
651def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]:
652    """Convert a value to an awaitable if not already an awaitable."""
653    if inspect.isawaitable(value):
654        assert isinstance(value, Awaitable)
655        return value
656
657    return DoneAwaitable(value)
658