1from collections import deque
2from dataclasses import dataclass
3from types import TracebackType
4from typing import Deque, Optional, Tuple, Type
5from warnings import warn
6
7from ..lowlevel import cancel_shielded_checkpoint, checkpoint, checkpoint_if_cancelled
8from ._compat import DeprecatedAwaitable
9from ._eventloop import get_asynclib
10from ._exceptions import BusyResourceError, WouldBlock
11from ._tasks import CancelScope
12from ._testing import TaskInfo, get_current_task
13
14
15@dataclass(frozen=True)
16class EventStatistics:
17    """
18    :ivar int tasks_waiting: number of tasks waiting on :meth:`~.Event.wait`
19    """
20
21    tasks_waiting: int
22
23
24@dataclass(frozen=True)
25class CapacityLimiterStatistics:
26    """
27    :ivar int borrowed_tokens: number of tokens currently borrowed by tasks
28    :ivar float total_tokens: total number of available tokens
29    :ivar tuple borrowers: tasks or other objects currently holding tokens borrowed from this
30        limiter
31    :ivar int tasks_waiting: number of tasks waiting on :meth:`~.CapacityLimiter.acquire` or
32        :meth:`~.CapacityLimiter.acquire_on_behalf_of`
33    """
34
35    borrowed_tokens: int
36    total_tokens: float
37    borrowers: Tuple[object, ...]
38    tasks_waiting: int
39
40
41@dataclass(frozen=True)
42class LockStatistics:
43    """
44    :ivar bool locked: flag indicating if this lock is locked or not
45    :ivar ~anyio.TaskInfo owner: task currently holding the lock (or ``None`` if the lock is not
46        held by any task)
47    :ivar int tasks_waiting: number of tasks waiting on :meth:`~.Lock.acquire`
48    """
49
50    locked: bool
51    owner: Optional[TaskInfo]
52    tasks_waiting: int
53
54
55@dataclass(frozen=True)
56class ConditionStatistics:
57    """
58    :ivar int tasks_waiting: number of tasks blocked on :meth:`~.Condition.wait`
59    :ivar ~anyio.LockStatistics lock_statistics: statistics of the underlying :class:`~.Lock`
60    """
61
62    tasks_waiting: int
63    lock_statistics: LockStatistics
64
65
66@dataclass(frozen=True)
67class SemaphoreStatistics:
68    """
69    :ivar int tasks_waiting: number of tasks waiting on :meth:`~.Semaphore.acquire`
70
71    """
72    tasks_waiting: int
73
74
75class Event:
76    def __new__(cls) -> 'Event':
77        return get_asynclib().Event()
78
79    def set(self) -> DeprecatedAwaitable:
80        """Set the flag, notifying all listeners."""
81        raise NotImplementedError
82
83    def is_set(self) -> bool:
84        """Return ``True`` if the flag is set, ``False`` if not."""
85        raise NotImplementedError
86
87    async def wait(self) -> None:
88        """
89        Wait until the flag has been set.
90
91        If the flag has already been set when this method is called, it returns immediately.
92
93        """
94        raise NotImplementedError
95
96    def statistics(self) -> EventStatistics:
97        """Return statistics about the current state of this event."""
98        raise NotImplementedError
99
100
101class Lock:
102    _owner_task: Optional[TaskInfo] = None
103
104    def __init__(self) -> None:
105        self._waiters: Deque[Tuple[TaskInfo, Event]] = deque()
106
107    async def __aenter__(self) -> None:
108        await self.acquire()
109
110    async def __aexit__(self, exc_type: Optional[Type[BaseException]],
111                        exc_val: Optional[BaseException],
112                        exc_tb: Optional[TracebackType]) -> None:
113        self.release()
114
115    async def acquire(self) -> None:
116        """Acquire the lock."""
117        await checkpoint_if_cancelled()
118        try:
119            self.acquire_nowait()
120        except WouldBlock:
121            task = get_current_task()
122            event = Event()
123            token = task, event
124            self._waiters.append(token)
125            try:
126                await event.wait()
127            except BaseException:
128                if not event.is_set():
129                    self._waiters.remove(token)
130                elif self._owner_task == task:
131                    self.release()
132
133                raise
134
135            assert self._owner_task == task
136        else:
137            await cancel_shielded_checkpoint()
138
139    def acquire_nowait(self) -> None:
140        """
141        Acquire the lock, without blocking.
142
143        :raises ~WouldBlock: if the operation would block
144
145        """
146        task = get_current_task()
147        if self._owner_task == task:
148            raise RuntimeError('Attempted to acquire an already held Lock')
149
150        if self._owner_task is not None:
151            raise WouldBlock
152
153        self._owner_task = task
154
155    def release(self) -> DeprecatedAwaitable:
156        """Release the lock."""
157        if self._owner_task != get_current_task():
158            raise RuntimeError('The current task is not holding this lock')
159
160        if self._waiters:
161            self._owner_task, event = self._waiters.popleft()
162            event.set()
163        else:
164            del self._owner_task
165
166        return DeprecatedAwaitable(self.release)
167
168    def locked(self) -> bool:
169        """Return True if the lock is currently held."""
170        return self._owner_task is not None
171
172    def statistics(self) -> LockStatistics:
173        """
174        Return statistics about the current state of this lock.
175
176        .. versionadded:: 3.0
177        """
178        return LockStatistics(self.locked(), self._owner_task, len(self._waiters))
179
180
181class Condition:
182    _owner_task: Optional[TaskInfo] = None
183
184    def __init__(self, lock: Optional[Lock] = None):
185        self._lock = lock or Lock()
186        self._waiters: Deque[Event] = deque()
187
188    async def __aenter__(self) -> None:
189        await self.acquire()
190
191    async def __aexit__(self, exc_type: Optional[Type[BaseException]],
192                        exc_val: Optional[BaseException],
193                        exc_tb: Optional[TracebackType]) -> None:
194        self.release()
195
196    def _check_acquired(self) -> None:
197        if self._owner_task != get_current_task():
198            raise RuntimeError('The current task is not holding the underlying lock')
199
200    async def acquire(self) -> None:
201        """Acquire the underlying lock."""
202        await self._lock.acquire()
203        self._owner_task = get_current_task()
204
205    def acquire_nowait(self) -> None:
206        """
207        Acquire the underlying lock, without blocking.
208
209        :raises ~WouldBlock: if the operation would block
210
211        """
212        self._lock.acquire_nowait()
213        self._owner_task = get_current_task()
214
215    def release(self) -> DeprecatedAwaitable:
216        """Release the underlying lock."""
217        self._lock.release()
218        return DeprecatedAwaitable(self.release)
219
220    def locked(self) -> bool:
221        """Return True if the lock is set."""
222        return self._lock.locked()
223
224    def notify(self, n: int = 1) -> None:
225        """Notify exactly n listeners."""
226        self._check_acquired()
227        for _ in range(n):
228            try:
229                event = self._waiters.popleft()
230            except IndexError:
231                break
232
233            event.set()
234
235    def notify_all(self) -> None:
236        """Notify all the listeners."""
237        self._check_acquired()
238        for event in self._waiters:
239            event.set()
240
241        self._waiters.clear()
242
243    async def wait(self) -> None:
244        """Wait for a notification."""
245        await checkpoint()
246        event = Event()
247        self._waiters.append(event)
248        self.release()
249        try:
250            await event.wait()
251        except BaseException:
252            if not event.is_set():
253                self._waiters.remove(event)
254
255            raise
256        finally:
257            with CancelScope(shield=True):
258                await self.acquire()
259
260    def statistics(self) -> ConditionStatistics:
261        """
262        Return statistics about the current state of this condition.
263
264        .. versionadded:: 3.0
265        """
266        return ConditionStatistics(len(self._waiters), self._lock.statistics())
267
268
269class Semaphore:
270    def __init__(self, initial_value: int, *, max_value: Optional[int] = None):
271        if not isinstance(initial_value, int):
272            raise TypeError('initial_value must be an integer')
273        if initial_value < 0:
274            raise ValueError('initial_value must be >= 0')
275        if max_value is not None:
276            if not isinstance(max_value, int):
277                raise TypeError('max_value must be an integer or None')
278            if max_value < initial_value:
279                raise ValueError('max_value must be equal to or higher than initial_value')
280
281        self._value = initial_value
282        self._max_value = max_value
283        self._waiters: Deque[Event] = deque()
284
285    async def __aenter__(self) -> 'Semaphore':
286        await self.acquire()
287        return self
288
289    async def __aexit__(self, exc_type: Optional[Type[BaseException]],
290                        exc_val: Optional[BaseException],
291                        exc_tb: Optional[TracebackType]) -> None:
292        self.release()
293
294    async def acquire(self) -> None:
295        """Decrement the semaphore value, blocking if necessary."""
296        await checkpoint_if_cancelled()
297        try:
298            self.acquire_nowait()
299        except WouldBlock:
300            event = Event()
301            self._waiters.append(event)
302            try:
303                await event.wait()
304            except BaseException:
305                if not event.is_set():
306                    self._waiters.remove(event)
307                else:
308                    self.release()
309
310                raise
311        else:
312            await cancel_shielded_checkpoint()
313
314    def acquire_nowait(self) -> None:
315        """
316        Acquire the underlying lock, without blocking.
317
318        :raises ~WouldBlock: if the operation would block
319
320        """
321        if self._value == 0:
322            raise WouldBlock
323
324        self._value -= 1
325
326    def release(self) -> DeprecatedAwaitable:
327        """Increment the semaphore value."""
328        if self._max_value is not None and self._value == self._max_value:
329            raise ValueError('semaphore released too many times')
330
331        if self._waiters:
332            self._waiters.popleft().set()
333        else:
334            self._value += 1
335
336        return DeprecatedAwaitable(self.release)
337
338    @property
339    def value(self) -> int:
340        """The current value of the semaphore."""
341        return self._value
342
343    @property
344    def max_value(self) -> Optional[int]:
345        """The maximum value of the semaphore."""
346        return self._max_value
347
348    def statistics(self) -> SemaphoreStatistics:
349        """
350        Return statistics about the current state of this semaphore.
351
352        .. versionadded:: 3.0
353        """
354        return SemaphoreStatistics(len(self._waiters))
355
356
357class CapacityLimiter:
358    def __new__(cls, total_tokens: float) -> 'CapacityLimiter':
359        return get_asynclib().CapacityLimiter(total_tokens)
360
361    async def __aenter__(self) -> None:
362        raise NotImplementedError
363
364    async def __aexit__(self, exc_type: Optional[Type[BaseException]],
365                        exc_val: Optional[BaseException],
366                        exc_tb: Optional[TracebackType]) -> Optional[bool]:
367        raise NotImplementedError
368
369    @property
370    def total_tokens(self) -> float:
371        """
372        The total number of tokens available for borrowing.
373
374        This is a read-write property. If the total number of tokens is increased, the
375        proportionate number of tasks waiting on this limiter will be granted their tokens.
376
377        .. versionchanged:: 3.0
378            The property is now writable.
379
380        """
381        raise NotImplementedError
382
383    @total_tokens.setter
384    def total_tokens(self, value: float) -> None:
385        raise NotImplementedError
386
387    async def set_total_tokens(self, value: float) -> None:
388        warn('CapacityLimiter.set_total_tokens has been deprecated. Set the value of the'
389             '"total_tokens" attribute directly.', DeprecationWarning)
390        self.total_tokens = value
391
392    @property
393    def borrowed_tokens(self) -> int:
394        """The number of tokens that have currently been borrowed."""
395        raise NotImplementedError
396
397    @property
398    def available_tokens(self) -> float:
399        """The number of tokens currently available to be borrowed"""
400        raise NotImplementedError
401
402    def acquire_nowait(self) -> DeprecatedAwaitable:
403        """
404        Acquire a token for the current task without waiting for one to become available.
405
406        :raises ~anyio.WouldBlock: if there are no tokens available for borrowing
407
408        """
409        raise NotImplementedError
410
411    def acquire_on_behalf_of_nowait(self, borrower: object) -> DeprecatedAwaitable:
412        """
413        Acquire a token without waiting for one to become available.
414
415        :param borrower: the entity borrowing a token
416        :raises ~anyio.WouldBlock: if there are no tokens available for borrowing
417
418        """
419        raise NotImplementedError
420
421    async def acquire(self) -> None:
422        """
423        Acquire a token for the current task, waiting if necessary for one to become available.
424
425        """
426        raise NotImplementedError
427
428    async def acquire_on_behalf_of(self, borrower: object) -> None:
429        """
430        Acquire a token, waiting if necessary for one to become available.
431
432        :param borrower: the entity borrowing a token
433
434        """
435        raise NotImplementedError
436
437    def release(self) -> None:
438        """
439        Release the token held by the current task.
440        :raises RuntimeError: if the current task has not borrowed a token from this limiter.
441
442        """
443        raise NotImplementedError
444
445    def release_on_behalf_of(self, borrower: object) -> None:
446        """
447        Release the token held by the given borrower.
448
449        :raises RuntimeError: if the borrower has not borrowed a token from this limiter.
450
451        """
452        raise NotImplementedError
453
454    def statistics(self) -> CapacityLimiterStatistics:
455        """
456        Return statistics about the current state of this limiter.
457
458        .. versionadded:: 3.0
459
460        """
461        raise NotImplementedError
462
463
464def create_lock() -> Lock:
465    """
466    Create an asynchronous lock.
467
468    :return: a lock object
469
470    .. deprecated:: 3.0
471       Use :class:`~Lock` directly.
472
473    """
474    warn('create_lock() is deprecated -- use Lock() directly', DeprecationWarning)
475    return Lock()
476
477
478def create_condition(lock: Optional[Lock] = None) -> Condition:
479    """
480    Create an asynchronous condition.
481
482    :param lock: the lock to base the condition object on
483    :return: a condition object
484
485    .. deprecated:: 3.0
486       Use :class:`~Condition` directly.
487
488    """
489    warn('create_condition() is deprecated -- use Condition() directly', DeprecationWarning)
490    return Condition(lock=lock)
491
492
493def create_event() -> Event:
494    """
495    Create an asynchronous event object.
496
497    :return: an event object
498
499    .. deprecated:: 3.0
500       Use :class:`~Event` directly.
501
502    """
503    warn('create_event() is deprecated -- use Event() directly', DeprecationWarning)
504    return get_asynclib().Event()
505
506
507def create_semaphore(value: int, *, max_value: Optional[int] = None) -> Semaphore:
508    """
509    Create an asynchronous semaphore.
510
511    :param value: the semaphore's initial value
512    :param max_value: if set, makes this a "bounded" semaphore that raises :exc:`ValueError` if the
513        semaphore's value would exceed this number
514    :return: a semaphore object
515
516    .. deprecated:: 3.0
517       Use :class:`~Semaphore` directly.
518
519    """
520    warn('create_semaphore() is deprecated -- use Semaphore() directly', DeprecationWarning)
521    return Semaphore(value, max_value=max_value)
522
523
524def create_capacity_limiter(total_tokens: float) -> CapacityLimiter:
525    """
526    Create a capacity limiter.
527
528    :param total_tokens: the total number of tokens available for borrowing (can be an integer or
529        :data:`math.inf`)
530    :return: a capacity limiter object
531
532    .. deprecated:: 3.0
533       Use :class:`~CapacityLimiter` directly.
534
535    """
536    warn('create_capacity_limiter() is deprecated -- use CapacityLimiter() directly',
537         DeprecationWarning)
538    return get_asynclib().CapacityLimiter(total_tokens)
539
540
541class ResourceGuard:
542    __slots__ = 'action', '_guarded'
543
544    def __init__(self, action: str):
545        self.action = action
546        self._guarded = False
547
548    def __enter__(self) -> None:
549        if self._guarded:
550            raise BusyResourceError(self.action)
551
552        self._guarded = True
553
554    def __exit__(self, exc_type: Optional[Type[BaseException]],
555                 exc_val: Optional[BaseException],
556                 exc_tb: Optional[TracebackType]) -> Optional[bool]:
557        self._guarded = False
558        return None
559