1from collections import deque 2from dataclasses import dataclass 3from types import TracebackType 4from typing import Deque, Optional, Tuple, Type 5from warnings import warn 6 7from ..lowlevel import cancel_shielded_checkpoint, checkpoint, checkpoint_if_cancelled 8from ._compat import DeprecatedAwaitable 9from ._eventloop import get_asynclib 10from ._exceptions import BusyResourceError, WouldBlock 11from ._tasks import CancelScope 12from ._testing import TaskInfo, get_current_task 13 14 15@dataclass(frozen=True) 16class EventStatistics: 17 """ 18 :ivar int tasks_waiting: number of tasks waiting on :meth:`~.Event.wait` 19 """ 20 21 tasks_waiting: int 22 23 24@dataclass(frozen=True) 25class CapacityLimiterStatistics: 26 """ 27 :ivar int borrowed_tokens: number of tokens currently borrowed by tasks 28 :ivar float total_tokens: total number of available tokens 29 :ivar tuple borrowers: tasks or other objects currently holding tokens borrowed from this 30 limiter 31 :ivar int tasks_waiting: number of tasks waiting on :meth:`~.CapacityLimiter.acquire` or 32 :meth:`~.CapacityLimiter.acquire_on_behalf_of` 33 """ 34 35 borrowed_tokens: int 36 total_tokens: float 37 borrowers: Tuple[object, ...] 38 tasks_waiting: int 39 40 41@dataclass(frozen=True) 42class LockStatistics: 43 """ 44 :ivar bool locked: flag indicating if this lock is locked or not 45 :ivar ~anyio.TaskInfo owner: task currently holding the lock (or ``None`` if the lock is not 46 held by any task) 47 :ivar int tasks_waiting: number of tasks waiting on :meth:`~.Lock.acquire` 48 """ 49 50 locked: bool 51 owner: Optional[TaskInfo] 52 tasks_waiting: int 53 54 55@dataclass(frozen=True) 56class ConditionStatistics: 57 """ 58 :ivar int tasks_waiting: number of tasks blocked on :meth:`~.Condition.wait` 59 :ivar ~anyio.LockStatistics lock_statistics: statistics of the underlying :class:`~.Lock` 60 """ 61 62 tasks_waiting: int 63 lock_statistics: LockStatistics 64 65 66@dataclass(frozen=True) 67class SemaphoreStatistics: 68 """ 69 :ivar int tasks_waiting: number of tasks waiting on :meth:`~.Semaphore.acquire` 70 71 """ 72 tasks_waiting: int 73 74 75class Event: 76 def __new__(cls) -> 'Event': 77 return get_asynclib().Event() 78 79 def set(self) -> DeprecatedAwaitable: 80 """Set the flag, notifying all listeners.""" 81 raise NotImplementedError 82 83 def is_set(self) -> bool: 84 """Return ``True`` if the flag is set, ``False`` if not.""" 85 raise NotImplementedError 86 87 async def wait(self) -> None: 88 """ 89 Wait until the flag has been set. 90 91 If the flag has already been set when this method is called, it returns immediately. 92 93 """ 94 raise NotImplementedError 95 96 def statistics(self) -> EventStatistics: 97 """Return statistics about the current state of this event.""" 98 raise NotImplementedError 99 100 101class Lock: 102 _owner_task: Optional[TaskInfo] = None 103 104 def __init__(self) -> None: 105 self._waiters: Deque[Tuple[TaskInfo, Event]] = deque() 106 107 async def __aenter__(self) -> None: 108 await self.acquire() 109 110 async def __aexit__(self, exc_type: Optional[Type[BaseException]], 111 exc_val: Optional[BaseException], 112 exc_tb: Optional[TracebackType]) -> None: 113 self.release() 114 115 async def acquire(self) -> None: 116 """Acquire the lock.""" 117 await checkpoint_if_cancelled() 118 try: 119 self.acquire_nowait() 120 except WouldBlock: 121 task = get_current_task() 122 event = Event() 123 token = task, event 124 self._waiters.append(token) 125 try: 126 await event.wait() 127 except BaseException: 128 if not event.is_set(): 129 self._waiters.remove(token) 130 elif self._owner_task == task: 131 self.release() 132 133 raise 134 135 assert self._owner_task == task 136 else: 137 await cancel_shielded_checkpoint() 138 139 def acquire_nowait(self) -> None: 140 """ 141 Acquire the lock, without blocking. 142 143 :raises ~WouldBlock: if the operation would block 144 145 """ 146 task = get_current_task() 147 if self._owner_task == task: 148 raise RuntimeError('Attempted to acquire an already held Lock') 149 150 if self._owner_task is not None: 151 raise WouldBlock 152 153 self._owner_task = task 154 155 def release(self) -> DeprecatedAwaitable: 156 """Release the lock.""" 157 if self._owner_task != get_current_task(): 158 raise RuntimeError('The current task is not holding this lock') 159 160 if self._waiters: 161 self._owner_task, event = self._waiters.popleft() 162 event.set() 163 else: 164 del self._owner_task 165 166 return DeprecatedAwaitable(self.release) 167 168 def locked(self) -> bool: 169 """Return True if the lock is currently held.""" 170 return self._owner_task is not None 171 172 def statistics(self) -> LockStatistics: 173 """ 174 Return statistics about the current state of this lock. 175 176 .. versionadded:: 3.0 177 """ 178 return LockStatistics(self.locked(), self._owner_task, len(self._waiters)) 179 180 181class Condition: 182 _owner_task: Optional[TaskInfo] = None 183 184 def __init__(self, lock: Optional[Lock] = None): 185 self._lock = lock or Lock() 186 self._waiters: Deque[Event] = deque() 187 188 async def __aenter__(self) -> None: 189 await self.acquire() 190 191 async def __aexit__(self, exc_type: Optional[Type[BaseException]], 192 exc_val: Optional[BaseException], 193 exc_tb: Optional[TracebackType]) -> None: 194 self.release() 195 196 def _check_acquired(self) -> None: 197 if self._owner_task != get_current_task(): 198 raise RuntimeError('The current task is not holding the underlying lock') 199 200 async def acquire(self) -> None: 201 """Acquire the underlying lock.""" 202 await self._lock.acquire() 203 self._owner_task = get_current_task() 204 205 def acquire_nowait(self) -> None: 206 """ 207 Acquire the underlying lock, without blocking. 208 209 :raises ~WouldBlock: if the operation would block 210 211 """ 212 self._lock.acquire_nowait() 213 self._owner_task = get_current_task() 214 215 def release(self) -> DeprecatedAwaitable: 216 """Release the underlying lock.""" 217 self._lock.release() 218 return DeprecatedAwaitable(self.release) 219 220 def locked(self) -> bool: 221 """Return True if the lock is set.""" 222 return self._lock.locked() 223 224 def notify(self, n: int = 1) -> None: 225 """Notify exactly n listeners.""" 226 self._check_acquired() 227 for _ in range(n): 228 try: 229 event = self._waiters.popleft() 230 except IndexError: 231 break 232 233 event.set() 234 235 def notify_all(self) -> None: 236 """Notify all the listeners.""" 237 self._check_acquired() 238 for event in self._waiters: 239 event.set() 240 241 self._waiters.clear() 242 243 async def wait(self) -> None: 244 """Wait for a notification.""" 245 await checkpoint() 246 event = Event() 247 self._waiters.append(event) 248 self.release() 249 try: 250 await event.wait() 251 except BaseException: 252 if not event.is_set(): 253 self._waiters.remove(event) 254 255 raise 256 finally: 257 with CancelScope(shield=True): 258 await self.acquire() 259 260 def statistics(self) -> ConditionStatistics: 261 """ 262 Return statistics about the current state of this condition. 263 264 .. versionadded:: 3.0 265 """ 266 return ConditionStatistics(len(self._waiters), self._lock.statistics()) 267 268 269class Semaphore: 270 def __init__(self, initial_value: int, *, max_value: Optional[int] = None): 271 if not isinstance(initial_value, int): 272 raise TypeError('initial_value must be an integer') 273 if initial_value < 0: 274 raise ValueError('initial_value must be >= 0') 275 if max_value is not None: 276 if not isinstance(max_value, int): 277 raise TypeError('max_value must be an integer or None') 278 if max_value < initial_value: 279 raise ValueError('max_value must be equal to or higher than initial_value') 280 281 self._value = initial_value 282 self._max_value = max_value 283 self._waiters: Deque[Event] = deque() 284 285 async def __aenter__(self) -> 'Semaphore': 286 await self.acquire() 287 return self 288 289 async def __aexit__(self, exc_type: Optional[Type[BaseException]], 290 exc_val: Optional[BaseException], 291 exc_tb: Optional[TracebackType]) -> None: 292 self.release() 293 294 async def acquire(self) -> None: 295 """Decrement the semaphore value, blocking if necessary.""" 296 await checkpoint_if_cancelled() 297 try: 298 self.acquire_nowait() 299 except WouldBlock: 300 event = Event() 301 self._waiters.append(event) 302 try: 303 await event.wait() 304 except BaseException: 305 if not event.is_set(): 306 self._waiters.remove(event) 307 else: 308 self.release() 309 310 raise 311 else: 312 await cancel_shielded_checkpoint() 313 314 def acquire_nowait(self) -> None: 315 """ 316 Acquire the underlying lock, without blocking. 317 318 :raises ~WouldBlock: if the operation would block 319 320 """ 321 if self._value == 0: 322 raise WouldBlock 323 324 self._value -= 1 325 326 def release(self) -> DeprecatedAwaitable: 327 """Increment the semaphore value.""" 328 if self._max_value is not None and self._value == self._max_value: 329 raise ValueError('semaphore released too many times') 330 331 if self._waiters: 332 self._waiters.popleft().set() 333 else: 334 self._value += 1 335 336 return DeprecatedAwaitable(self.release) 337 338 @property 339 def value(self) -> int: 340 """The current value of the semaphore.""" 341 return self._value 342 343 @property 344 def max_value(self) -> Optional[int]: 345 """The maximum value of the semaphore.""" 346 return self._max_value 347 348 def statistics(self) -> SemaphoreStatistics: 349 """ 350 Return statistics about the current state of this semaphore. 351 352 .. versionadded:: 3.0 353 """ 354 return SemaphoreStatistics(len(self._waiters)) 355 356 357class CapacityLimiter: 358 def __new__(cls, total_tokens: float) -> 'CapacityLimiter': 359 return get_asynclib().CapacityLimiter(total_tokens) 360 361 async def __aenter__(self) -> None: 362 raise NotImplementedError 363 364 async def __aexit__(self, exc_type: Optional[Type[BaseException]], 365 exc_val: Optional[BaseException], 366 exc_tb: Optional[TracebackType]) -> Optional[bool]: 367 raise NotImplementedError 368 369 @property 370 def total_tokens(self) -> float: 371 """ 372 The total number of tokens available for borrowing. 373 374 This is a read-write property. If the total number of tokens is increased, the 375 proportionate number of tasks waiting on this limiter will be granted their tokens. 376 377 .. versionchanged:: 3.0 378 The property is now writable. 379 380 """ 381 raise NotImplementedError 382 383 @total_tokens.setter 384 def total_tokens(self, value: float) -> None: 385 raise NotImplementedError 386 387 async def set_total_tokens(self, value: float) -> None: 388 warn('CapacityLimiter.set_total_tokens has been deprecated. Set the value of the' 389 '"total_tokens" attribute directly.', DeprecationWarning) 390 self.total_tokens = value 391 392 @property 393 def borrowed_tokens(self) -> int: 394 """The number of tokens that have currently been borrowed.""" 395 raise NotImplementedError 396 397 @property 398 def available_tokens(self) -> float: 399 """The number of tokens currently available to be borrowed""" 400 raise NotImplementedError 401 402 def acquire_nowait(self) -> DeprecatedAwaitable: 403 """ 404 Acquire a token for the current task without waiting for one to become available. 405 406 :raises ~anyio.WouldBlock: if there are no tokens available for borrowing 407 408 """ 409 raise NotImplementedError 410 411 def acquire_on_behalf_of_nowait(self, borrower: object) -> DeprecatedAwaitable: 412 """ 413 Acquire a token without waiting for one to become available. 414 415 :param borrower: the entity borrowing a token 416 :raises ~anyio.WouldBlock: if there are no tokens available for borrowing 417 418 """ 419 raise NotImplementedError 420 421 async def acquire(self) -> None: 422 """ 423 Acquire a token for the current task, waiting if necessary for one to become available. 424 425 """ 426 raise NotImplementedError 427 428 async def acquire_on_behalf_of(self, borrower: object) -> None: 429 """ 430 Acquire a token, waiting if necessary for one to become available. 431 432 :param borrower: the entity borrowing a token 433 434 """ 435 raise NotImplementedError 436 437 def release(self) -> None: 438 """ 439 Release the token held by the current task. 440 :raises RuntimeError: if the current task has not borrowed a token from this limiter. 441 442 """ 443 raise NotImplementedError 444 445 def release_on_behalf_of(self, borrower: object) -> None: 446 """ 447 Release the token held by the given borrower. 448 449 :raises RuntimeError: if the borrower has not borrowed a token from this limiter. 450 451 """ 452 raise NotImplementedError 453 454 def statistics(self) -> CapacityLimiterStatistics: 455 """ 456 Return statistics about the current state of this limiter. 457 458 .. versionadded:: 3.0 459 460 """ 461 raise NotImplementedError 462 463 464def create_lock() -> Lock: 465 """ 466 Create an asynchronous lock. 467 468 :return: a lock object 469 470 .. deprecated:: 3.0 471 Use :class:`~Lock` directly. 472 473 """ 474 warn('create_lock() is deprecated -- use Lock() directly', DeprecationWarning) 475 return Lock() 476 477 478def create_condition(lock: Optional[Lock] = None) -> Condition: 479 """ 480 Create an asynchronous condition. 481 482 :param lock: the lock to base the condition object on 483 :return: a condition object 484 485 .. deprecated:: 3.0 486 Use :class:`~Condition` directly. 487 488 """ 489 warn('create_condition() is deprecated -- use Condition() directly', DeprecationWarning) 490 return Condition(lock=lock) 491 492 493def create_event() -> Event: 494 """ 495 Create an asynchronous event object. 496 497 :return: an event object 498 499 .. deprecated:: 3.0 500 Use :class:`~Event` directly. 501 502 """ 503 warn('create_event() is deprecated -- use Event() directly', DeprecationWarning) 504 return get_asynclib().Event() 505 506 507def create_semaphore(value: int, *, max_value: Optional[int] = None) -> Semaphore: 508 """ 509 Create an asynchronous semaphore. 510 511 :param value: the semaphore's initial value 512 :param max_value: if set, makes this a "bounded" semaphore that raises :exc:`ValueError` if the 513 semaphore's value would exceed this number 514 :return: a semaphore object 515 516 .. deprecated:: 3.0 517 Use :class:`~Semaphore` directly. 518 519 """ 520 warn('create_semaphore() is deprecated -- use Semaphore() directly', DeprecationWarning) 521 return Semaphore(value, max_value=max_value) 522 523 524def create_capacity_limiter(total_tokens: float) -> CapacityLimiter: 525 """ 526 Create a capacity limiter. 527 528 :param total_tokens: the total number of tokens available for borrowing (can be an integer or 529 :data:`math.inf`) 530 :return: a capacity limiter object 531 532 .. deprecated:: 3.0 533 Use :class:`~CapacityLimiter` directly. 534 535 """ 536 warn('create_capacity_limiter() is deprecated -- use CapacityLimiter() directly', 537 DeprecationWarning) 538 return get_asynclib().CapacityLimiter(total_tokens) 539 540 541class ResourceGuard: 542 __slots__ = 'action', '_guarded' 543 544 def __init__(self, action: str): 545 self.action = action 546 self._guarded = False 547 548 def __enter__(self) -> None: 549 if self._guarded: 550 raise BusyResourceError(self.action) 551 552 self._guarded = True 553 554 def __exit__(self, exc_type: Optional[Type[BaseException]], 555 exc_val: Optional[BaseException], 556 exc_tb: Optional[TracebackType]) -> Optional[bool]: 557 self._guarded = False 558 return None 559