1from asyncio import CancelledError, Event, Task, ensure_future, wait
2from concurrent.futures import FIRST_COMPLETED
3from inspect import isasyncgen, isawaitable
4from typing import cast, Any, AsyncIterable, Callable, Optional, Set, Type, Union
5from types import TracebackType
6
7__all__ = ["MapAsyncIterator"]
8
9
10# noinspection PyAttributeOutsideInit
11class MapAsyncIterator:
12    """Map an AsyncIterable over a callback function.
13
14    Given an AsyncIterable and a callback function, return an AsyncIterator which
15    produces values mapped via calling the callback function.
16
17    When the resulting AsyncIterator is closed, the underlying AsyncIterable will also
18    be closed.
19    """
20
21    def __init__(
22        self,
23        iterable: AsyncIterable,
24        callback: Callable,
25        reject_callback: Optional[Callable] = None,
26    ) -> None:
27        self.iterator = iterable.__aiter__()
28        self.callback = callback
29        self.reject_callback = reject_callback
30        self._close_event = Event()
31
32    def __aiter__(self) -> "MapAsyncIterator":
33        """Get the iterator object."""
34        return self
35
36    async def __anext__(self) -> Any:
37        """Get the next value of the iterator."""
38        if self.is_closed:
39            if not isasyncgen(self.iterator):
40                raise StopAsyncIteration
41            value = await self.iterator.__anext__()
42            result = self.callback(value)
43
44        else:
45            aclose = ensure_future(self._close_event.wait())
46            anext = ensure_future(self.iterator.__anext__())
47
48            try:
49                pending: Set[Task] = (
50                    await wait([aclose, anext], return_when=FIRST_COMPLETED)
51                )[1]
52            except CancelledError:
53                # cancel underlying tasks and close
54                aclose.cancel()
55                anext.cancel()
56                await self.aclose()
57                raise  # re-raise the cancellation
58
59            for task in pending:
60                task.cancel()
61
62            if aclose.done():
63                raise StopAsyncIteration
64
65            error = anext.exception()
66            if error:
67                if not self.reject_callback or isinstance(
68                    error, (StopAsyncIteration, GeneratorExit)
69                ):
70                    raise error
71                result = self.reject_callback(error)
72            else:
73                value = anext.result()
74                result = self.callback(value)
75
76        return await result if isawaitable(result) else result
77
78    async def athrow(
79        self,
80        type_: Union[BaseException, Type[BaseException]],
81        value: Optional[BaseException] = None,
82        traceback: Optional[TracebackType] = None,
83    ) -> None:
84        """Throw an exception into the asynchronous iterator."""
85        if not self.is_closed:
86            athrow = getattr(self.iterator, "athrow", None)
87            if athrow:
88                await athrow(type_, value, traceback)
89            else:
90                await self.aclose()
91                if value is None:
92                    if traceback is None:
93                        raise type_
94                    value = (
95                        type_
96                        if isinstance(value, BaseException)
97                        else cast(Type[BaseException], type_)()
98                    )
99                if traceback is not None:
100                    value = value.with_traceback(traceback)
101                raise value
102
103    async def aclose(self) -> None:
104        """Close the iterator."""
105        if not self.is_closed:
106            aclose = getattr(self.iterator, "aclose", None)
107            if aclose:
108                try:
109                    await aclose()
110                except RuntimeError:
111                    pass
112            self.is_closed = True
113
114    @property
115    def is_closed(self) -> bool:
116        """Check whether the iterator is closed."""
117        return self._close_event.is_set()
118
119    @is_closed.setter
120    def is_closed(self, value: bool) -> None:
121        """Mark the iterator as closed."""
122        if value:
123            self._close_event.set()
124        else:
125            self._close_event.clear()
126