1# -*- coding: utf-8 -*-
2import io
3import sys
4import unittest
5import mock
6import pytest
7import requests
8from requests_toolbelt.multipart.decoder import BodyPart
9from requests_toolbelt.multipart.decoder import (
10    ImproperBodyPartContentException
11)
12from requests_toolbelt.multipart.decoder import MultipartDecoder
13from requests_toolbelt.multipart.decoder import (
14    NonMultipartContentTypeException
15)
16from requests_toolbelt.multipart.encoder import encode_with
17from requests_toolbelt.multipart.encoder import MultipartEncoder
18
19
20class TestBodyPart(unittest.TestCase):
21    @staticmethod
22    def u(content):
23        major = sys.version_info[0]
24        if major == 3:
25            return content
26        else:
27            return unicode(content.replace(r'\\', r'\\\\'), 'unicode_escape')
28
29    @staticmethod
30    def bodypart_bytes_from_headers_and_values(headers, value, encoding):
31        return b'\r\n\r\n'.join(
32            [
33                b'\r\n'.join(
34                    [
35                        b': '.join([encode_with(i, encoding) for i in h])
36                        for h in headers
37                    ]
38                ),
39                encode_with(value, encoding)
40            ]
41        )
42
43    def setUp(self):
44        self.header_1 = (TestBodyPart.u('Snowman'), TestBodyPart.u('☃'))
45        self.value_1 = TestBodyPart.u('©')
46        self.part_1 = BodyPart(
47            TestBodyPart.bodypart_bytes_from_headers_and_values(
48                (self.header_1,), self.value_1, 'utf-8'
49            ),
50            'utf-8'
51        )
52        self.part_2 = BodyPart(
53            TestBodyPart.bodypart_bytes_from_headers_and_values(
54                [], self.value_1, 'utf-16'
55            ),
56            'utf-16'
57        )
58
59    def test_equality_content_should_be_equal(self):
60        part_3 = BodyPart(
61            TestBodyPart.bodypart_bytes_from_headers_and_values(
62                [], self.value_1, 'utf-8'
63            ),
64            'utf-8'
65        )
66        assert self.part_1.content == part_3.content
67
68    def test_equality_content_equals_bytes(self):
69        assert self.part_1.content == encode_with(self.value_1, 'utf-8')
70
71    def test_equality_content_should_not_be_equal(self):
72        assert self.part_1.content != self.part_2.content
73
74    def test_equality_content_does_not_equal_bytes(self):
75        assert self.part_1.content != encode_with(self.value_1, 'latin-1')
76
77    def test_changing_encoding_changes_text(self):
78        part_2_orig_text = self.part_2.text
79        self.part_2.encoding = 'latin-1'
80        assert self.part_2.text != part_2_orig_text
81
82    def test_text_should_be_equal(self):
83        assert self.part_1.text == self.part_2.text
84
85    def test_no_headers(self):
86        sample_1 = b'\r\n\r\nNo headers\r\nTwo lines'
87        part_3 = BodyPart(sample_1, 'utf-8')
88        assert len(part_3.headers) == 0
89        assert part_3.content == b'No headers\r\nTwo lines'
90
91    def test_no_crlf_crlf_in_content(self):
92        content = b'no CRLF CRLF here!\r\n'
93        with pytest.raises(ImproperBodyPartContentException):
94            BodyPart(content, 'utf-8')
95
96
97class TestMultipartDecoder(unittest.TestCase):
98    def setUp(self):
99        self.sample_1 = (
100            ('field 1', 'value 1'),
101            ('field 2', 'value 2'),
102            ('field 3', 'value 3'),
103            ('field 4', 'value 4'),
104        )
105        self.boundary = 'test boundary'
106        self.encoded_1 = MultipartEncoder(self.sample_1, self.boundary)
107        self.decoded_1 = MultipartDecoder(
108            self.encoded_1.to_string(),
109            self.encoded_1.content_type
110        )
111
112    def test_non_multipart_response_fails(self):
113        jpeg_response = mock.NonCallableMagicMock(spec=requests.Response)
114        jpeg_response.headers = {'content-type': 'image/jpeg'}
115        with pytest.raises(NonMultipartContentTypeException):
116            MultipartDecoder.from_response(jpeg_response)
117
118    def test_length_of_parts(self):
119        assert len(self.sample_1) == len(self.decoded_1.parts)
120
121    def test_content_of_parts(self):
122        def parts_equal(part, sample):
123            return part.content == encode_with(sample[1], 'utf-8')
124
125        parts_iter = zip(self.decoded_1.parts, self.sample_1)
126        assert all(parts_equal(part, sample) for part, sample in parts_iter)
127
128    def test_header_of_parts(self):
129        def parts_header_equal(part, sample):
130            return part.headers[b'Content-Disposition'] == encode_with(
131                'form-data; name="{0}"'.format(sample[0]), 'utf-8'
132            )
133
134        parts_iter = zip(self.decoded_1.parts, self.sample_1)
135        assert all(
136            parts_header_equal(part, sample)
137            for part, sample in parts_iter
138        )
139
140    def test_from_response(self):
141        response = mock.NonCallableMagicMock(spec=requests.Response)
142        response.headers = {
143            'content-type': 'multipart/related; boundary="samp1"'
144        }
145        cnt = io.BytesIO()
146        cnt.write(b'\r\n--samp1\r\n')
147        cnt.write(b'Header-1: Header-Value-1\r\n')
148        cnt.write(b'Header-2: Header-Value-2\r\n')
149        cnt.write(b'\r\n')
150        cnt.write(b'Body 1, Line 1\r\n')
151        cnt.write(b'Body 1, Line 2\r\n')
152        cnt.write(b'--samp1\r\n')
153        cnt.write(b'\r\n')
154        cnt.write(b'Body 2, Line 1\r\n')
155        cnt.write(b'--samp1--\r\n')
156        response.content = cnt.getvalue()
157        decoder_2 = MultipartDecoder.from_response(response)
158        assert decoder_2.content_type == response.headers['content-type']
159        assert (
160            decoder_2.parts[0].content == b'Body 1, Line 1\r\nBody 1, Line 2'
161        )
162        assert decoder_2.parts[0].headers[b'Header-1'] == b'Header-Value-1'
163        assert len(decoder_2.parts[1].headers) == 0
164        assert decoder_2.parts[1].content == b'Body 2, Line 1'
165
166    def test_from_responsecaplarge(self):
167        response = mock.NonCallableMagicMock(spec=requests.Response)
168        response.headers = {
169            'content-type': 'Multipart/Related; boundary="samp1"'
170        }
171        cnt = io.BytesIO()
172        cnt.write(b'\r\n--samp1\r\n')
173        cnt.write(b'Header-1: Header-Value-1\r\n')
174        cnt.write(b'Header-2: Header-Value-2\r\n')
175        cnt.write(b'\r\n')
176        cnt.write(b'Body 1, Line 1\r\n')
177        cnt.write(b'Body 1, Line 2\r\n')
178        cnt.write(b'--samp1\r\n')
179        cnt.write(b'\r\n')
180        cnt.write(b'Body 2, Line 1\r\n')
181        cnt.write(b'--samp1--\r\n')
182        response.content = cnt.getvalue()
183        decoder_2 = MultipartDecoder.from_response(response)
184        assert decoder_2.content_type == response.headers['content-type']
185        assert (
186            decoder_2.parts[0].content == b'Body 1, Line 1\r\nBody 1, Line 2'
187        )
188        assert decoder_2.parts[0].headers[b'Header-1'] == b'Header-Value-1'
189        assert len(decoder_2.parts[1].headers) == 0
190        assert decoder_2.parts[1].content == b'Body 2, Line 1'
191
192