1# Copyright 2021 The Duet Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import collections 16from typing import ( 17 AsyncIterable, 18 AsyncIterator, 19 Deque, 20 Generic, 21 Iterable, 22 Optional, 23 Tuple, 24 TypeVar, 25 Union, 26) 27 28import duet.futuretools as futuretools 29 30T = TypeVar("T") 31 32AnyIterable = Union[Iterable[T], AsyncIterable[T]] 33 34 35async def aenumerate(iterable: AnyIterable[T], start: int = 0) -> AsyncIterator[Tuple[int, T]]: 36 i = start 37 async for value in aiter(iterable): 38 yield (i, value) 39 i += 1 40 41 42async def aiter(iterable: AnyIterable[T]) -> AsyncIterator[T]: 43 if isinstance(iterable, Iterable): 44 for value in iterable: 45 yield value 46 else: 47 async for value in iterable: 48 yield value 49 50 51async def azip(*iterables: AnyIterable) -> AsyncIterator[Tuple]: 52 iters = [aiter(iterable) for iterable in iterables] 53 while True: 54 values = [] 55 for it in iters: 56 try: 57 value = await it.__anext__() 58 values.append(value) 59 except StopAsyncIteration: 60 return 61 yield tuple(values) 62 63 64class AsyncCollector(Generic[T]): 65 """Allows async iteration over values dynamically added by the client. 66 67 This class is useful for creating an asynchronous iterator that is "fed" by 68 one process (the "producer") and iterated over by another process (the 69 "consumer"). The producer calls `.add` repeatedly to add values to be 70 iterated over, and then calls either `.done` or `.error` to stop the 71 iteration or raise an error, respectively. The consumer can use `async for` 72 or direct calls to `__anext__` to iterate over the produced values. 73 """ 74 75 def __init__(self): 76 self._buffer: Deque[T] = collections.deque() 77 self._waiter: Optional[futuretools.AwaitableFuture[None]] = None 78 self._done: bool = False 79 self._error: Optional[Exception] = None 80 81 def add(self, value: T) -> None: 82 if self._done: 83 raise RuntimeError("already done.") 84 self._buffer.append(value) 85 if self._waiter: 86 self._waiter.try_set_result(None) 87 88 def done(self) -> None: 89 if self._done: 90 raise RuntimeError("already done.") 91 self._done = True 92 if self._waiter: 93 self._waiter.try_set_result(None) 94 95 def error(self, error: Exception) -> None: 96 if self._done: 97 raise RuntimeError("already done.") 98 self._done = True 99 self._error = error 100 if self._waiter: 101 self._waiter.try_set_result(None) 102 103 def __aiter__(self) -> AsyncIterator[T]: 104 return self 105 106 async def __anext__(self) -> T: 107 if not self._done and not self._buffer: 108 self._waiter = futuretools.AwaitableFuture() 109 await self._waiter 110 self._waiter = None 111 if self._buffer: 112 return self._buffer.popleft() 113 if self._error: 114 raise self._error 115 raise StopAsyncIteration() 116