1import asyncio
2import collections
3import warnings
4from typing import Awaitable, Callable, Generic, List, Optional, Tuple, TypeVar
5
6from .base_protocol import BaseProtocol
7from .helpers import BaseTimerContext, set_exception, set_result
8from .log import internal_logger
9
10try:  # pragma: no cover
11    from typing import Deque
12except ImportError:
13    from typing_extensions import Deque
14
15__all__ = (
16    "EMPTY_PAYLOAD",
17    "EofStream",
18    "StreamReader",
19    "DataQueue",
20    "FlowControlDataQueue",
21)
22
23_T = TypeVar("_T")
24
25
26class EofStream(Exception):
27    """eof stream indication."""
28
29
30class AsyncStreamIterator(Generic[_T]):
31    def __init__(self, read_func: Callable[[], Awaitable[_T]]) -> None:
32        self.read_func = read_func
33
34    def __aiter__(self) -> "AsyncStreamIterator[_T]":
35        return self
36
37    async def __anext__(self) -> _T:
38        try:
39            rv = await self.read_func()
40        except EofStream:
41            raise StopAsyncIteration
42        if rv == b"":
43            raise StopAsyncIteration
44        return rv
45
46
47class ChunkTupleAsyncStreamIterator:
48    def __init__(self, stream: "StreamReader") -> None:
49        self._stream = stream
50
51    def __aiter__(self) -> "ChunkTupleAsyncStreamIterator":
52        return self
53
54    async def __anext__(self) -> Tuple[bytes, bool]:
55        rv = await self._stream.readchunk()
56        if rv == (b"", False):
57            raise StopAsyncIteration
58        return rv
59
60
61class AsyncStreamReaderMixin:
62    def __aiter__(self) -> AsyncStreamIterator[bytes]:
63        return AsyncStreamIterator(self.readline)  # type: ignore
64
65    def iter_chunked(self, n: int) -> AsyncStreamIterator[bytes]:
66        """Returns an asynchronous iterator that yields chunks of size n.
67
68        Python-3.5 available for Python 3.5+ only
69        """
70        return AsyncStreamIterator(lambda: self.read(n))  # type: ignore
71
72    def iter_any(self) -> AsyncStreamIterator[bytes]:
73        """Returns an asynchronous iterator that yields all the available
74        data as soon as it is received
75
76        Python-3.5 available for Python 3.5+ only
77        """
78        return AsyncStreamIterator(self.readany)  # type: ignore
79
80    def iter_chunks(self) -> ChunkTupleAsyncStreamIterator:
81        """Returns an asynchronous iterator that yields chunks of data
82        as they are received by the server. The yielded objects are tuples
83        of (bytes, bool) as returned by the StreamReader.readchunk method.
84
85        Python-3.5 available for Python 3.5+ only
86        """
87        return ChunkTupleAsyncStreamIterator(self)  # type: ignore
88
89
90class StreamReader(AsyncStreamReaderMixin):
91    """An enhancement of asyncio.StreamReader.
92
93    Supports asynchronous iteration by line, chunk or as available::
94
95        async for line in reader:
96            ...
97        async for chunk in reader.iter_chunked(1024):
98            ...
99        async for slice in reader.iter_any():
100            ...
101
102    """
103
104    total_bytes = 0
105
106    def __init__(
107        self,
108        protocol: BaseProtocol,
109        limit: int,
110        *,
111        timer: Optional[BaseTimerContext] = None,
112        loop: Optional[asyncio.AbstractEventLoop] = None
113    ) -> None:
114        self._protocol = protocol
115        self._low_water = limit
116        self._high_water = limit * 2
117        if loop is None:
118            loop = asyncio.get_event_loop()
119        self._loop = loop
120        self._size = 0
121        self._cursor = 0
122        self._http_chunk_splits = None  # type: Optional[List[int]]
123        self._buffer = collections.deque()  # type: Deque[bytes]
124        self._buffer_offset = 0
125        self._eof = False
126        self._waiter = None  # type: Optional[asyncio.Future[None]]
127        self._eof_waiter = None  # type: Optional[asyncio.Future[None]]
128        self._exception = None  # type: Optional[BaseException]
129        self._timer = timer
130        self._eof_callbacks = []  # type: List[Callable[[], None]]
131
132    def __repr__(self) -> str:
133        info = [self.__class__.__name__]
134        if self._size:
135            info.append("%d bytes" % self._size)
136        if self._eof:
137            info.append("eof")
138        if self._low_water != 2 ** 16:  # default limit
139            info.append("low=%d high=%d" % (self._low_water, self._high_water))
140        if self._waiter:
141            info.append("w=%r" % self._waiter)
142        if self._exception:
143            info.append("e=%r" % self._exception)
144        return "<%s>" % " ".join(info)
145
146    def get_read_buffer_limits(self) -> Tuple[int, int]:
147        return (self._low_water, self._high_water)
148
149    def exception(self) -> Optional[BaseException]:
150        return self._exception
151
152    def set_exception(self, exc: BaseException) -> None:
153        self._exception = exc
154        self._eof_callbacks.clear()
155
156        waiter = self._waiter
157        if waiter is not None:
158            self._waiter = None
159            set_exception(waiter, exc)
160
161        waiter = self._eof_waiter
162        if waiter is not None:
163            self._eof_waiter = None
164            set_exception(waiter, exc)
165
166    def on_eof(self, callback: Callable[[], None]) -> None:
167        if self._eof:
168            try:
169                callback()
170            except Exception:
171                internal_logger.exception("Exception in eof callback")
172        else:
173            self._eof_callbacks.append(callback)
174
175    def feed_eof(self) -> None:
176        self._eof = True
177
178        waiter = self._waiter
179        if waiter is not None:
180            self._waiter = None
181            set_result(waiter, None)
182
183        waiter = self._eof_waiter
184        if waiter is not None:
185            self._eof_waiter = None
186            set_result(waiter, None)
187
188        for cb in self._eof_callbacks:
189            try:
190                cb()
191            except Exception:
192                internal_logger.exception("Exception in eof callback")
193
194        self._eof_callbacks.clear()
195
196    def is_eof(self) -> bool:
197        """Return True if  'feed_eof' was called."""
198        return self._eof
199
200    def at_eof(self) -> bool:
201        """Return True if the buffer is empty and 'feed_eof' was called."""
202        return self._eof and not self._buffer
203
204    async def wait_eof(self) -> None:
205        if self._eof:
206            return
207
208        assert self._eof_waiter is None
209        self._eof_waiter = self._loop.create_future()
210        try:
211            await self._eof_waiter
212        finally:
213            self._eof_waiter = None
214
215    def unread_data(self, data: bytes) -> None:
216        """rollback reading some data from stream, inserting it to buffer head."""
217        warnings.warn(
218            "unread_data() is deprecated "
219            "and will be removed in future releases (#3260)",
220            DeprecationWarning,
221            stacklevel=2,
222        )
223        if not data:
224            return
225
226        if self._buffer_offset:
227            self._buffer[0] = self._buffer[0][self._buffer_offset :]
228            self._buffer_offset = 0
229        self._size += len(data)
230        self._cursor -= len(data)
231        self._buffer.appendleft(data)
232        self._eof_counter = 0
233
234    # TODO: size is ignored, remove the param later
235    def feed_data(self, data: bytes, size: int = 0) -> None:
236        assert not self._eof, "feed_data after feed_eof"
237
238        if not data:
239            return
240
241        self._size += len(data)
242        self._buffer.append(data)
243        self.total_bytes += len(data)
244
245        waiter = self._waiter
246        if waiter is not None:
247            self._waiter = None
248            set_result(waiter, None)
249
250        if self._size > self._high_water and not self._protocol._reading_paused:
251            self._protocol.pause_reading()
252
253    def begin_http_chunk_receiving(self) -> None:
254        if self._http_chunk_splits is None:
255            if self.total_bytes:
256                raise RuntimeError(
257                    "Called begin_http_chunk_receiving when" "some data was already fed"
258                )
259            self._http_chunk_splits = []
260
261    def end_http_chunk_receiving(self) -> None:
262        if self._http_chunk_splits is None:
263            raise RuntimeError(
264                "Called end_chunk_receiving without calling "
265                "begin_chunk_receiving first"
266            )
267
268        # self._http_chunk_splits contains logical byte offsets from start of
269        # the body transfer. Each offset is the offset of the end of a chunk.
270        # "Logical" means bytes, accessible for a user.
271        # If no chunks containig logical data were received, current position
272        # is difinitely zero.
273        pos = self._http_chunk_splits[-1] if self._http_chunk_splits else 0
274
275        if self.total_bytes == pos:
276            # We should not add empty chunks here. So we check for that.
277            # Note, when chunked + gzip is used, we can receive a chunk
278            # of compressed data, but that data may not be enough for gzip FSM
279            # to yield any uncompressed data. That's why current position may
280            # not change after receiving a chunk.
281            return
282
283        self._http_chunk_splits.append(self.total_bytes)
284
285        # wake up readchunk when end of http chunk received
286        waiter = self._waiter
287        if waiter is not None:
288            self._waiter = None
289            set_result(waiter, None)
290
291    async def _wait(self, func_name: str) -> None:
292        # StreamReader uses a future to link the protocol feed_data() method
293        # to a read coroutine. Running two read coroutines at the same time
294        # would have an unexpected behaviour. It would not possible to know
295        # which coroutine would get the next data.
296        if self._waiter is not None:
297            raise RuntimeError(
298                "%s() called while another coroutine is "
299                "already waiting for incoming data" % func_name
300            )
301
302        waiter = self._waiter = self._loop.create_future()
303        try:
304            if self._timer:
305                with self._timer:
306                    await waiter
307            else:
308                await waiter
309        finally:
310            self._waiter = None
311
312    async def readline(self) -> bytes:
313        if self._exception is not None:
314            raise self._exception
315
316        line = []
317        line_size = 0
318        not_enough = True
319
320        while not_enough:
321            while self._buffer and not_enough:
322                offset = self._buffer_offset
323                ichar = self._buffer[0].find(b"\n", offset) + 1
324                # Read from current offset to found b'\n' or to the end.
325                data = self._read_nowait_chunk(ichar - offset if ichar else -1)
326                line.append(data)
327                line_size += len(data)
328                if ichar:
329                    not_enough = False
330
331                if line_size > self._high_water:
332                    raise ValueError("Line is too long")
333
334            if self._eof:
335                break
336
337            if not_enough:
338                await self._wait("readline")
339
340        return b"".join(line)
341
342    async def read(self, n: int = -1) -> bytes:
343        if self._exception is not None:
344            raise self._exception
345
346        # migration problem; with DataQueue you have to catch
347        # EofStream exception, so common way is to run payload.read() inside
348        # infinite loop. what can cause real infinite loop with StreamReader
349        # lets keep this code one major release.
350        if __debug__:
351            if self._eof and not self._buffer:
352                self._eof_counter = getattr(self, "_eof_counter", 0) + 1
353                if self._eof_counter > 5:
354                    internal_logger.warning(
355                        "Multiple access to StreamReader in eof state, "
356                        "might be infinite loop.",
357                        stack_info=True,
358                    )
359
360        if not n:
361            return b""
362
363        if n < 0:
364            # This used to just loop creating a new waiter hoping to
365            # collect everything in self._buffer, but that would
366            # deadlock if the subprocess sends more than self.limit
367            # bytes.  So just call self.readany() until EOF.
368            blocks = []
369            while True:
370                block = await self.readany()
371                if not block:
372                    break
373                blocks.append(block)
374            return b"".join(blocks)
375
376        # TODO: should be `if` instead of `while`
377        # because waiter maybe triggered on chunk end,
378        # without feeding any data
379        while not self._buffer and not self._eof:
380            await self._wait("read")
381
382        return self._read_nowait(n)
383
384    async def readany(self) -> bytes:
385        if self._exception is not None:
386            raise self._exception
387
388        # TODO: should be `if` instead of `while`
389        # because waiter maybe triggered on chunk end,
390        # without feeding any data
391        while not self._buffer and not self._eof:
392            await self._wait("readany")
393
394        return self._read_nowait(-1)
395
396    async def readchunk(self) -> Tuple[bytes, bool]:
397        """Returns a tuple of (data, end_of_http_chunk). When chunked transfer
398        encoding is used, end_of_http_chunk is a boolean indicating if the end
399        of the data corresponds to the end of a HTTP chunk , otherwise it is
400        always False.
401        """
402        while True:
403            if self._exception is not None:
404                raise self._exception
405
406            while self._http_chunk_splits:
407                pos = self._http_chunk_splits.pop(0)
408                if pos == self._cursor:
409                    return (b"", True)
410                if pos > self._cursor:
411                    return (self._read_nowait(pos - self._cursor), True)
412                internal_logger.warning(
413                    "Skipping HTTP chunk end due to data "
414                    "consumption beyond chunk boundary"
415                )
416
417            if self._buffer:
418                return (self._read_nowait_chunk(-1), False)
419                # return (self._read_nowait(-1), False)
420
421            if self._eof:
422                # Special case for signifying EOF.
423                # (b'', True) is not a final return value actually.
424                return (b"", False)
425
426            await self._wait("readchunk")
427
428    async def readexactly(self, n: int) -> bytes:
429        if self._exception is not None:
430            raise self._exception
431
432        blocks = []  # type: List[bytes]
433        while n > 0:
434            block = await self.read(n)
435            if not block:
436                partial = b"".join(blocks)
437                raise asyncio.IncompleteReadError(partial, len(partial) + n)
438            blocks.append(block)
439            n -= len(block)
440
441        return b"".join(blocks)
442
443    def read_nowait(self, n: int = -1) -> bytes:
444        # default was changed to be consistent with .read(-1)
445        #
446        # I believe the most users don't know about the method and
447        # they are not affected.
448        if self._exception is not None:
449            raise self._exception
450
451        if self._waiter and not self._waiter.done():
452            raise RuntimeError(
453                "Called while some coroutine is waiting for incoming data."
454            )
455
456        return self._read_nowait(n)
457
458    def _read_nowait_chunk(self, n: int) -> bytes:
459        first_buffer = self._buffer[0]
460        offset = self._buffer_offset
461        if n != -1 and len(first_buffer) - offset > n:
462            data = first_buffer[offset : offset + n]
463            self._buffer_offset += n
464
465        elif offset:
466            self._buffer.popleft()
467            data = first_buffer[offset:]
468            self._buffer_offset = 0
469
470        else:
471            data = self._buffer.popleft()
472
473        self._size -= len(data)
474        self._cursor += len(data)
475
476        chunk_splits = self._http_chunk_splits
477        # Prevent memory leak: drop useless chunk splits
478        while chunk_splits and chunk_splits[0] < self._cursor:
479            chunk_splits.pop(0)
480
481        if self._size < self._low_water and self._protocol._reading_paused:
482            self._protocol.resume_reading()
483        return data
484
485    def _read_nowait(self, n: int) -> bytes:
486        """ Read not more than n bytes, or whole buffer if n == -1 """
487        chunks = []
488
489        while self._buffer:
490            chunk = self._read_nowait_chunk(n)
491            chunks.append(chunk)
492            if n != -1:
493                n -= len(chunk)
494                if n == 0:
495                    break
496
497        return b"".join(chunks) if chunks else b""
498
499
500class EmptyStreamReader(AsyncStreamReaderMixin):
501    def exception(self) -> Optional[BaseException]:
502        return None
503
504    def set_exception(self, exc: BaseException) -> None:
505        pass
506
507    def on_eof(self, callback: Callable[[], None]) -> None:
508        try:
509            callback()
510        except Exception:
511            internal_logger.exception("Exception in eof callback")
512
513    def feed_eof(self) -> None:
514        pass
515
516    def is_eof(self) -> bool:
517        return True
518
519    def at_eof(self) -> bool:
520        return True
521
522    async def wait_eof(self) -> None:
523        return
524
525    def feed_data(self, data: bytes, n: int = 0) -> None:
526        pass
527
528    async def readline(self) -> bytes:
529        return b""
530
531    async def read(self, n: int = -1) -> bytes:
532        return b""
533
534    async def readany(self) -> bytes:
535        return b""
536
537    async def readchunk(self) -> Tuple[bytes, bool]:
538        return (b"", True)
539
540    async def readexactly(self, n: int) -> bytes:
541        raise asyncio.IncompleteReadError(b"", n)
542
543    def read_nowait(self) -> bytes:
544        return b""
545
546
547EMPTY_PAYLOAD = EmptyStreamReader()
548
549
550class DataQueue(Generic[_T]):
551    """DataQueue is a general-purpose blocking queue with one reader."""
552
553    def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
554        self._loop = loop
555        self._eof = False
556        self._waiter = None  # type: Optional[asyncio.Future[None]]
557        self._exception = None  # type: Optional[BaseException]
558        self._size = 0
559        self._buffer = collections.deque()  # type: Deque[Tuple[_T, int]]
560
561    def __len__(self) -> int:
562        return len(self._buffer)
563
564    def is_eof(self) -> bool:
565        return self._eof
566
567    def at_eof(self) -> bool:
568        return self._eof and not self._buffer
569
570    def exception(self) -> Optional[BaseException]:
571        return self._exception
572
573    def set_exception(self, exc: BaseException) -> None:
574        self._eof = True
575        self._exception = exc
576
577        waiter = self._waiter
578        if waiter is not None:
579            self._waiter = None
580            set_exception(waiter, exc)
581
582    def feed_data(self, data: _T, size: int = 0) -> None:
583        self._size += size
584        self._buffer.append((data, size))
585
586        waiter = self._waiter
587        if waiter is not None:
588            self._waiter = None
589            set_result(waiter, None)
590
591    def feed_eof(self) -> None:
592        self._eof = True
593
594        waiter = self._waiter
595        if waiter is not None:
596            self._waiter = None
597            set_result(waiter, None)
598
599    async def read(self) -> _T:
600        if not self._buffer and not self._eof:
601            assert not self._waiter
602            self._waiter = self._loop.create_future()
603            try:
604                await self._waiter
605            except (asyncio.CancelledError, asyncio.TimeoutError):
606                self._waiter = None
607                raise
608
609        if self._buffer:
610            data, size = self._buffer.popleft()
611            self._size -= size
612            return data
613        else:
614            if self._exception is not None:
615                raise self._exception
616            else:
617                raise EofStream
618
619    def __aiter__(self) -> AsyncStreamIterator[_T]:
620        return AsyncStreamIterator(self.read)
621
622
623class FlowControlDataQueue(DataQueue[_T]):
624    """FlowControlDataQueue resumes and pauses an underlying stream.
625
626    It is a destination for parsed data."""
627
628    def __init__(
629        self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop
630    ) -> None:
631        super().__init__(loop=loop)
632
633        self._protocol = protocol
634        self._limit = limit * 2
635
636    def feed_data(self, data: _T, size: int = 0) -> None:
637        super().feed_data(data, size)
638
639        if self._size > self._limit and not self._protocol._reading_paused:
640            self._protocol.pause_reading()
641
642    async def read(self) -> _T:
643        try:
644            return await super().read()
645        finally:
646            if self._size < self._limit and self._protocol._reading_paused:
647                self._protocol.resume_reading()
648