1# Copyright 2014-2016 OpenMarket Ltd 2# Copyright 2018 New Vector Ltd 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15 16import abc 17import collections 18import inspect 19import itertools 20import logging 21from contextlib import contextmanager 22from typing import ( 23 Any, 24 Awaitable, 25 Callable, 26 Collection, 27 Dict, 28 Generic, 29 Hashable, 30 Iterable, 31 Iterator, 32 Optional, 33 Set, 34 Tuple, 35 TypeVar, 36 Union, 37 cast, 38 overload, 39) 40 41import attr 42from typing_extensions import ContextManager 43 44from twisted.internet import defer 45from twisted.internet.defer import CancelledError 46from twisted.internet.interfaces import IReactorTime 47from twisted.python.failure import Failure 48 49from synapse.logging.context import ( 50 PreserveLoggingContext, 51 make_deferred_yieldable, 52 run_in_background, 53) 54from synapse.util import Clock, unwrapFirstError 55 56logger = logging.getLogger(__name__) 57 58_T = TypeVar("_T") 59 60 61class AbstractObservableDeferred(Generic[_T], metaclass=abc.ABCMeta): 62 """Abstract base class defining the consumer interface of ObservableDeferred""" 63 64 __slots__ = () 65 66 @abc.abstractmethod 67 def observe(self) -> "defer.Deferred[_T]": 68 """Add a new observer for this ObservableDeferred 69 70 This returns a brand new deferred that is resolved when the underlying 71 deferred is resolved. Interacting with the returned deferred does not 72 effect the underlying deferred. 73 74 Note that the returned Deferred doesn't follow the Synapse logcontext rules - 75 you will probably want to `make_deferred_yieldable` it. 76 """ 77 ... 78 79 80class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]): 81 """Wraps a deferred object so that we can add observer deferreds. These 82 observer deferreds do not affect the callback chain of the original 83 deferred. 84 85 If consumeErrors is true errors will be captured from the origin deferred. 86 87 Cancelling or otherwise resolving an observer will not affect the original 88 ObservableDeferred. 89 90 NB that it does not attempt to do anything with logcontexts; in general 91 you should probably make_deferred_yieldable the deferreds 92 returned by `observe`, and ensure that the original deferred runs its 93 callbacks in the sentinel logcontext. 94 """ 95 96 __slots__ = ["_deferred", "_observers", "_result"] 97 98 def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False): 99 object.__setattr__(self, "_deferred", deferred) 100 object.__setattr__(self, "_result", None) 101 object.__setattr__(self, "_observers", []) 102 103 def callback(r: _T) -> _T: 104 object.__setattr__(self, "_result", (True, r)) 105 106 # once we have set _result, no more entries will be added to _observers, 107 # so it's safe to replace it with the empty tuple. 108 observers = self._observers 109 object.__setattr__(self, "_observers", ()) 110 111 for observer in observers: 112 try: 113 observer.callback(r) 114 except Exception as e: 115 logger.exception( 116 "%r threw an exception on .callback(%r), ignoring...", 117 observer, 118 r, 119 exc_info=e, 120 ) 121 return r 122 123 def errback(f: Failure) -> Optional[Failure]: 124 object.__setattr__(self, "_result", (False, f)) 125 126 # once we have set _result, no more entries will be added to _observers, 127 # so it's safe to replace it with the empty tuple. 128 observers = self._observers 129 object.__setattr__(self, "_observers", ()) 130 131 for observer in observers: 132 # This is a little bit of magic to correctly propagate stack 133 # traces when we `await` on one of the observer deferreds. 134 f.value.__failure__ = f # type: ignore[union-attr] 135 try: 136 observer.errback(f) 137 except Exception as e: 138 logger.exception( 139 "%r threw an exception on .errback(%r), ignoring...", 140 observer, 141 f, 142 exc_info=e, 143 ) 144 145 if consumeErrors: 146 return None 147 else: 148 return f 149 150 deferred.addCallbacks(callback, errback) 151 152 def observe(self) -> "defer.Deferred[_T]": 153 """Observe the underlying deferred. 154 155 This returns a brand new deferred that is resolved when the underlying 156 deferred is resolved. Interacting with the returned deferred does not 157 effect the underlying deferred. 158 """ 159 if not self._result: 160 d: "defer.Deferred[_T]" = defer.Deferred() 161 self._observers.append(d) 162 return d 163 else: 164 success, res = self._result 165 return defer.succeed(res) if success else defer.fail(res) 166 167 def observers(self) -> "Collection[defer.Deferred[_T]]": 168 return self._observers 169 170 def has_called(self) -> bool: 171 return self._result is not None 172 173 def has_succeeded(self) -> bool: 174 return self._result is not None and self._result[0] is True 175 176 def get_result(self) -> Union[_T, Failure]: 177 return self._result[1] 178 179 def __getattr__(self, name: str) -> Any: 180 return getattr(self._deferred, name) 181 182 def __setattr__(self, name: str, value: Any) -> None: 183 setattr(self._deferred, name, value) 184 185 def __repr__(self) -> str: 186 return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % ( 187 id(self), 188 self._result, 189 self._deferred, 190 ) 191 192 193T = TypeVar("T") 194 195 196def concurrently_execute( 197 func: Callable[[T], Any], args: Iterable[T], limit: int 198) -> defer.Deferred: 199 """Executes the function with each argument concurrently while limiting 200 the number of concurrent executions. 201 202 Args: 203 func: Function to execute, should return a deferred or coroutine. 204 args: List of arguments to pass to func, each invocation of func 205 gets a single argument. 206 limit: Maximum number of conccurent executions. 207 208 Returns: 209 Deferred: Resolved when all function invocations have finished. 210 """ 211 it = iter(args) 212 213 async def _concurrently_execute_inner(value: T) -> None: 214 try: 215 while True: 216 await maybe_awaitable(func(value)) 217 value = next(it) 218 except StopIteration: 219 pass 220 221 # We use `itertools.islice` to handle the case where the number of args is 222 # less than the limit, avoiding needlessly spawning unnecessary background 223 # tasks. 224 return make_deferred_yieldable( 225 defer.gatherResults( 226 [ 227 run_in_background(_concurrently_execute_inner, value) 228 for value in itertools.islice(it, limit) 229 ], 230 consumeErrors=True, 231 ) 232 ).addErrback(unwrapFirstError) 233 234 235def yieldable_gather_results( 236 func: Callable, iter: Iterable, *args: Any, **kwargs: Any 237) -> defer.Deferred: 238 """Executes the function with each argument concurrently. 239 240 Args: 241 func: Function to execute that returns a Deferred 242 iter: An iterable that yields items that get passed as the first 243 argument to the function 244 *args: Arguments to be passed to each call to func 245 **kwargs: Keyword arguments to be passed to each call to func 246 247 Returns 248 Deferred[list]: Resolved when all functions have been invoked, or errors if 249 one of the function calls fails. 250 """ 251 return make_deferred_yieldable( 252 defer.gatherResults( 253 [run_in_background(func, item, *args, **kwargs) for item in iter], 254 consumeErrors=True, 255 ) 256 ).addErrback(unwrapFirstError) 257 258 259T1 = TypeVar("T1") 260T2 = TypeVar("T2") 261T3 = TypeVar("T3") 262 263 264@overload 265def gather_results( 266 deferredList: Tuple[()], consumeErrors: bool = ... 267) -> "defer.Deferred[Tuple[()]]": 268 ... 269 270 271@overload 272def gather_results( 273 deferredList: Tuple["defer.Deferred[T1]"], 274 consumeErrors: bool = ..., 275) -> "defer.Deferred[Tuple[T1]]": 276 ... 277 278 279@overload 280def gather_results( 281 deferredList: Tuple["defer.Deferred[T1]", "defer.Deferred[T2]"], 282 consumeErrors: bool = ..., 283) -> "defer.Deferred[Tuple[T1, T2]]": 284 ... 285 286 287@overload 288def gather_results( 289 deferredList: Tuple[ 290 "defer.Deferred[T1]", "defer.Deferred[T2]", "defer.Deferred[T3]" 291 ], 292 consumeErrors: bool = ..., 293) -> "defer.Deferred[Tuple[T1, T2, T3]]": 294 ... 295 296 297def gather_results( # type: ignore[misc] 298 deferredList: Tuple["defer.Deferred[T1]", ...], 299 consumeErrors: bool = False, 300) -> "defer.Deferred[Tuple[T1, ...]]": 301 """Combines a tuple of `Deferred`s into a single `Deferred`. 302 303 Wraps `defer.gatherResults` to provide type annotations that support heterogenous 304 lists of `Deferred`s. 305 """ 306 # The `type: ignore[misc]` above suppresses 307 # "Overloaded function implementation cannot produce return type of signature 1/2/3" 308 deferred = defer.gatherResults(deferredList, consumeErrors=consumeErrors) 309 return deferred.addCallback(tuple) 310 311 312@attr.s(slots=True) 313class _LinearizerEntry: 314 # The number of things executing. 315 count = attr.ib(type=int) 316 # Deferreds for the things blocked from executing. 317 deferreds = attr.ib(type=collections.OrderedDict) 318 319 320class Linearizer: 321 """Limits concurrent access to resources based on a key. Useful to ensure 322 only a few things happen at a time on a given resource. 323 324 Example: 325 326 with await limiter.queue("test_key"): 327 # do some work. 328 329 """ 330 331 def __init__( 332 self, 333 name: Optional[str] = None, 334 max_count: int = 1, 335 clock: Optional[Clock] = None, 336 ): 337 """ 338 Args: 339 max_count: The maximum number of concurrent accesses 340 """ 341 if name is None: 342 self.name: Union[str, int] = id(self) 343 else: 344 self.name = name 345 346 if not clock: 347 from twisted.internet import reactor 348 349 clock = Clock(cast(IReactorTime, reactor)) 350 self._clock = clock 351 self.max_count = max_count 352 353 # key_to_defer is a map from the key to a _LinearizerEntry. 354 self.key_to_defer: Dict[Hashable, _LinearizerEntry] = {} 355 356 def is_queued(self, key: Hashable) -> bool: 357 """Checks whether there is a process queued up waiting""" 358 entry = self.key_to_defer.get(key) 359 if not entry: 360 # No entry so nothing is waiting. 361 return False 362 363 # There are waiting deferreds only in the OrderedDict of deferreds is 364 # non-empty. 365 return bool(entry.deferreds) 366 367 def queue(self, key: Hashable) -> defer.Deferred: 368 # we avoid doing defer.inlineCallbacks here, so that cancellation works correctly. 369 # (https://twistedmatrix.com/trac/ticket/4632 meant that cancellations were not 370 # propagated inside inlineCallbacks until Twisted 18.7) 371 entry = self.key_to_defer.setdefault( 372 key, _LinearizerEntry(0, collections.OrderedDict()) 373 ) 374 375 # If the number of things executing is greater than the maximum 376 # then add a deferred to the list of blocked items 377 # When one of the things currently executing finishes it will callback 378 # this item so that it can continue executing. 379 if entry.count >= self.max_count: 380 res = self._await_lock(key) 381 else: 382 logger.debug( 383 "Acquired uncontended linearizer lock %r for key %r", self.name, key 384 ) 385 entry.count += 1 386 res = defer.succeed(None) 387 388 # once we successfully get the lock, we need to return a context manager which 389 # will release the lock. 390 391 @contextmanager 392 def _ctx_manager(_: None) -> Iterator[None]: 393 try: 394 yield 395 finally: 396 logger.debug("Releasing linearizer lock %r for key %r", self.name, key) 397 398 # We've finished executing so check if there are any things 399 # blocked waiting to execute and start one of them 400 entry.count -= 1 401 402 if entry.deferreds: 403 (next_def, _) = entry.deferreds.popitem(last=False) 404 405 # we need to run the next thing in the sentinel context. 406 with PreserveLoggingContext(): 407 next_def.callback(None) 408 elif entry.count == 0: 409 # We were the last thing for this key: remove it from the 410 # map. 411 del self.key_to_defer[key] 412 413 res.addCallback(_ctx_manager) 414 return res 415 416 def _await_lock(self, key: Hashable) -> defer.Deferred: 417 """Helper for queue: adds a deferred to the queue 418 419 Assumes that we've already checked that we've reached the limit of the number 420 of lock-holders we allow. Creates a new deferred which is added to the list, and 421 adds some management around cancellations. 422 423 Returns the deferred, which will callback once we have secured the lock. 424 425 """ 426 entry = self.key_to_defer[key] 427 428 logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key) 429 430 new_defer: "defer.Deferred[None]" = make_deferred_yieldable(defer.Deferred()) 431 entry.deferreds[new_defer] = 1 432 433 def cb(_r: None) -> "defer.Deferred[None]": 434 logger.debug("Acquired linearizer lock %r for key %r", self.name, key) 435 entry.count += 1 436 437 # if the code holding the lock completes synchronously, then it 438 # will recursively run the next claimant on the list. That can 439 # relatively rapidly lead to stack exhaustion. This is essentially 440 # the same problem as http://twistedmatrix.com/trac/ticket/9304. 441 # 442 # In order to break the cycle, we add a cheeky sleep(0) here to 443 # ensure that we fall back to the reactor between each iteration. 444 # 445 # (This needs to happen while we hold the lock, and the context manager's exit 446 # code must be synchronous, so this is the only sensible place.) 447 return self._clock.sleep(0) 448 449 def eb(e: Failure) -> Failure: 450 logger.info("defer %r got err %r", new_defer, e) 451 if isinstance(e, CancelledError): 452 logger.debug( 453 "Cancelling wait for linearizer lock %r for key %r", self.name, key 454 ) 455 456 else: 457 logger.warning( 458 "Unexpected exception waiting for linearizer lock %r for key %r", 459 self.name, 460 key, 461 ) 462 463 # we just have to take ourselves back out of the queue. 464 del entry.deferreds[new_defer] 465 return e 466 467 new_defer.addCallbacks(cb, eb) 468 return new_defer 469 470 471class ReadWriteLock: 472 """An async read write lock. 473 474 Example: 475 476 with await read_write_lock.read("test_key"): 477 # do some work 478 """ 479 480 # IMPLEMENTATION NOTES 481 # 482 # We track the most recent queued reader and writer deferreds (which get 483 # resolved when they release the lock). 484 # 485 # Read: We know its safe to acquire a read lock when the latest writer has 486 # been resolved. The new reader is appended to the list of latest readers. 487 # 488 # Write: We know its safe to acquire the write lock when both the latest 489 # writers and readers have been resolved. The new writer replaces the latest 490 # writer. 491 492 def __init__(self) -> None: 493 # Latest readers queued 494 self.key_to_current_readers: Dict[str, Set[defer.Deferred]] = {} 495 496 # Latest writer queued 497 self.key_to_current_writer: Dict[str, defer.Deferred] = {} 498 499 async def read(self, key: str) -> ContextManager: 500 new_defer: "defer.Deferred[None]" = defer.Deferred() 501 502 curr_readers = self.key_to_current_readers.setdefault(key, set()) 503 curr_writer = self.key_to_current_writer.get(key, None) 504 505 curr_readers.add(new_defer) 506 507 # We wait for the latest writer to finish writing. We can safely ignore 508 # any existing readers... as they're readers. 509 if curr_writer: 510 await make_deferred_yieldable(curr_writer) 511 512 @contextmanager 513 def _ctx_manager() -> Iterator[None]: 514 try: 515 yield 516 finally: 517 with PreserveLoggingContext(): 518 new_defer.callback(None) 519 self.key_to_current_readers.get(key, set()).discard(new_defer) 520 521 return _ctx_manager() 522 523 async def write(self, key: str) -> ContextManager: 524 new_defer: "defer.Deferred[None]" = defer.Deferred() 525 526 curr_readers = self.key_to_current_readers.get(key, set()) 527 curr_writer = self.key_to_current_writer.get(key, None) 528 529 # We wait on all latest readers and writer. 530 to_wait_on = list(curr_readers) 531 if curr_writer: 532 to_wait_on.append(curr_writer) 533 534 # We can clear the list of current readers since the new writer waits 535 # for them to finish. 536 curr_readers.clear() 537 self.key_to_current_writer[key] = new_defer 538 539 await make_deferred_yieldable(defer.gatherResults(to_wait_on)) 540 541 @contextmanager 542 def _ctx_manager() -> Iterator[None]: 543 try: 544 yield 545 finally: 546 with PreserveLoggingContext(): 547 new_defer.callback(None) 548 if self.key_to_current_writer[key] == new_defer: 549 self.key_to_current_writer.pop(key) 550 551 return _ctx_manager() 552 553 554R = TypeVar("R") 555 556 557def timeout_deferred( 558 deferred: "defer.Deferred[_T]", timeout: float, reactor: IReactorTime 559) -> "defer.Deferred[_T]": 560 """The in built twisted `Deferred.addTimeout` fails to time out deferreds 561 that have a canceller that throws exceptions. This method creates a new 562 deferred that wraps and times out the given deferred, correctly handling 563 the case where the given deferred's canceller throws. 564 565 (See https://twistedmatrix.com/trac/ticket/9534) 566 567 NOTE: Unlike `Deferred.addTimeout`, this function returns a new deferred. 568 569 NOTE: the TimeoutError raised by the resultant deferred is 570 twisted.internet.defer.TimeoutError, which is *different* to the built-in 571 TimeoutError, as well as various other TimeoutErrors you might have imported. 572 573 Args: 574 deferred: The Deferred to potentially timeout. 575 timeout: Timeout in seconds 576 reactor: The twisted reactor to use 577 578 579 Returns: 580 A new Deferred, which will errback with defer.TimeoutError on timeout. 581 """ 582 new_d: "defer.Deferred[_T]" = defer.Deferred() 583 584 timed_out = [False] 585 586 def time_it_out() -> None: 587 timed_out[0] = True 588 589 try: 590 deferred.cancel() 591 except Exception: # if we throw any exception it'll break time outs 592 logger.exception("Canceller failed during timeout") 593 594 # the cancel() call should have set off a chain of errbacks which 595 # will have errbacked new_d, but in case it hasn't, errback it now. 596 597 if not new_d.called: 598 new_d.errback(defer.TimeoutError("Timed out after %gs" % (timeout,))) 599 600 delayed_call = reactor.callLater(timeout, time_it_out) 601 602 def convert_cancelled(value: Failure) -> Failure: 603 # if the original deferred was cancelled, and our timeout has fired, then 604 # the reason it was cancelled was due to our timeout. Turn the CancelledError 605 # into a TimeoutError. 606 if timed_out[0] and value.check(CancelledError): 607 raise defer.TimeoutError("Timed out after %gs" % (timeout,)) 608 return value 609 610 deferred.addErrback(convert_cancelled) 611 612 def cancel_timeout(result: _T) -> _T: 613 # stop the pending call to cancel the deferred if it's been fired 614 if delayed_call.active(): 615 delayed_call.cancel() 616 return result 617 618 deferred.addBoth(cancel_timeout) 619 620 def success_cb(val: _T) -> None: 621 if not new_d.called: 622 new_d.callback(val) 623 624 def failure_cb(val: Failure) -> None: 625 if not new_d.called: 626 new_d.errback(val) 627 628 deferred.addCallbacks(success_cb, failure_cb) 629 630 return new_d 631 632 633# This class can't be generic because it uses slots with attrs. 634# See: https://github.com/python-attrs/attrs/issues/313 635@attr.s(slots=True, frozen=True, auto_attribs=True) 636class DoneAwaitable: # should be: Generic[R] 637 """Simple awaitable that returns the provided value.""" 638 639 value: Any # should be: R 640 641 def __await__(self) -> Any: 642 return self 643 644 def __iter__(self) -> "DoneAwaitable": 645 return self 646 647 def __next__(self) -> None: 648 raise StopIteration(self.value) 649 650 651def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]: 652 """Convert a value to an awaitable if not already an awaitable.""" 653 if inspect.isawaitable(value): 654 assert isinstance(value, Awaitable) 655 return value 656 657 return DoneAwaitable(value) 658