1import inspect
2import warnings
3from json import dumps as json_dumps
4from typing import (
5    Any,
6    AsyncIterable,
7    AsyncIterator,
8    Dict,
9    Iterable,
10    Iterator,
11    Tuple,
12    Union,
13)
14from urllib.parse import urlencode
15
16from ._exceptions import StreamClosed, StreamConsumed
17from ._multipart import MultipartStream
18from ._types import (
19    AsyncByteStream,
20    RequestContent,
21    RequestData,
22    RequestFiles,
23    ResponseContent,
24    SyncByteStream,
25)
26from ._utils import peek_filelike_length, primitive_value_to_str
27
28
29class ByteStream(AsyncByteStream, SyncByteStream):
30    def __init__(self, stream: bytes) -> None:
31        self._stream = stream
32
33    def __iter__(self) -> Iterator[bytes]:
34        yield self._stream
35
36    async def __aiter__(self) -> AsyncIterator[bytes]:
37        yield self._stream
38
39
40class IteratorByteStream(SyncByteStream):
41    def __init__(self, stream: Iterable[bytes]):
42        self._stream = stream
43        self._is_stream_consumed = False
44        self._is_generator = inspect.isgenerator(stream)
45
46    def __iter__(self) -> Iterator[bytes]:
47        if self._is_stream_consumed and self._is_generator:
48            raise StreamConsumed()
49
50        self._is_stream_consumed = True
51        for part in self._stream:
52            yield part
53
54
55class AsyncIteratorByteStream(AsyncByteStream):
56    def __init__(self, stream: AsyncIterable[bytes]):
57        self._stream = stream
58        self._is_stream_consumed = False
59        self._is_generator = inspect.isasyncgen(stream)
60
61    async def __aiter__(self) -> AsyncIterator[bytes]:
62        if self._is_stream_consumed and self._is_generator:
63            raise StreamConsumed()
64
65        self._is_stream_consumed = True
66        async for part in self._stream:
67            yield part
68
69
70class UnattachedStream(AsyncByteStream, SyncByteStream):
71    """
72    If a request or response is serialized using pickle, then it is no longer
73    attached to a stream for I/O purposes. Any stream operations should result
74    in `httpx.StreamClosed`.
75    """
76
77    def __iter__(self) -> Iterator[bytes]:
78        raise StreamClosed()
79
80    async def __aiter__(self) -> AsyncIterator[bytes]:
81        raise StreamClosed()
82        yield b""  # pragma: nocover
83
84
85def encode_content(
86    content: Union[str, bytes, Iterable[bytes], AsyncIterable[bytes]]
87) -> Tuple[Dict[str, str], Union[SyncByteStream, AsyncByteStream]]:
88
89    if isinstance(content, (bytes, str)):
90        body = content.encode("utf-8") if isinstance(content, str) else content
91        content_length = len(body)
92        headers = {"Content-Length": str(content_length)} if body else {}
93        return headers, ByteStream(body)
94
95    elif isinstance(content, Iterable):
96        content_length_or_none = peek_filelike_length(content)
97
98        if content_length_or_none is None:
99            headers = {"Transfer-Encoding": "chunked"}
100        else:
101            headers = {"Content-Length": str(content_length_or_none)}
102        return headers, IteratorByteStream(content)  # type: ignore
103
104    elif isinstance(content, AsyncIterable):
105        headers = {"Transfer-Encoding": "chunked"}
106        return headers, AsyncIteratorByteStream(content)
107
108    raise TypeError(f"Unexpected type for 'content', {type(content)!r}")
109
110
111def encode_urlencoded_data(
112    data: dict,
113) -> Tuple[Dict[str, str], ByteStream]:
114    plain_data = []
115    for key, value in data.items():
116        if isinstance(value, (list, tuple)):
117            plain_data.extend([(key, primitive_value_to_str(item)) for item in value])
118        else:
119            plain_data.append((key, primitive_value_to_str(value)))
120    body = urlencode(plain_data, doseq=True).encode("utf-8")
121    content_length = str(len(body))
122    content_type = "application/x-www-form-urlencoded"
123    headers = {"Content-Length": content_length, "Content-Type": content_type}
124    return headers, ByteStream(body)
125
126
127def encode_multipart_data(
128    data: dict, files: RequestFiles, boundary: bytes = None
129) -> Tuple[Dict[str, str], MultipartStream]:
130    multipart = MultipartStream(data=data, files=files, boundary=boundary)
131    headers = multipart.get_headers()
132    return headers, multipart
133
134
135def encode_text(text: str) -> Tuple[Dict[str, str], ByteStream]:
136    body = text.encode("utf-8")
137    content_length = str(len(body))
138    content_type = "text/plain; charset=utf-8"
139    headers = {"Content-Length": content_length, "Content-Type": content_type}
140    return headers, ByteStream(body)
141
142
143def encode_html(html: str) -> Tuple[Dict[str, str], ByteStream]:
144    body = html.encode("utf-8")
145    content_length = str(len(body))
146    content_type = "text/html; charset=utf-8"
147    headers = {"Content-Length": content_length, "Content-Type": content_type}
148    return headers, ByteStream(body)
149
150
151def encode_json(json: Any) -> Tuple[Dict[str, str], ByteStream]:
152    body = json_dumps(json).encode("utf-8")
153    content_length = str(len(body))
154    content_type = "application/json"
155    headers = {"Content-Length": content_length, "Content-Type": content_type}
156    return headers, ByteStream(body)
157
158
159def encode_request(
160    content: RequestContent = None,
161    data: RequestData = None,
162    files: RequestFiles = None,
163    json: Any = None,
164    boundary: bytes = None,
165) -> Tuple[Dict[str, str], Union[SyncByteStream, AsyncByteStream]]:
166    """
167    Handles encoding the given `content`, `data`, `files`, and `json`,
168    returning a two-tuple of (<headers>, <stream>).
169    """
170    if data is not None and not isinstance(data, dict):
171        # We prefer to seperate `content=<bytes|str|byte iterator|bytes aiterator>`
172        # for raw request content, and `data=<form data>` for url encoded or
173        # multipart form content.
174        #
175        # However for compat with requests, we *do* still support
176        # `data=<bytes...>` usages. We deal with that case here, treating it
177        # as if `content=<...>` had been supplied instead.
178        message = "Use 'content=<...>' to upload raw bytes/text content."
179        warnings.warn(message, DeprecationWarning)
180        return encode_content(data)
181
182    if content is not None:
183        return encode_content(content)
184    elif files:
185        return encode_multipart_data(data or {}, files, boundary)
186    elif data:
187        return encode_urlencoded_data(data)
188    elif json is not None:
189        return encode_json(json)
190
191    return {}, ByteStream(b"")
192
193
194def encode_response(
195    content: ResponseContent = None,
196    text: str = None,
197    html: str = None,
198    json: Any = None,
199) -> Tuple[Dict[str, str], Union[SyncByteStream, AsyncByteStream]]:
200    """
201    Handles encoding the given `content`, returning a two-tuple of
202    (<headers>, <stream>).
203    """
204    if content is not None:
205        return encode_content(content)
206    elif text is not None:
207        return encode_text(text)
208    elif html is not None:
209        return encode_html(html)
210    elif json is not None:
211        return encode_json(json)
212
213    return {}, ByteStream(b"")
214