1# The code below is mostly my own but based on the interfaces of the
2# curio library by David Beazley.  I'm considering switching to using
3# curio.  In the mean-time this is an attempt to provide a similar
4# clean, pure-async interface and move away from direct
5# framework-specific dependencies.  As asyncio differs in its design
6# it is not possible to provide identical semantics.
7#
8# The curio library is distributed under the following licence:
9#
10# Copyright (C) 2015-2017
11# David Beazley (Dabeaz LLC)
12# All rights reserved.
13#
14# Redistribution and use in source and binary forms, with or without
15# modification, are permitted provided that the following conditions are
16# met:
17#
18# * Redistributions of source code must retain the above copyright notice,
19#   this list of conditions and the following disclaimer.
20# * Redistributions in binary form must reproduce the above copyright notice,
21#   this list of conditions and the following disclaimer in the documentation
22#   and/or other materials provided with the distribution.
23# * Neither the name of the David Beazley or Dabeaz LLC may be used to
24#   endorse or promote products derived from this software without
25#   specific prior written permission.
26#
27# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
28# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
29# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
30# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
31# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
32# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
33# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
34# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
35# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
36# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
37# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
38
39from asyncio import (
40    CancelledError, get_event_loop, Queue, Event, Lock, Semaphore, sleep
41)
42from collections import deque
43from contextlib import suppress
44from functools import partial
45import logging
46import sys
47
48from aiorpcx.util import instantiate_coroutine, check_task
49
50
51__all__ = (
52    'Queue', 'Event', 'Lock', 'Semaphore', 'sleep', 'CancelledError',
53    'run_in_thread', 'spawn', 'spawn_sync', 'TaskGroup', 'NoRemainingTasksError',
54    'TaskTimeout', 'TimeoutCancellationError', 'UncaughtTimeoutError',
55    'timeout_after', 'timeout_at', 'ignore_after', 'ignore_at',
56)
57
58
59if sys.version_info >= (3, 7):
60    from asyncio import current_task
61else:
62    from asyncio import Task
63    current_task = Task.current_task
64
65
66async def run_in_thread(func, *args):
67    '''Run a function in a separate thread, and await its completion.'''
68    return await get_event_loop().run_in_executor(None, func, *args)
69
70
71async def spawn(coro, *args, loop=None, report_crash=True):
72    return spawn_sync(coro, *args, loop=loop, report_crash=report_crash)
73
74
75def spawn_sync(coro, *args, loop=None, report_crash=True):
76    coro = instantiate_coroutine(coro, args)
77    loop = loop or get_event_loop()
78    task = loop.create_task(coro)
79    if report_crash:
80        task.add_done_callback(partial(check_task, logging))
81    return task
82
83
84class NoRemainingTasksError(RuntimeError):
85    pass
86
87
88class TaskGroup(object):
89    '''A class representing a group of executing tasks. tasks is an
90    optional set of existing tasks to put into the group. New tasks
91    can later be added using the spawn() method below. wait specifies
92    the policy used for waiting for tasks. See the join() method
93    below. Each TaskGroup is an independent entity. Task groups do not
94    form a hierarchy or any kind of relationship to other previously
95    created task groups or tasks. Moreover, Tasks created by the top
96    level spawn() function are not placed into any task group. To
97    create a task in a group, it should be created using
98    TaskGroup.spawn() or explicitly added using TaskGroup.add_task().
99
100    completed attribute: the first task that completed with a result
101    in the group.  Takes into account the wait option used in the
102    TaskGroup constructor.
103    '''
104
105    def __init__(self, tasks=(), *, wait=all):
106        if wait not in (any, all, object):
107            raise ValueError('invalid wait argument')
108        self._done = deque()
109        self._pending = set()
110        self._wait = wait
111        self._done_event = Event()
112        self._logger = logging.getLogger(self.__class__.__name__)
113        self._closed = False
114        self.completed = None
115        for task in tasks:
116            self._add_task(task)
117
118    def _add_task(self, task):
119        '''Add an already existing task to the task group.'''
120        if hasattr(task, '_task_group'):
121            raise RuntimeError('task is already part of a group')
122        if self._closed:
123            raise RuntimeError('task group is closed')
124        task._task_group = self
125        if task.done():
126            self._done.append(task)
127        else:
128            self._pending.add(task)
129            task.add_done_callback(self._on_done)
130
131    def _on_done(self, task):
132        task._task_group = None
133        self._pending.remove(task)
134        self._done.append(task)
135        self._done_event.set()
136        if self.completed is None:
137            if not task.cancelled() and not task.exception():
138                if self._wait is object and task.result() is None:
139                    pass
140                else:
141                    self.completed = task
142
143    async def spawn(self, coro, *args):
144        '''Create a new task that’s part of the group. Returns a Task
145        instance.
146        '''
147        task = await spawn(coro, *args, report_crash=False)
148        self._add_task(task)
149        return task
150
151    async def add_task(self, task):
152        '''Add an already existing task to the task group.'''
153        self._add_task(task)
154
155    async def next_done(self):
156        '''Returns the next completed task.  Returns None if no more tasks
157        remain. A TaskGroup may also be used as an asynchronous iterator.
158        '''
159        if not self._done and self._pending:
160            self._done_event.clear()
161            await self._done_event.wait()
162        if self._done:
163            return self._done.popleft()
164        return None
165
166    async def next_result(self):
167        '''Returns the result of the next completed task. If the task failed
168        with an exception, that exception is raised. A RuntimeError
169        exception is raised if this is called when no remaining tasks
170        are available.'''
171        task = await self.next_done()
172        if not task:
173            raise NoRemainingTasksError('no tasks remain')
174        return task.result()
175
176    async def join(self):
177        '''Wait for tasks in the group to terminate according to the wait
178        policy for the group.
179
180        If the join() operation itself is cancelled, all remaining
181        tasks in the group are also cancelled.
182
183        If a TaskGroup is used as a context manager, the join() method
184        is called on context-exit.
185
186        Once join() returns, no more tasks may be added to the task
187        group.  Tasks can be added while join() is running.
188        '''
189        def errored(task):
190            return not task.cancelled() and task.exception()
191
192        try:
193            if self._wait in (all, object):
194                while True:
195                    task = await self.next_done()
196                    if task is None:
197                        return
198                    if errored(task):
199                        break
200                    if self._wait is object:
201                        if task.cancelled() or task.result() is not None:
202                            return
203            else:  # any
204                task = await self.next_done()
205                if task is None or not errored(task):
206                    return
207        finally:
208            # Cancel everything but don't wait as cancellation can be ignored and our
209            # exception could be e.g. a timeout.
210            await self.cancel_remaining(blocking=False)
211
212        if errored(task):
213            raise task.exception()
214
215    async def cancel_remaining(self, blocking=True):
216        '''Cancel all remaining tasks.  Wait for them to complete if blocking is True.'''
217        self._closed = True
218        for task in self._pending:
219            task.cancel()
220
221        if blocking and self._pending:
222            def pop_task(task):
223                unfinished.remove(task)
224                if not unfinished:
225                    all_done.set()
226
227            unfinished = self._pending.copy()
228            for task in unfinished:
229                task.add_done_callback(pop_task)
230            all_done = Event()
231            await all_done.wait()
232        else:
233            # So loop processes cancellations
234            await sleep(0)
235
236    def closed(self):
237        return self._closed
238
239    def __aiter__(self):
240        return self
241
242    async def __anext__(self):
243        task = await self.next_done()
244        if task:
245            return task
246        raise StopAsyncIteration
247
248    async def __aenter__(self):
249        return self
250
251    async def __aexit__(self, exc_type, exc_value, traceback):
252        if exc_type:
253            await self.cancel_remaining()
254        else:
255            await self.join()
256
257
258class TaskTimeout(CancelledError):
259
260    def __init__(self, secs):
261        self.secs = secs
262
263    def __str__(self):
264        return f'task timed out after {self.args[0]}s'
265
266
267class TimeoutCancellationError(CancelledError):
268    pass
269
270
271class UncaughtTimeoutError(Exception):
272    pass
273
274
275def _set_new_deadline(task, deadline):
276    def timeout_task():
277        # Unfortunately task.cancel is all we can do with asyncio
278        task.cancel()
279        task._timed_out = deadline
280    task._deadline_handle = task._loop.call_at(deadline, timeout_task)
281
282
283def _set_task_deadline(task, deadline):
284    deadlines = getattr(task, '_deadlines', [])
285    if deadlines:
286        if deadline < min(deadlines):
287            task._deadline_handle.cancel()
288            _set_new_deadline(task, deadline)
289    else:
290        _set_new_deadline(task, deadline)
291    deadlines.append(deadline)
292    task._deadlines = deadlines
293    task._timed_out = None
294
295
296def _unset_task_deadline(task):
297    deadlines = task._deadlines
298    timed_out_deadline = task._timed_out
299    uncaught = timed_out_deadline not in deadlines
300    task._deadline_handle.cancel()
301    deadlines.pop()
302    if deadlines:
303        _set_new_deadline(task, min(deadlines))
304    return timed_out_deadline, uncaught
305
306
307class TimeoutAfter(object):
308
309    def __init__(self, deadline, *, ignore=False, absolute=False):
310        self._deadline = deadline
311        self._ignore = ignore
312        self._absolute = absolute
313        self.expired = False
314
315    async def __aenter__(self):
316        task = current_task()
317        loop_time = task._loop.time()
318        if self._absolute:
319            self._secs = self._deadline - loop_time
320        else:
321            self._secs = self._deadline
322            self._deadline += loop_time
323        _set_task_deadline(task, self._deadline)
324        self.expired = False
325        self._task = task
326        return self
327
328    async def __aexit__(self, exc_type, exc_value, traceback):
329        timed_out_deadline, uncaught = _unset_task_deadline(self._task)
330        if exc_type not in (CancelledError, TaskTimeout,
331                            TimeoutCancellationError):
332            return False
333        if timed_out_deadline == self._deadline:
334            self.expired = True
335            if self._ignore:
336                return True
337            raise TaskTimeout(self._secs) from None
338        if timed_out_deadline is None:
339            return False
340        if uncaught:
341            raise UncaughtTimeoutError('uncaught timeout received')
342        if exc_type is TimeoutCancellationError:
343            return False
344        raise TimeoutCancellationError(timed_out_deadline) from None
345
346
347async def _timeout_after_func(seconds, absolute, coro, args):
348    coro = instantiate_coroutine(coro, args)
349    async with TimeoutAfter(seconds, absolute=absolute):
350        return await coro
351
352
353def timeout_after(seconds, coro=None, *args):
354    '''Execute the specified coroutine and return its result. However,
355    issue a cancellation request to the calling task after seconds
356    have elapsed.  When this happens, a TaskTimeout exception is
357    raised.  If coro is None, the result of this function serves
358    as an asynchronous context manager that applies a timeout to a
359    block of statements.
360
361    timeout_after() may be composed with other timeout_after()
362    operations (i.e., nested timeouts).  If an outer timeout expires
363    first, then TimeoutCancellationError is raised instead of
364    TaskTimeout.  If an inner timeout expires and fails to properly
365    TaskTimeout, a UncaughtTimeoutError is raised in the outer
366    timeout.
367
368    '''
369    if coro:
370        return _timeout_after_func(seconds, False, coro, args)
371
372    return TimeoutAfter(seconds)
373
374
375def timeout_at(clock, coro=None, *args):
376    '''Execute the specified coroutine and return its result. However,
377    issue a cancellation request to the calling task after seconds
378    have elapsed.  When this happens, a TaskTimeout exception is
379    raised.  If coro is None, the result of this function serves
380    as an asynchronous context manager that applies a timeout to a
381    block of statements.
382
383    timeout_after() may be composed with other timeout_after()
384    operations (i.e., nested timeouts).  If an outer timeout expires
385    first, then TimeoutCancellationError is raised instead of
386    TaskTimeout.  If an inner timeout expires and fails to properly
387    TaskTimeout, a UncaughtTimeoutError is raised in the outer
388    timeout.
389
390    '''
391    if coro:
392        return _timeout_after_func(clock, True, coro, args)
393
394    return TimeoutAfter(clock, absolute=True)
395
396
397async def _ignore_after_func(seconds, absolute, coro, args, timeout_result):
398    coro = instantiate_coroutine(coro, args)
399    async with TimeoutAfter(seconds, absolute=absolute, ignore=True):
400        return await coro
401
402    return timeout_result
403
404
405def ignore_after(seconds, coro=None, *args, timeout_result=None):
406    '''Execute the specified coroutine and return its result. Issue a
407    cancellation request after seconds have elapsed. When a timeout
408    occurs, no exception is raised. Instead, timeout_result is
409    returned.
410
411    If coro is None, the result is an asynchronous context manager
412    that applies a timeout to a block of statements. For the context
413    manager case, the resulting context manager object has an expired
414    attribute set to True if time expired.
415
416    Note: ignore_after() may also be composed with other timeout
417    operations. TimeoutCancellationError and UncaughtTimeoutError
418    exceptions might be raised according to the same rules as for
419    timeout_after().
420    '''
421    if coro:
422        return _ignore_after_func(seconds, False, coro, args, timeout_result)
423
424    return TimeoutAfter(seconds, ignore=True)
425
426
427def ignore_at(clock, coro=None, *args, timeout_result=None):
428    '''
429    Stop the enclosed task or block of code at an absolute
430    clock value. Same usage as ignore_after().
431    '''
432    if coro:
433        return _ignore_after_func(clock, True, coro, args, timeout_result)
434
435    return TimeoutAfter(clock, absolute=True, ignore=True)
436