1# Copyright 2019 John Reese
2# Licensed under the MIT license
3
4"""
5Friendlier version of asyncio standard library.
6
7Provisional library.  Must be imported as `aioitertools.asyncio`.
8"""
9
10import asyncio
11import time
12from typing import Any, Awaitable, Dict, Iterable, List, Optional, Set, Tuple, cast
13
14from .builtins import maybe_await, iter as aiter
15from .types import AsyncIterator, MaybeAwaitable, AnyIterable, T
16
17
18async def as_completed(
19    aws: Iterable[Awaitable[T]],
20    *,
21    loop: Optional[asyncio.AbstractEventLoop] = None,
22    timeout: Optional[float] = None
23) -> AsyncIterator[T]:
24    """
25    Run awaitables in `aws` concurrently, and yield results as they complete.
26
27    Unlike `asyncio.as_completed`, this yields actual results, and does not require
28    awaiting each item in the iterable.
29
30    Example::
31
32        async for value in as_completed(futures):
33            ...  # use value immediately
34
35    """
36    done: Set[Awaitable[T]] = set()
37    pending: Set[Awaitable[T]] = set(aws)
38    remaining: Optional[float] = None
39
40    if timeout and timeout > 0:
41        threshold = time.time() + timeout
42    else:
43        timeout = None
44
45    while pending:
46        if timeout:
47            remaining = threshold - time.time()
48            if remaining <= 0:
49                raise asyncio.TimeoutError()
50
51        # asyncio.Future inherits from typing.Awaitable
52        # asyncio.wait takes Iterable[Union[Future, Generator, Awaitable]], but
53        # returns Tuple[Set[Future], Set[Future]. Because mypy doesn't like assigning
54        # these values to existing Set[Awaitable] or even Set[Union[Awaitable, Future]],
55        # we need to first cast the results to something that we can actually use
56        # asyncio.Future: https://github.com/python/typeshed/blob/72ff7b94e534c610ddf8939bacbc55343e9465d2/stdlib/3/asyncio/futures.pyi#L30  # noqa: E501
57        # asyncio.wait(): https://github.com/python/typeshed/blob/72ff7b94e534c610ddf8939bacbc55343e9465d2/stdlib/3/asyncio/tasks.pyi#L89  # noqa: E501
58        done, pending = cast(
59            Tuple[Set[Awaitable[T]], Set[Awaitable[T]]],
60            await asyncio.wait(
61                pending,
62                loop=loop,
63                timeout=remaining,
64                return_when=asyncio.FIRST_COMPLETED,
65            ),
66        )
67
68        for item in done:
69            yield await item
70
71
72async def gather(
73    *args: Awaitable[T],
74    loop: Optional[asyncio.AbstractEventLoop] = None,
75    return_exceptions: bool = False,
76    limit: int = -1
77) -> List[Any]:
78    """
79    Like asyncio.gather but with a limit on concurrency.
80
81    Note that all results are buffered.
82
83    If gather is cancelled all tasks that were internally created and still pending
84    will be cancelled as well.
85
86    Example::
87
88        futures = [some_coro(i) for i in range(10)]
89
90        results = await gather(*futures, limit=2)
91    """
92
93    # For detecting input duplicates and reconciling them at the end
94    input_map: Dict[Awaitable[T], List[int]] = {}
95    # This is keyed on what we'll get back from asyncio.wait
96    pos: Dict[asyncio.Future[T], int] = {}
97    ret: List[Any] = [None] * len(args)
98
99    pending: Set[asyncio.Future[T]] = set()
100    done: Set[asyncio.Future[T]] = set()
101
102    next_arg = 0
103
104    while True:
105        while next_arg < len(args) and (limit == -1 or len(pending) < limit):
106            # We have to defer the creation of the Task as long as possible
107            # because once we do, it starts executing, regardless of what we
108            # have in the pending set.
109            if args[next_arg] in input_map:
110                input_map[args[next_arg]].append(next_arg)
111            else:
112                # We call ensure_future directly to ensure that we have a Task
113                # because the return value of asyncio.wait will be an implicit
114                # task otherwise, and we won't be able to know which input it
115                # corresponds to.
116                task: asyncio.Future[T] = asyncio.ensure_future(args[next_arg])
117                pending.add(task)
118                pos[task] = next_arg
119                input_map[args[next_arg]] = [next_arg]
120            next_arg += 1
121
122        # pending might be empty if the last items of args were dupes;
123        # asyncio.wait([]) will raise an exception.
124        if pending:
125            try:
126                done, pending = await asyncio.wait(
127                    pending, loop=loop, return_when=asyncio.FIRST_COMPLETED
128                )
129                for x in done:
130                    if return_exceptions and x.exception():
131                        ret[pos[x]] = x.exception()
132                    else:
133                        ret[pos[x]] = x.result()
134            except asyncio.CancelledError:
135                # Since we created these tasks we should cancel them
136                for x in pending:
137                    x.cancel()
138                # we insure that all tasks are cancelled before we raise
139                await asyncio.gather(*pending, loop=loop, return_exceptions=True)
140                raise
141
142        if not pending and next_arg == len(args):
143            break
144
145    for lst in input_map.values():
146        for i in range(1, len(lst)):
147            ret[lst[i]] = ret[lst[0]]
148
149    return ret
150
151
152async def gather_iter(
153    itr: AnyIterable[MaybeAwaitable[T]],
154    loop: Optional[asyncio.AbstractEventLoop] = None,
155    return_exceptions: bool = False,
156    limit: int = -1,
157) -> List[T]:
158    """
159    Wrapper around gather to handle gathering an iterable instead of *args.
160
161    Note that the iterable values don't have to be awaitable.
162    """
163    return await gather(
164        *[maybe_await(i) async for i in aiter(itr)],
165        loop=loop,
166        return_exceptions=return_exceptions,
167        limit=limit,
168    )
169