1import base64
2import binascii
3import json
4import re
5import uuid
6import warnings
7import zlib
8from collections import deque
9from types import TracebackType
10from typing import (
11    TYPE_CHECKING,
12    Any,
13    AsyncIterator,
14    Dict,
15    Iterator,
16    List,
17    Mapping,
18    Optional,
19    Sequence,
20    Tuple,
21    Type,
22    Union,
23)
24from urllib.parse import parse_qsl, unquote, urlencode
25
26from multidict import CIMultiDict, CIMultiDictProxy, MultiMapping
27
28from .hdrs import (
29    CONTENT_DISPOSITION,
30    CONTENT_ENCODING,
31    CONTENT_LENGTH,
32    CONTENT_TRANSFER_ENCODING,
33    CONTENT_TYPE,
34)
35from .helpers import CHAR, TOKEN, parse_mimetype, reify
36from .http import HeadersParser
37from .payload import (
38    JsonPayload,
39    LookupError,
40    Order,
41    Payload,
42    StringPayload,
43    get_payload,
44    payload_type,
45)
46from .streams import StreamReader
47
48__all__ = (
49    "MultipartReader",
50    "MultipartWriter",
51    "BodyPartReader",
52    "BadContentDispositionHeader",
53    "BadContentDispositionParam",
54    "parse_content_disposition",
55    "content_disposition_filename",
56)
57
58
59if TYPE_CHECKING:  # pragma: no cover
60    from .client_reqrep import ClientResponse
61
62
63class BadContentDispositionHeader(RuntimeWarning):
64    pass
65
66
67class BadContentDispositionParam(RuntimeWarning):
68    pass
69
70
71def parse_content_disposition(
72    header: Optional[str],
73) -> Tuple[Optional[str], Dict[str, str]]:
74    def is_token(string: str) -> bool:
75        return bool(string) and TOKEN >= set(string)
76
77    def is_quoted(string: str) -> bool:
78        return string[0] == string[-1] == '"'
79
80    def is_rfc5987(string: str) -> bool:
81        return is_token(string) and string.count("'") == 2
82
83    def is_extended_param(string: str) -> bool:
84        return string.endswith("*")
85
86    def is_continuous_param(string: str) -> bool:
87        pos = string.find("*") + 1
88        if not pos:
89            return False
90        substring = string[pos:-1] if string.endswith("*") else string[pos:]
91        return substring.isdigit()
92
93    def unescape(text: str, *, chars: str = "".join(map(re.escape, CHAR))) -> str:
94        return re.sub(f"\\\\([{chars}])", "\\1", text)
95
96    if not header:
97        return None, {}
98
99    disptype, *parts = header.split(";")
100    if not is_token(disptype):
101        warnings.warn(BadContentDispositionHeader(header))
102        return None, {}
103
104    params = {}  # type: Dict[str, str]
105    while parts:
106        item = parts.pop(0)
107
108        if "=" not in item:
109            warnings.warn(BadContentDispositionHeader(header))
110            return None, {}
111
112        key, value = item.split("=", 1)
113        key = key.lower().strip()
114        value = value.lstrip()
115
116        if key in params:
117            warnings.warn(BadContentDispositionHeader(header))
118            return None, {}
119
120        if not is_token(key):
121            warnings.warn(BadContentDispositionParam(item))
122            continue
123
124        elif is_continuous_param(key):
125            if is_quoted(value):
126                value = unescape(value[1:-1])
127            elif not is_token(value):
128                warnings.warn(BadContentDispositionParam(item))
129                continue
130
131        elif is_extended_param(key):
132            if is_rfc5987(value):
133                encoding, _, value = value.split("'", 2)
134                encoding = encoding or "utf-8"
135            else:
136                warnings.warn(BadContentDispositionParam(item))
137                continue
138
139            try:
140                value = unquote(value, encoding, "strict")
141            except UnicodeDecodeError:  # pragma: nocover
142                warnings.warn(BadContentDispositionParam(item))
143                continue
144
145        else:
146            failed = True
147            if is_quoted(value):
148                failed = False
149                value = unescape(value[1:-1].lstrip("\\/"))
150            elif is_token(value):
151                failed = False
152            elif parts:
153                # maybe just ; in filename, in any case this is just
154                # one case fix, for proper fix we need to redesign parser
155                _value = "{};{}".format(value, parts[0])
156                if is_quoted(_value):
157                    parts.pop(0)
158                    value = unescape(_value[1:-1].lstrip("\\/"))
159                    failed = False
160
161            if failed:
162                warnings.warn(BadContentDispositionHeader(header))
163                return None, {}
164
165        params[key] = value
166
167    return disptype.lower(), params
168
169
170def content_disposition_filename(
171    params: Mapping[str, str], name: str = "filename"
172) -> Optional[str]:
173    name_suf = "%s*" % name
174    if not params:
175        return None
176    elif name_suf in params:
177        return params[name_suf]
178    elif name in params:
179        return params[name]
180    else:
181        parts = []
182        fnparams = sorted(
183            (key, value) for key, value in params.items() if key.startswith(name_suf)
184        )
185        for num, (key, value) in enumerate(fnparams):
186            _, tail = key.split("*", 1)
187            if tail.endswith("*"):
188                tail = tail[:-1]
189            if tail == str(num):
190                parts.append(value)
191            else:
192                break
193        if not parts:
194            return None
195        value = "".join(parts)
196        if "'" in value:
197            encoding, _, value = value.split("'", 2)
198            encoding = encoding or "utf-8"
199            return unquote(value, encoding, "strict")
200        return value
201
202
203class MultipartResponseWrapper:
204    """Wrapper around the MultipartReader.
205
206    It takes care about
207    underlying connection and close it when it needs in.
208    """
209
210    def __init__(
211        self,
212        resp: "ClientResponse",
213        stream: "MultipartReader",
214    ) -> None:
215        self.resp = resp
216        self.stream = stream
217
218    def __aiter__(self) -> "MultipartResponseWrapper":
219        return self
220
221    async def __anext__(
222        self,
223    ) -> Union["MultipartReader", "BodyPartReader"]:
224        part = await self.next()
225        if part is None:
226            raise StopAsyncIteration
227        return part
228
229    def at_eof(self) -> bool:
230        """Returns True when all response data had been read."""
231        return self.resp.content.at_eof()
232
233    async def next(
234        self,
235    ) -> Optional[Union["MultipartReader", "BodyPartReader"]]:
236        """Emits next multipart reader object."""
237        item = await self.stream.next()
238        if self.stream.at_eof():
239            await self.release()
240        return item
241
242    async def release(self) -> None:
243        """Releases the connection gracefully, reading all the content
244        to the void."""
245        await self.resp.release()
246
247
248class BodyPartReader:
249    """Multipart reader for single body part."""
250
251    chunk_size = 8192
252
253    def __init__(
254        self, boundary: bytes, headers: "CIMultiDictProxy[str]", content: StreamReader
255    ) -> None:
256        self.headers = headers
257        self._boundary = boundary
258        self._content = content
259        self._at_eof = False
260        length = self.headers.get(CONTENT_LENGTH, None)
261        self._length = int(length) if length is not None else None
262        self._read_bytes = 0
263        # TODO: typeing.Deque is not supported by Python 3.5
264        self._unread = deque()  # type: Any
265        self._prev_chunk = None  # type: Optional[bytes]
266        self._content_eof = 0
267        self._cache = {}  # type: Dict[str, Any]
268
269    def __aiter__(self) -> AsyncIterator["BodyPartReader"]:
270        return self  # type: ignore
271
272    async def __anext__(self) -> bytes:
273        part = await self.next()
274        if part is None:
275            raise StopAsyncIteration
276        return part
277
278    async def next(self) -> Optional[bytes]:
279        item = await self.read()
280        if not item:
281            return None
282        return item
283
284    async def read(self, *, decode: bool = False) -> bytes:
285        """Reads body part data.
286
287        decode: Decodes data following by encoding
288                method from Content-Encoding header. If it missed
289                data remains untouched
290        """
291        if self._at_eof:
292            return b""
293        data = bytearray()
294        while not self._at_eof:
295            data.extend(await self.read_chunk(self.chunk_size))
296        if decode:
297            return self.decode(data)
298        return data
299
300    async def read_chunk(self, size: int = chunk_size) -> bytes:
301        """Reads body part content chunk of the specified size.
302
303        size: chunk size
304        """
305        if self._at_eof:
306            return b""
307        if self._length:
308            chunk = await self._read_chunk_from_length(size)
309        else:
310            chunk = await self._read_chunk_from_stream(size)
311
312        self._read_bytes += len(chunk)
313        if self._read_bytes == self._length:
314            self._at_eof = True
315        if self._at_eof:
316            clrf = await self._content.readline()
317            assert (
318                b"\r\n" == clrf
319            ), "reader did not read all the data or it is malformed"
320        return chunk
321
322    async def _read_chunk_from_length(self, size: int) -> bytes:
323        # Reads body part content chunk of the specified size.
324        # The body part must has Content-Length header with proper value.
325        assert self._length is not None, "Content-Length required for chunked read"
326        chunk_size = min(size, self._length - self._read_bytes)
327        chunk = await self._content.read(chunk_size)
328        return chunk
329
330    async def _read_chunk_from_stream(self, size: int) -> bytes:
331        # Reads content chunk of body part with unknown length.
332        # The Content-Length header for body part is not necessary.
333        assert (
334            size >= len(self._boundary) + 2
335        ), "Chunk size must be greater or equal than boundary length + 2"
336        first_chunk = self._prev_chunk is None
337        if first_chunk:
338            self._prev_chunk = await self._content.read(size)
339
340        chunk = await self._content.read(size)
341        self._content_eof += int(self._content.at_eof())
342        assert self._content_eof < 3, "Reading after EOF"
343        assert self._prev_chunk is not None
344        window = self._prev_chunk + chunk
345        sub = b"\r\n" + self._boundary
346        if first_chunk:
347            idx = window.find(sub)
348        else:
349            idx = window.find(sub, max(0, len(self._prev_chunk) - len(sub)))
350        if idx >= 0:
351            # pushing boundary back to content
352            with warnings.catch_warnings():
353                warnings.filterwarnings("ignore", category=DeprecationWarning)
354                self._content.unread_data(window[idx:])
355            if size > idx:
356                self._prev_chunk = self._prev_chunk[:idx]
357            chunk = window[len(self._prev_chunk) : idx]
358            if not chunk:
359                self._at_eof = True
360        result = self._prev_chunk
361        self._prev_chunk = chunk
362        return result
363
364    async def readline(self) -> bytes:
365        """Reads body part by line by line."""
366        if self._at_eof:
367            return b""
368
369        if self._unread:
370            line = self._unread.popleft()
371        else:
372            line = await self._content.readline()
373
374        if line.startswith(self._boundary):
375            # the very last boundary may not come with \r\n,
376            # so set single rules for everyone
377            sline = line.rstrip(b"\r\n")
378            boundary = self._boundary
379            last_boundary = self._boundary + b"--"
380            # ensure that we read exactly the boundary, not something alike
381            if sline == boundary or sline == last_boundary:
382                self._at_eof = True
383                self._unread.append(line)
384                return b""
385        else:
386            next_line = await self._content.readline()
387            if next_line.startswith(self._boundary):
388                line = line[:-2]  # strip CRLF but only once
389            self._unread.append(next_line)
390
391        return line
392
393    async def release(self) -> None:
394        """Like read(), but reads all the data to the void."""
395        if self._at_eof:
396            return
397        while not self._at_eof:
398            await self.read_chunk(self.chunk_size)
399
400    async def text(self, *, encoding: Optional[str] = None) -> str:
401        """Like read(), but assumes that body part contains text data."""
402        data = await self.read(decode=True)
403        # see https://www.w3.org/TR/html5/forms.html#multipart/form-data-encoding-algorithm # NOQA
404        # and https://dvcs.w3.org/hg/xhr/raw-file/tip/Overview.html#dom-xmlhttprequest-send # NOQA
405        encoding = encoding or self.get_charset(default="utf-8")
406        return data.decode(encoding)
407
408    async def json(self, *, encoding: Optional[str] = None) -> Optional[Dict[str, Any]]:
409        """Like read(), but assumes that body parts contains JSON data."""
410        data = await self.read(decode=True)
411        if not data:
412            return None
413        encoding = encoding or self.get_charset(default="utf-8")
414        return json.loads(data.decode(encoding))
415
416    async def form(self, *, encoding: Optional[str] = None) -> List[Tuple[str, str]]:
417        """Like read(), but assumes that body parts contains form
418        urlencoded data.
419        """
420        data = await self.read(decode=True)
421        if not data:
422            return []
423        if encoding is not None:
424            real_encoding = encoding
425        else:
426            real_encoding = self.get_charset(default="utf-8")
427        return parse_qsl(
428            data.rstrip().decode(real_encoding),
429            keep_blank_values=True,
430            encoding=real_encoding,
431        )
432
433    def at_eof(self) -> bool:
434        """Returns True if the boundary was reached or False otherwise."""
435        return self._at_eof
436
437    def decode(self, data: bytes) -> bytes:
438        """Decodes data according the specified Content-Encoding
439        or Content-Transfer-Encoding headers value.
440        """
441        if CONTENT_TRANSFER_ENCODING in self.headers:
442            data = self._decode_content_transfer(data)
443        if CONTENT_ENCODING in self.headers:
444            return self._decode_content(data)
445        return data
446
447    def _decode_content(self, data: bytes) -> bytes:
448        encoding = self.headers.get(CONTENT_ENCODING, "").lower()
449
450        if encoding == "deflate":
451            return zlib.decompress(data, -zlib.MAX_WBITS)
452        elif encoding == "gzip":
453            return zlib.decompress(data, 16 + zlib.MAX_WBITS)
454        elif encoding == "identity":
455            return data
456        else:
457            raise RuntimeError(f"unknown content encoding: {encoding}")
458
459    def _decode_content_transfer(self, data: bytes) -> bytes:
460        encoding = self.headers.get(CONTENT_TRANSFER_ENCODING, "").lower()
461
462        if encoding == "base64":
463            return base64.b64decode(data)
464        elif encoding == "quoted-printable":
465            return binascii.a2b_qp(data)
466        elif encoding in ("binary", "8bit", "7bit"):
467            return data
468        else:
469            raise RuntimeError(
470                "unknown content transfer encoding: {}" "".format(encoding)
471            )
472
473    def get_charset(self, default: str) -> str:
474        """Returns charset parameter from Content-Type header or default."""
475        ctype = self.headers.get(CONTENT_TYPE, "")
476        mimetype = parse_mimetype(ctype)
477        return mimetype.parameters.get("charset", default)
478
479    @reify
480    def name(self) -> Optional[str]:
481        """Returns name specified in Content-Disposition header or None
482        if missed or header is malformed.
483        """
484
485        _, params = parse_content_disposition(self.headers.get(CONTENT_DISPOSITION))
486        return content_disposition_filename(params, "name")
487
488    @reify
489    def filename(self) -> Optional[str]:
490        """Returns filename specified in Content-Disposition header or None
491        if missed or header is malformed.
492        """
493        _, params = parse_content_disposition(self.headers.get(CONTENT_DISPOSITION))
494        return content_disposition_filename(params, "filename")
495
496
497@payload_type(BodyPartReader, order=Order.try_first)
498class BodyPartReaderPayload(Payload):
499    def __init__(self, value: BodyPartReader, *args: Any, **kwargs: Any) -> None:
500        super().__init__(value, *args, **kwargs)
501
502        params = {}  # type: Dict[str, str]
503        if value.name is not None:
504            params["name"] = value.name
505        if value.filename is not None:
506            params["filename"] = value.filename
507
508        if params:
509            self.set_content_disposition("attachment", True, **params)
510
511    async def write(self, writer: Any) -> None:
512        field = self._value
513        chunk = await field.read_chunk(size=2 ** 16)
514        while chunk:
515            await writer.write(field.decode(chunk))
516            chunk = await field.read_chunk(size=2 ** 16)
517
518
519class MultipartReader:
520    """Multipart body reader."""
521
522    #: Response wrapper, used when multipart readers constructs from response.
523    response_wrapper_cls = MultipartResponseWrapper
524    #: Multipart reader class, used to handle multipart/* body parts.
525    #: None points to type(self)
526    multipart_reader_cls = None
527    #: Body part reader class for non multipart/* content types.
528    part_reader_cls = BodyPartReader
529
530    def __init__(self, headers: Mapping[str, str], content: StreamReader) -> None:
531        self.headers = headers
532        self._boundary = ("--" + self._get_boundary()).encode()
533        self._content = content
534        self._last_part = (
535            None
536        )  # type: Optional[Union['MultipartReader', BodyPartReader]]
537        self._at_eof = False
538        self._at_bof = True
539        self._unread = []  # type: List[bytes]
540
541    def __aiter__(
542        self,
543    ) -> AsyncIterator["BodyPartReader"]:
544        return self  # type: ignore
545
546    async def __anext__(
547        self,
548    ) -> Optional[Union["MultipartReader", BodyPartReader]]:
549        part = await self.next()
550        if part is None:
551            raise StopAsyncIteration
552        return part
553
554    @classmethod
555    def from_response(
556        cls,
557        response: "ClientResponse",
558    ) -> MultipartResponseWrapper:
559        """Constructs reader instance from HTTP response.
560
561        :param response: :class:`~aiohttp.client.ClientResponse` instance
562        """
563        obj = cls.response_wrapper_cls(
564            response, cls(response.headers, response.content)
565        )
566        return obj
567
568    def at_eof(self) -> bool:
569        """Returns True if the final boundary was reached or
570        False otherwise.
571        """
572        return self._at_eof
573
574    async def next(
575        self,
576    ) -> Optional[Union["MultipartReader", BodyPartReader]]:
577        """Emits the next multipart body part."""
578        # So, if we're at BOF, we need to skip till the boundary.
579        if self._at_eof:
580            return None
581        await self._maybe_release_last_part()
582        if self._at_bof:
583            await self._read_until_first_boundary()
584            self._at_bof = False
585        else:
586            await self._read_boundary()
587        if self._at_eof:  # we just read the last boundary, nothing to do there
588            return None
589        self._last_part = await self.fetch_next_part()
590        return self._last_part
591
592    async def release(self) -> None:
593        """Reads all the body parts to the void till the final boundary."""
594        while not self._at_eof:
595            item = await self.next()
596            if item is None:
597                break
598            await item.release()
599
600    async def fetch_next_part(
601        self,
602    ) -> Union["MultipartReader", BodyPartReader]:
603        """Returns the next body part reader."""
604        headers = await self._read_headers()
605        return self._get_part_reader(headers)
606
607    def _get_part_reader(
608        self,
609        headers: "CIMultiDictProxy[str]",
610    ) -> Union["MultipartReader", BodyPartReader]:
611        """Dispatches the response by the `Content-Type` header, returning
612        suitable reader instance.
613
614        :param dict headers: Response headers
615        """
616        ctype = headers.get(CONTENT_TYPE, "")
617        mimetype = parse_mimetype(ctype)
618
619        if mimetype.type == "multipart":
620            if self.multipart_reader_cls is None:
621                return type(self)(headers, self._content)
622            return self.multipart_reader_cls(headers, self._content)
623        else:
624            return self.part_reader_cls(self._boundary, headers, self._content)
625
626    def _get_boundary(self) -> str:
627        mimetype = parse_mimetype(self.headers[CONTENT_TYPE])
628
629        assert mimetype.type == "multipart", "multipart/* content type expected"
630
631        if "boundary" not in mimetype.parameters:
632            raise ValueError(
633                "boundary missed for Content-Type: %s" % self.headers[CONTENT_TYPE]
634            )
635
636        boundary = mimetype.parameters["boundary"]
637        if len(boundary) > 70:
638            raise ValueError("boundary %r is too long (70 chars max)" % boundary)
639
640        return boundary
641
642    async def _readline(self) -> bytes:
643        if self._unread:
644            return self._unread.pop()
645        return await self._content.readline()
646
647    async def _read_until_first_boundary(self) -> None:
648        while True:
649            chunk = await self._readline()
650            if chunk == b"":
651                raise ValueError(
652                    "Could not find starting boundary %r" % (self._boundary)
653                )
654            chunk = chunk.rstrip()
655            if chunk == self._boundary:
656                return
657            elif chunk == self._boundary + b"--":
658                self._at_eof = True
659                return
660
661    async def _read_boundary(self) -> None:
662        chunk = (await self._readline()).rstrip()
663        if chunk == self._boundary:
664            pass
665        elif chunk == self._boundary + b"--":
666            self._at_eof = True
667            epilogue = await self._readline()
668            next_line = await self._readline()
669
670            # the epilogue is expected and then either the end of input or the
671            # parent multipart boundary, if the parent boundary is found then
672            # it should be marked as unread and handed to the parent for
673            # processing
674            if next_line[:2] == b"--":
675                self._unread.append(next_line)
676            # otherwise the request is likely missing an epilogue and both
677            # lines should be passed to the parent for processing
678            # (this handles the old behavior gracefully)
679            else:
680                self._unread.extend([next_line, epilogue])
681        else:
682            raise ValueError(f"Invalid boundary {chunk!r}, expected {self._boundary!r}")
683
684    async def _read_headers(self) -> "CIMultiDictProxy[str]":
685        lines = [b""]
686        while True:
687            chunk = await self._content.readline()
688            chunk = chunk.strip()
689            lines.append(chunk)
690            if not chunk:
691                break
692        parser = HeadersParser()
693        headers, raw_headers = parser.parse_headers(lines)
694        return headers
695
696    async def _maybe_release_last_part(self) -> None:
697        """Ensures that the last read body part is read completely."""
698        if self._last_part is not None:
699            if not self._last_part.at_eof():
700                await self._last_part.release()
701            self._unread.extend(self._last_part._unread)
702            self._last_part = None
703
704
705_Part = Tuple[Payload, str, str]
706
707
708class MultipartWriter(Payload):
709    """Multipart body writer."""
710
711    def __init__(self, subtype: str = "mixed", boundary: Optional[str] = None) -> None:
712        boundary = boundary if boundary is not None else uuid.uuid4().hex
713        # The underlying Payload API demands a str (utf-8), not bytes,
714        # so we need to ensure we don't lose anything during conversion.
715        # As a result, require the boundary to be ASCII only.
716        # In both situations.
717
718        try:
719            self._boundary = boundary.encode("ascii")
720        except UnicodeEncodeError:
721            raise ValueError("boundary should contain ASCII only chars") from None
722        ctype = f"multipart/{subtype}; boundary={self._boundary_value}"
723
724        super().__init__(None, content_type=ctype)
725
726        self._parts = []  # type: List[_Part]
727
728    def __enter__(self) -> "MultipartWriter":
729        return self
730
731    def __exit__(
732        self,
733        exc_type: Optional[Type[BaseException]],
734        exc_val: Optional[BaseException],
735        exc_tb: Optional[TracebackType],
736    ) -> None:
737        pass
738
739    def __iter__(self) -> Iterator[_Part]:
740        return iter(self._parts)
741
742    def __len__(self) -> int:
743        return len(self._parts)
744
745    def __bool__(self) -> bool:
746        return True
747
748    _valid_tchar_regex = re.compile(br"\A[!#$%&'*+\-.^_`|~\w]+\Z")
749    _invalid_qdtext_char_regex = re.compile(br"[\x00-\x08\x0A-\x1F\x7F]")
750
751    @property
752    def _boundary_value(self) -> str:
753        """Wrap boundary parameter value in quotes, if necessary.
754
755        Reads self.boundary and returns a unicode sting.
756        """
757        # Refer to RFCs 7231, 7230, 5234.
758        #
759        # parameter      = token "=" ( token / quoted-string )
760        # token          = 1*tchar
761        # quoted-string  = DQUOTE *( qdtext / quoted-pair ) DQUOTE
762        # qdtext         = HTAB / SP / %x21 / %x23-5B / %x5D-7E / obs-text
763        # obs-text       = %x80-FF
764        # quoted-pair    = "\" ( HTAB / SP / VCHAR / obs-text )
765        # tchar          = "!" / "#" / "$" / "%" / "&" / "'" / "*"
766        #                  / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~"
767        #                  / DIGIT / ALPHA
768        #                  ; any VCHAR, except delimiters
769        # VCHAR           = %x21-7E
770        value = self._boundary
771        if re.match(self._valid_tchar_regex, value):
772            return value.decode("ascii")  # cannot fail
773
774        if re.search(self._invalid_qdtext_char_regex, value):
775            raise ValueError("boundary value contains invalid characters")
776
777        # escape %x5C and %x22
778        quoted_value_content = value.replace(b"\\", b"\\\\")
779        quoted_value_content = quoted_value_content.replace(b'"', b'\\"')
780
781        return '"' + quoted_value_content.decode("ascii") + '"'
782
783    @property
784    def boundary(self) -> str:
785        return self._boundary.decode("ascii")
786
787    def append(self, obj: Any, headers: Optional[MultiMapping[str]] = None) -> Payload:
788        if headers is None:
789            headers = CIMultiDict()
790
791        if isinstance(obj, Payload):
792            obj.headers.update(headers)
793            return self.append_payload(obj)
794        else:
795            try:
796                payload = get_payload(obj, headers=headers)
797            except LookupError:
798                raise TypeError("Cannot create payload from %r" % obj)
799            else:
800                return self.append_payload(payload)
801
802    def append_payload(self, payload: Payload) -> Payload:
803        """Adds a new body part to multipart writer."""
804        # compression
805        encoding = payload.headers.get(
806            CONTENT_ENCODING,
807            "",
808        ).lower()  # type: Optional[str]
809        if encoding and encoding not in ("deflate", "gzip", "identity"):
810            raise RuntimeError(f"unknown content encoding: {encoding}")
811        if encoding == "identity":
812            encoding = None
813
814        # te encoding
815        te_encoding = payload.headers.get(
816            CONTENT_TRANSFER_ENCODING,
817            "",
818        ).lower()  # type: Optional[str]
819        if te_encoding not in ("", "base64", "quoted-printable", "binary"):
820            raise RuntimeError(
821                "unknown content transfer encoding: {}" "".format(te_encoding)
822            )
823        if te_encoding == "binary":
824            te_encoding = None
825
826        # size
827        size = payload.size
828        if size is not None and not (encoding or te_encoding):
829            payload.headers[CONTENT_LENGTH] = str(size)
830
831        self._parts.append((payload, encoding, te_encoding))  # type: ignore
832        return payload
833
834    def append_json(
835        self, obj: Any, headers: Optional[MultiMapping[str]] = None
836    ) -> Payload:
837        """Helper to append JSON part."""
838        if headers is None:
839            headers = CIMultiDict()
840
841        return self.append_payload(JsonPayload(obj, headers=headers))
842
843    def append_form(
844        self,
845        obj: Union[Sequence[Tuple[str, str]], Mapping[str, str]],
846        headers: Optional[MultiMapping[str]] = None,
847    ) -> Payload:
848        """Helper to append form urlencoded part."""
849        assert isinstance(obj, (Sequence, Mapping))
850
851        if headers is None:
852            headers = CIMultiDict()
853
854        if isinstance(obj, Mapping):
855            obj = list(obj.items())
856        data = urlencode(obj, doseq=True)
857
858        return self.append_payload(
859            StringPayload(
860                data, headers=headers, content_type="application/x-www-form-urlencoded"
861            )
862        )
863
864    @property
865    def size(self) -> Optional[int]:
866        """Size of the payload."""
867        total = 0
868        for part, encoding, te_encoding in self._parts:
869            if encoding or te_encoding or part.size is None:
870                return None
871
872            total += int(
873                2
874                + len(self._boundary)
875                + 2
876                + part.size  # b'--'+self._boundary+b'\r\n'
877                + len(part._binary_headers)
878                + 2  # b'\r\n'
879            )
880
881        total += 2 + len(self._boundary) + 4  # b'--'+self._boundary+b'--\r\n'
882        return total
883
884    async def write(self, writer: Any, close_boundary: bool = True) -> None:
885        """Write body."""
886        for part, encoding, te_encoding in self._parts:
887            await writer.write(b"--" + self._boundary + b"\r\n")
888            await writer.write(part._binary_headers)
889
890            if encoding or te_encoding:
891                w = MultipartPayloadWriter(writer)
892                if encoding:
893                    w.enable_compression(encoding)
894                if te_encoding:
895                    w.enable_encoding(te_encoding)
896                await part.write(w)  # type: ignore
897                await w.write_eof()
898            else:
899                await part.write(writer)
900
901            await writer.write(b"\r\n")
902
903        if close_boundary:
904            await writer.write(b"--" + self._boundary + b"--\r\n")
905
906
907class MultipartPayloadWriter:
908    def __init__(self, writer: Any) -> None:
909        self._writer = writer
910        self._encoding = None  # type: Optional[str]
911        self._compress = None  # type: Any
912        self._encoding_buffer = None  # type: Optional[bytearray]
913
914    def enable_encoding(self, encoding: str) -> None:
915        if encoding == "base64":
916            self._encoding = encoding
917            self._encoding_buffer = bytearray()
918        elif encoding == "quoted-printable":
919            self._encoding = "quoted-printable"
920
921    def enable_compression(self, encoding: str = "deflate") -> None:
922        zlib_mode = 16 + zlib.MAX_WBITS if encoding == "gzip" else -zlib.MAX_WBITS
923        self._compress = zlib.compressobj(wbits=zlib_mode)
924
925    async def write_eof(self) -> None:
926        if self._compress is not None:
927            chunk = self._compress.flush()
928            if chunk:
929                self._compress = None
930                await self.write(chunk)
931
932        if self._encoding == "base64":
933            if self._encoding_buffer:
934                await self._writer.write(base64.b64encode(self._encoding_buffer))
935
936    async def write(self, chunk: bytes) -> None:
937        if self._compress is not None:
938            if chunk:
939                chunk = self._compress.compress(chunk)
940                if not chunk:
941                    return
942
943        if self._encoding == "base64":
944            buf = self._encoding_buffer
945            assert buf is not None
946            buf.extend(chunk)
947
948            if buf:
949                div, mod = divmod(len(buf), 3)
950                enc_chunk, self._encoding_buffer = (buf[: div * 3], buf[div * 3 :])
951                if enc_chunk:
952                    b64chunk = base64.b64encode(enc_chunk)
953                    await self._writer.write(b64chunk)
954        elif self._encoding == "quoted-printable":
955            await self._writer.write(binascii.b2a_qp(chunk))
956        else:
957            await self._writer.write(chunk)
958