1from asyncio import CancelledError, Event, Task, ensure_future, wait 2from concurrent.futures import FIRST_COMPLETED 3from inspect import isasyncgen, isawaitable 4from typing import cast, Any, AsyncIterable, Callable, Optional, Set, Type, Union 5from types import TracebackType 6 7__all__ = ["MapAsyncIterator"] 8 9 10# noinspection PyAttributeOutsideInit 11class MapAsyncIterator: 12 """Map an AsyncIterable over a callback function. 13 14 Given an AsyncIterable and a callback function, return an AsyncIterator which 15 produces values mapped via calling the callback function. 16 17 When the resulting AsyncIterator is closed, the underlying AsyncIterable will also 18 be closed. 19 """ 20 21 def __init__( 22 self, 23 iterable: AsyncIterable, 24 callback: Callable, 25 reject_callback: Optional[Callable] = None, 26 ) -> None: 27 self.iterator = iterable.__aiter__() 28 self.callback = callback 29 self.reject_callback = reject_callback 30 self._close_event = Event() 31 32 def __aiter__(self) -> "MapAsyncIterator": 33 """Get the iterator object.""" 34 return self 35 36 async def __anext__(self) -> Any: 37 """Get the next value of the iterator.""" 38 if self.is_closed: 39 if not isasyncgen(self.iterator): 40 raise StopAsyncIteration 41 value = await self.iterator.__anext__() 42 result = self.callback(value) 43 44 else: 45 aclose = ensure_future(self._close_event.wait()) 46 anext = ensure_future(self.iterator.__anext__()) 47 48 try: 49 pending: Set[Task] = ( 50 await wait([aclose, anext], return_when=FIRST_COMPLETED) 51 )[1] 52 except CancelledError: 53 # cancel underlying tasks and close 54 aclose.cancel() 55 anext.cancel() 56 await self.aclose() 57 raise # re-raise the cancellation 58 59 for task in pending: 60 task.cancel() 61 62 if aclose.done(): 63 raise StopAsyncIteration 64 65 error = anext.exception() 66 if error: 67 if not self.reject_callback or isinstance( 68 error, (StopAsyncIteration, GeneratorExit) 69 ): 70 raise error 71 result = self.reject_callback(error) 72 else: 73 value = anext.result() 74 result = self.callback(value) 75 76 return await result if isawaitable(result) else result 77 78 async def athrow( 79 self, 80 type_: Union[BaseException, Type[BaseException]], 81 value: Optional[BaseException] = None, 82 traceback: Optional[TracebackType] = None, 83 ) -> None: 84 """Throw an exception into the asynchronous iterator.""" 85 if not self.is_closed: 86 athrow = getattr(self.iterator, "athrow", None) 87 if athrow: 88 await athrow(type_, value, traceback) 89 else: 90 await self.aclose() 91 if value is None: 92 if traceback is None: 93 raise type_ 94 value = ( 95 type_ 96 if isinstance(value, BaseException) 97 else cast(Type[BaseException], type_)() 98 ) 99 if traceback is not None: 100 value = value.with_traceback(traceback) 101 raise value 102 103 async def aclose(self) -> None: 104 """Close the iterator.""" 105 if not self.is_closed: 106 aclose = getattr(self.iterator, "aclose", None) 107 if aclose: 108 try: 109 await aclose() 110 except RuntimeError: 111 pass 112 self.is_closed = True 113 114 @property 115 def is_closed(self) -> bool: 116 """Check whether the iterator is closed.""" 117 return self._close_event.is_set() 118 119 @is_closed.setter 120 def is_closed(self, value: bool) -> None: 121 """Mark the iterator as closed.""" 122 if value: 123 self._close_event.set() 124 else: 125 self._close_event.clear() 126