1# -*- coding: utf-8 -*-
2"""
3
4requests_toolbelt.multipart.decoder
5===================================
6
7This holds all the implementation details of the MultipartDecoder
8
9"""
10
11import sys
12import email.parser
13from .encoder import encode_with
14from requests.structures import CaseInsensitiveDict
15
16
17def _split_on_find(content, bound):
18    point = content.find(bound)
19    return content[:point], content[point + len(bound):]
20
21
22class ImproperBodyPartContentException(Exception):
23    pass
24
25
26class NonMultipartContentTypeException(Exception):
27    pass
28
29
30def _header_parser(string, encoding):
31    major = sys.version_info[0]
32    if major == 3:
33        string = string.decode(encoding)
34    headers = email.parser.HeaderParser().parsestr(string).items()
35    return (
36        (encode_with(k, encoding), encode_with(v, encoding))
37        for k, v in headers
38    )
39
40
41class BodyPart(object):
42    """
43
44    The ``BodyPart`` object is a ``Response``-like interface to an individual
45    subpart of a multipart response. It is expected that these will
46    generally be created by objects of the ``MultipartDecoder`` class.
47
48    Like ``Response``, there is a ``CaseInsensitiveDict`` object named headers,
49    ``content`` to access bytes, ``text`` to access unicode, and ``encoding``
50    to access the unicode codec.
51
52    """
53
54    def __init__(self, content, encoding):
55        self.encoding = encoding
56        headers = {}
57        # Split into header section (if any) and the content
58        if b'\r\n\r\n' in content:
59            first, self.content = _split_on_find(content, b'\r\n\r\n')
60            if first != b'':
61                headers = _header_parser(first.lstrip(), encoding)
62        else:
63            raise ImproperBodyPartContentException(
64                'content does not contain CR-LF-CR-LF'
65            )
66        self.headers = CaseInsensitiveDict(headers)
67
68    @property
69    def text(self):
70        """Content of the ``BodyPart`` in unicode."""
71        return self.content.decode(self.encoding)
72
73
74class MultipartDecoder(object):
75    """
76
77    The ``MultipartDecoder`` object parses the multipart payload of
78    a bytestring into a tuple of ``Response``-like ``BodyPart`` objects.
79
80    The basic usage is::
81
82        import requests
83        from requests_toolbelt import MultipartDecoder
84
85        response = request.get(url)
86        decoder = MultipartDecoder.from_response(response)
87        for part in decoder.parts:
88            print(part.headers['content-type'])
89
90    If the multipart content is not from a response, basic usage is::
91
92        from requests_toolbelt import MultipartDecoder
93
94        decoder = MultipartDecoder(content, content_type)
95        for part in decoder.parts:
96            print(part.headers['content-type'])
97
98    For both these usages, there is an optional ``encoding`` parameter. This is
99    a string, which is the name of the unicode codec to use (default is
100    ``'utf-8'``).
101
102    """
103    def __init__(self, content, content_type, encoding='utf-8'):
104        #: Original Content-Type header
105        self.content_type = content_type
106        #: Response body encoding
107        self.encoding = encoding
108        #: Parsed parts of the multipart response body
109        self.parts = tuple()
110        self._find_boundary()
111        self._parse_body(content)
112
113    def _find_boundary(self):
114        ct_info = tuple(x.strip() for x in self.content_type.split(';'))
115        mimetype = ct_info[0]
116        if mimetype.split('/')[0].lower() != 'multipart':
117            raise NonMultipartContentTypeException(
118                "Unexpected mimetype in content-type: '{0}'".format(mimetype)
119            )
120        for item in ct_info[1:]:
121            attr, value = _split_on_find(
122                item,
123                '='
124            )
125            if attr.lower() == 'boundary':
126                self.boundary = encode_with(value.strip('"'), self.encoding)
127
128    @staticmethod
129    def _fix_first_part(part, boundary_marker):
130        bm_len = len(boundary_marker)
131        if boundary_marker == part[:bm_len]:
132            return part[bm_len:]
133        else:
134            return part
135
136    def _parse_body(self, content):
137        boundary = b''.join((b'--', self.boundary))
138
139        def body_part(part):
140            fixed = MultipartDecoder._fix_first_part(part, boundary)
141            return BodyPart(fixed, self.encoding)
142
143        def test_part(part):
144            return (part != b'' and
145                    part != b'\r\n' and
146                    part[:4] != b'--\r\n' and
147                    part != b'--')
148
149        parts = content.split(b''.join((b'\r\n', boundary)))
150        self.parts = tuple(body_part(x) for x in parts if test_part(x))
151
152    @classmethod
153    def from_response(cls, response, encoding='utf-8'):
154        content = response.content
155        content_type = response.headers.get('content-type', None)
156        return cls(content, content_type, encoding)
157