1# -*- coding: utf-8 -*-
2import mock
3
4from twisted.python.failure import Failure
5
6from twisted.trial.unittest import TestCase
7from twisted.web.http_headers import Headers
8from twisted.web.client import ResponseDone, ResponseFailed
9from twisted.web.http import PotentialDataLoss
10
11from treq import collect, content, json_content, text_content
12from treq.client import _BufferedResponse
13
14
15class ContentTests(TestCase):
16    def setUp(self):
17        self.response = mock.Mock()
18        self.protocol = None
19
20        def deliverBody(protocol):
21            self.protocol = protocol
22
23        self.response.deliverBody.side_effect = deliverBody
24        self.response = _BufferedResponse(self.response)
25
26    def test_collect(self):
27        data = []
28
29        d = collect(self.response, data.append)
30
31        self.protocol.dataReceived(b'{')
32        self.protocol.dataReceived(b'"msg": "hell')
33        self.protocol.dataReceived(b'o"}')
34
35        self.protocol.connectionLost(Failure(ResponseDone()))
36
37        self.assertEqual(self.successResultOf(d), None)
38
39        self.assertEqual(data, [b'{', b'"msg": "hell', b'o"}'])
40
41    def test_collect_failure(self):
42        data = []
43
44        d = collect(self.response, data.append)
45
46        self.protocol.dataReceived(b'foo')
47
48        self.protocol.connectionLost(Failure(ResponseFailed("test failure")))
49
50        self.failureResultOf(d, ResponseFailed)
51
52        self.assertEqual(data, [b'foo'])
53
54    def test_collect_failure_potential_data_loss(self):
55        """
56        PotentialDataLoss failures are treated as success.
57        """
58        data = []
59
60        d = collect(self.response, data.append)
61
62        self.protocol.dataReceived(b'foo')
63
64        self.protocol.connectionLost(Failure(PotentialDataLoss()))
65
66        self.assertEqual(self.successResultOf(d), None)
67
68        self.assertEqual(data, [b'foo'])
69
70    def test_collect_0_length(self):
71        self.response.length = 0
72
73        d = collect(
74            self.response,
75            lambda d: self.fail("Unexpectedly called with: {0}".format(d)))
76
77        self.assertEqual(self.successResultOf(d), None)
78
79    def test_content(self):
80        d = content(self.response)
81
82        self.protocol.dataReceived(b'foo')
83        self.protocol.dataReceived(b'bar')
84        self.protocol.connectionLost(Failure(ResponseDone()))
85
86        self.assertEqual(self.successResultOf(d), b'foobar')
87
88    def test_content_cached(self):
89        d1 = content(self.response)
90
91        self.protocol.dataReceived(b'foo')
92        self.protocol.dataReceived(b'bar')
93        self.protocol.connectionLost(Failure(ResponseDone()))
94
95        self.assertEqual(self.successResultOf(d1), b'foobar')
96
97        def _fail_deliverBody(protocol):
98            self.fail("deliverBody unexpectedly called.")
99
100        self.response.original.deliverBody.side_effect = _fail_deliverBody
101
102        d3 = content(self.response)
103
104        self.assertEqual(self.successResultOf(d3), b'foobar')
105
106        self.assertNotIdentical(d1, d3)
107
108    def test_content_multiple_waiters(self):
109        d1 = content(self.response)
110        d2 = content(self.response)
111
112        self.protocol.dataReceived(b'foo')
113        self.protocol.connectionLost(Failure(ResponseDone()))
114
115        self.assertEqual(self.successResultOf(d1), b'foo')
116        self.assertEqual(self.successResultOf(d2), b'foo')
117
118        self.assertNotIdentical(d1, d2)
119
120    def test_json_content(self):
121        self.response.headers = Headers()
122        d = json_content(self.response)
123
124        self.protocol.dataReceived(b'{"msg":"hello!"}')
125        self.protocol.connectionLost(Failure(ResponseDone()))
126
127        self.assertEqual(self.successResultOf(d), {"msg": "hello!"})
128
129    def test_json_content_unicode(self):
130        """
131        When Unicode JSON content is received, the JSON text should be
132        correctly decoded.
133        RFC7159 (8.1): "JSON text SHALL be encoded in UTF-8, UTF-16, or UTF-32.
134        The default encoding is UTF-8"
135        """
136        self.response.headers = Headers()
137        d = json_content(self.response)
138
139        self.protocol.dataReceived(u'{"msg":"hëlló!"}'.encode('utf-8'))
140        self.protocol.connectionLost(Failure(ResponseDone()))
141
142        self.assertEqual(self.successResultOf(d), {u'msg': u'hëlló!'})
143
144    def test_json_content_utf16(self):
145        """
146        JSON received is decoded according to the charset given in the
147        Content-Type header.
148        """
149        self.response.headers = Headers({
150            b'Content-Type': [b"application/json; charset='UTF-16LE'"],
151        })
152        d = json_content(self.response)
153
154        self.protocol.dataReceived(u'{"msg":"hëlló!"}'.encode('UTF-16LE'))
155        self.protocol.connectionLost(Failure(ResponseDone()))
156
157        self.assertEqual(self.successResultOf(d), {u'msg': u'hëlló!'})
158
159    def test_text_content(self):
160        self.response.headers = Headers(
161            {b'Content-Type': [b'text/plain; charset=utf-8']})
162
163        d = text_content(self.response)
164
165        self.protocol.dataReceived(b'\xe2\x98\x83')
166        self.protocol.connectionLost(Failure(ResponseDone()))
167
168        self.assertEqual(self.successResultOf(d), u'\u2603')
169
170    def test_text_content_default_encoding_no_param(self):
171        self.response.headers = Headers(
172            {b'Content-Type': [b'text/plain']})
173
174        d = text_content(self.response)
175
176        self.protocol.dataReceived(b'\xa1')
177        self.protocol.connectionLost(Failure(ResponseDone()))
178
179        self.assertEqual(self.successResultOf(d), u'\xa1')
180
181    def test_text_content_default_encoding_no_header(self):
182        self.response.headers = Headers()
183
184        d = text_content(self.response)
185
186        self.protocol.dataReceived(b'\xa1')
187        self.protocol.connectionLost(Failure(ResponseDone()))
188
189        self.assertEqual(self.successResultOf(d), u'\xa1')
190
191    def test_content_application_json_default_encoding(self):
192        self.response.headers = Headers(
193            {b'Content-Type': [b'application/json']})
194
195        d = text_content(self.response)
196
197        self.protocol.dataReceived(b'gr\xc3\xbcn')
198        self.protocol.connectionLost(Failure(ResponseDone()))
199
200        self.assertEqual(self.successResultOf(d), u'grün')
201
202    def test_text_content_unicode_headers(self):
203        """
204        Header parsing is robust against unicode header names and values.
205        """
206        self.response.headers = Headers({
207            b'Content-Type': [
208                u'text/plain; charset="UTF-16BE"; u=ᛃ'.encode('utf-8')],
209            u'Coördination'.encode('iso-8859-1'): [
210                u'koʊˌɔrdɪˈneɪʃən'.encode('utf-8')],
211        })
212
213        d = text_content(self.response)
214
215        self.protocol.dataReceived(u'ᚠᚡ'.encode('UTF-16BE'))
216        self.protocol.connectionLost(Failure(ResponseDone()))
217
218        self.assertEqual(self.successResultOf(d), u'ᚠᚡ')
219