1"""Synchronization primitives."""
2
3__all__ = ('Lock', 'Event', 'Condition', 'Semaphore', 'BoundedSemaphore')
4
5import collections
6
7from . import exceptions
8from . import mixins
9
10
11class _ContextManagerMixin:
12    async def __aenter__(self):
13        await self.acquire()
14        # We have no use for the "as ..."  clause in the with
15        # statement for locks.
16        return None
17
18    async def __aexit__(self, exc_type, exc, tb):
19        self.release()
20
21
22class Lock(_ContextManagerMixin, mixins._LoopBoundMixin):
23    """Primitive lock objects.
24
25    A primitive lock is a synchronization primitive that is not owned
26    by a particular coroutine when locked.  A primitive lock is in one
27    of two states, 'locked' or 'unlocked'.
28
29    It is created in the unlocked state.  It has two basic methods,
30    acquire() and release().  When the state is unlocked, acquire()
31    changes the state to locked and returns immediately.  When the
32    state is locked, acquire() blocks until a call to release() in
33    another coroutine changes it to unlocked, then the acquire() call
34    resets it to locked and returns.  The release() method should only
35    be called in the locked state; it changes the state to unlocked
36    and returns immediately.  If an attempt is made to release an
37    unlocked lock, a RuntimeError will be raised.
38
39    When more than one coroutine is blocked in acquire() waiting for
40    the state to turn to unlocked, only one coroutine proceeds when a
41    release() call resets the state to unlocked; first coroutine which
42    is blocked in acquire() is being processed.
43
44    acquire() is a coroutine and should be called with 'await'.
45
46    Locks also support the asynchronous context management protocol.
47    'async with lock' statement should be used.
48
49    Usage:
50
51        lock = Lock()
52        ...
53        await lock.acquire()
54        try:
55            ...
56        finally:
57            lock.release()
58
59    Context manager usage:
60
61        lock = Lock()
62        ...
63        async with lock:
64             ...
65
66    Lock objects can be tested for locking state:
67
68        if not lock.locked():
69           await lock.acquire()
70        else:
71           # lock is acquired
72           ...
73
74    """
75
76    def __init__(self, *, loop=mixins._marker):
77        super().__init__(loop=loop)
78        self._waiters = None
79        self._locked = False
80
81    def __repr__(self):
82        res = super().__repr__()
83        extra = 'locked' if self._locked else 'unlocked'
84        if self._waiters:
85            extra = f'{extra}, waiters:{len(self._waiters)}'
86        return f'<{res[1:-1]} [{extra}]>'
87
88    def locked(self):
89        """Return True if lock is acquired."""
90        return self._locked
91
92    async def acquire(self):
93        """Acquire a lock.
94
95        This method blocks until the lock is unlocked, then sets it to
96        locked and returns True.
97        """
98        if (not self._locked and (self._waiters is None or
99                all(w.cancelled() for w in self._waiters))):
100            self._locked = True
101            return True
102
103        if self._waiters is None:
104            self._waiters = collections.deque()
105        fut = self._get_loop().create_future()
106        self._waiters.append(fut)
107
108        # Finally block should be called before the CancelledError
109        # handling as we don't want CancelledError to call
110        # _wake_up_first() and attempt to wake up itself.
111        try:
112            try:
113                await fut
114            finally:
115                self._waiters.remove(fut)
116        except exceptions.CancelledError:
117            if not self._locked:
118                self._wake_up_first()
119            raise
120
121        self._locked = True
122        return True
123
124    def release(self):
125        """Release a lock.
126
127        When the lock is locked, reset it to unlocked, and return.
128        If any other coroutines are blocked waiting for the lock to become
129        unlocked, allow exactly one of them to proceed.
130
131        When invoked on an unlocked lock, a RuntimeError is raised.
132
133        There is no return value.
134        """
135        if self._locked:
136            self._locked = False
137            self._wake_up_first()
138        else:
139            raise RuntimeError('Lock is not acquired.')
140
141    def _wake_up_first(self):
142        """Wake up the first waiter if it isn't done."""
143        if not self._waiters:
144            return
145        try:
146            fut = next(iter(self._waiters))
147        except StopIteration:
148            return
149
150        # .done() necessarily means that a waiter will wake up later on and
151        # either take the lock, or, if it was cancelled and lock wasn't
152        # taken already, will hit this again and wake up a new waiter.
153        if not fut.done():
154            fut.set_result(True)
155
156
157class Event(mixins._LoopBoundMixin):
158    """Asynchronous equivalent to threading.Event.
159
160    Class implementing event objects. An event manages a flag that can be set
161    to true with the set() method and reset to false with the clear() method.
162    The wait() method blocks until the flag is true. The flag is initially
163    false.
164    """
165
166    def __init__(self, *, loop=mixins._marker):
167        super().__init__(loop=loop)
168        self._waiters = collections.deque()
169        self._value = False
170
171    def __repr__(self):
172        res = super().__repr__()
173        extra = 'set' if self._value else 'unset'
174        if self._waiters:
175            extra = f'{extra}, waiters:{len(self._waiters)}'
176        return f'<{res[1:-1]} [{extra}]>'
177
178    def is_set(self):
179        """Return True if and only if the internal flag is true."""
180        return self._value
181
182    def set(self):
183        """Set the internal flag to true. All coroutines waiting for it to
184        become true are awakened. Coroutine that call wait() once the flag is
185        true will not block at all.
186        """
187        if not self._value:
188            self._value = True
189
190            for fut in self._waiters:
191                if not fut.done():
192                    fut.set_result(True)
193
194    def clear(self):
195        """Reset the internal flag to false. Subsequently, coroutines calling
196        wait() will block until set() is called to set the internal flag
197        to true again."""
198        self._value = False
199
200    async def wait(self):
201        """Block until the internal flag is true.
202
203        If the internal flag is true on entry, return True
204        immediately.  Otherwise, block until another coroutine calls
205        set() to set the flag to true, then return True.
206        """
207        if self._value:
208            return True
209
210        fut = self._get_loop().create_future()
211        self._waiters.append(fut)
212        try:
213            await fut
214            return True
215        finally:
216            self._waiters.remove(fut)
217
218
219class Condition(_ContextManagerMixin, mixins._LoopBoundMixin):
220    """Asynchronous equivalent to threading.Condition.
221
222    This class implements condition variable objects. A condition variable
223    allows one or more coroutines to wait until they are notified by another
224    coroutine.
225
226    A new Lock object is created and used as the underlying lock.
227    """
228
229    def __init__(self, lock=None, *, loop=mixins._marker):
230        super().__init__(loop=loop)
231        if lock is None:
232            lock = Lock()
233
234        self._lock = lock
235        # Export the lock's locked(), acquire() and release() methods.
236        self.locked = lock.locked
237        self.acquire = lock.acquire
238        self.release = lock.release
239
240        self._waiters = collections.deque()
241
242    def __repr__(self):
243        res = super().__repr__()
244        extra = 'locked' if self.locked() else 'unlocked'
245        if self._waiters:
246            extra = f'{extra}, waiters:{len(self._waiters)}'
247        return f'<{res[1:-1]} [{extra}]>'
248
249    async def wait(self):
250        """Wait until notified.
251
252        If the calling coroutine has not acquired the lock when this
253        method is called, a RuntimeError is raised.
254
255        This method releases the underlying lock, and then blocks
256        until it is awakened by a notify() or notify_all() call for
257        the same condition variable in another coroutine.  Once
258        awakened, it re-acquires the lock and returns True.
259        """
260        if not self.locked():
261            raise RuntimeError('cannot wait on un-acquired lock')
262
263        self.release()
264        try:
265            fut = self._get_loop().create_future()
266            self._waiters.append(fut)
267            try:
268                await fut
269                return True
270            finally:
271                self._waiters.remove(fut)
272
273        finally:
274            # Must reacquire lock even if wait is cancelled
275            cancelled = False
276            while True:
277                try:
278                    await self.acquire()
279                    break
280                except exceptions.CancelledError:
281                    cancelled = True
282
283            if cancelled:
284                raise exceptions.CancelledError
285
286    async def wait_for(self, predicate):
287        """Wait until a predicate becomes true.
288
289        The predicate should be a callable which result will be
290        interpreted as a boolean value.  The final predicate value is
291        the return value.
292        """
293        result = predicate()
294        while not result:
295            await self.wait()
296            result = predicate()
297        return result
298
299    def notify(self, n=1):
300        """By default, wake up one coroutine waiting on this condition, if any.
301        If the calling coroutine has not acquired the lock when this method
302        is called, a RuntimeError is raised.
303
304        This method wakes up at most n of the coroutines waiting for the
305        condition variable; it is a no-op if no coroutines are waiting.
306
307        Note: an awakened coroutine does not actually return from its
308        wait() call until it can reacquire the lock. Since notify() does
309        not release the lock, its caller should.
310        """
311        if not self.locked():
312            raise RuntimeError('cannot notify on un-acquired lock')
313
314        idx = 0
315        for fut in self._waiters:
316            if idx >= n:
317                break
318
319            if not fut.done():
320                idx += 1
321                fut.set_result(False)
322
323    def notify_all(self):
324        """Wake up all threads waiting on this condition. This method acts
325        like notify(), but wakes up all waiting threads instead of one. If the
326        calling thread has not acquired the lock when this method is called,
327        a RuntimeError is raised.
328        """
329        self.notify(len(self._waiters))
330
331
332class Semaphore(_ContextManagerMixin, mixins._LoopBoundMixin):
333    """A Semaphore implementation.
334
335    A semaphore manages an internal counter which is decremented by each
336    acquire() call and incremented by each release() call. The counter
337    can never go below zero; when acquire() finds that it is zero, it blocks,
338    waiting until some other thread calls release().
339
340    Semaphores also support the context management protocol.
341
342    The optional argument gives the initial value for the internal
343    counter; it defaults to 1. If the value given is less than 0,
344    ValueError is raised.
345    """
346
347    def __init__(self, value=1, *, loop=mixins._marker):
348        super().__init__(loop=loop)
349        if value < 0:
350            raise ValueError("Semaphore initial value must be >= 0")
351        self._value = value
352        self._waiters = collections.deque()
353
354    def __repr__(self):
355        res = super().__repr__()
356        extra = 'locked' if self.locked() else f'unlocked, value:{self._value}'
357        if self._waiters:
358            extra = f'{extra}, waiters:{len(self._waiters)}'
359        return f'<{res[1:-1]} [{extra}]>'
360
361    def _wake_up_next(self):
362        while self._waiters:
363            waiter = self._waiters.popleft()
364            if not waiter.done():
365                waiter.set_result(None)
366                return
367
368    def locked(self):
369        """Returns True if semaphore can not be acquired immediately."""
370        return self._value == 0
371
372    async def acquire(self):
373        """Acquire a semaphore.
374
375        If the internal counter is larger than zero on entry,
376        decrement it by one and return True immediately.  If it is
377        zero on entry, block, waiting until some other coroutine has
378        called release() to make it larger than 0, and then return
379        True.
380        """
381        while self._value <= 0:
382            fut = self._get_loop().create_future()
383            self._waiters.append(fut)
384            try:
385                await fut
386            except:
387                # See the similar code in Queue.get.
388                fut.cancel()
389                if self._value > 0 and not fut.cancelled():
390                    self._wake_up_next()
391                raise
392        self._value -= 1
393        return True
394
395    def release(self):
396        """Release a semaphore, incrementing the internal counter by one.
397        When it was zero on entry and another coroutine is waiting for it to
398        become larger than zero again, wake up that coroutine.
399        """
400        self._value += 1
401        self._wake_up_next()
402
403
404class BoundedSemaphore(Semaphore):
405    """A bounded semaphore implementation.
406
407    This raises ValueError in release() if it would increase the value
408    above the initial value.
409    """
410
411    def __init__(self, value=1, *, loop=mixins._marker):
412        self._bound_value = value
413        super().__init__(value, loop=loop)
414
415    def release(self):
416        if self._value >= self._bound_value:
417            raise ValueError('BoundedSemaphore released too many times')
418        super().release()
419