1# Copyright 2021 The Duet Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import collections
16from typing import (
17    AsyncIterable,
18    AsyncIterator,
19    Deque,
20    Generic,
21    Iterable,
22    Optional,
23    Tuple,
24    TypeVar,
25    Union,
26)
27
28import duet.futuretools as futuretools
29
30T = TypeVar("T")
31
32AnyIterable = Union[Iterable[T], AsyncIterable[T]]
33
34
35async def aenumerate(iterable: AnyIterable[T], start: int = 0) -> AsyncIterator[Tuple[int, T]]:
36    i = start
37    async for value in aiter(iterable):
38        yield (i, value)
39        i += 1
40
41
42async def aiter(iterable: AnyIterable[T]) -> AsyncIterator[T]:
43    if isinstance(iterable, Iterable):
44        for value in iterable:
45            yield value
46    else:
47        async for value in iterable:
48            yield value
49
50
51async def azip(*iterables: AnyIterable) -> AsyncIterator[Tuple]:
52    iters = [aiter(iterable) for iterable in iterables]
53    while True:
54        values = []
55        for it in iters:
56            try:
57                value = await it.__anext__()
58                values.append(value)
59            except StopAsyncIteration:
60                return
61        yield tuple(values)
62
63
64class AsyncCollector(Generic[T]):
65    """Allows async iteration over values dynamically added by the client.
66
67    This class is useful for creating an asynchronous iterator that is "fed" by
68    one process (the "producer") and iterated over by another process (the
69    "consumer"). The producer calls `.add` repeatedly to add values to be
70    iterated over, and then calls either `.done` or `.error` to stop the
71    iteration or raise an error, respectively. The consumer can use `async for`
72    or direct calls to `__anext__` to iterate over the produced values.
73    """
74
75    def __init__(self):
76        self._buffer: Deque[T] = collections.deque()
77        self._waiter: Optional[futuretools.AwaitableFuture[None]] = None
78        self._done: bool = False
79        self._error: Optional[Exception] = None
80
81    def add(self, value: T) -> None:
82        if self._done:
83            raise RuntimeError("already done.")
84        self._buffer.append(value)
85        if self._waiter:
86            self._waiter.try_set_result(None)
87
88    def done(self) -> None:
89        if self._done:
90            raise RuntimeError("already done.")
91        self._done = True
92        if self._waiter:
93            self._waiter.try_set_result(None)
94
95    def error(self, error: Exception) -> None:
96        if self._done:
97            raise RuntimeError("already done.")
98        self._done = True
99        self._error = error
100        if self._waiter:
101            self._waiter.try_set_result(None)
102
103    def __aiter__(self) -> AsyncIterator[T]:
104        return self
105
106    async def __anext__(self) -> T:
107        if not self._done and not self._buffer:
108            self._waiter = futuretools.AwaitableFuture()
109            await self._waiter
110            self._waiter = None
111        if self._buffer:
112            return self._buffer.popleft()
113        if self._error:
114            raise self._error
115        raise StopAsyncIteration()
116