1# -*- coding: utf-8 -*- 2import mock 3 4from twisted.python.failure import Failure 5 6from twisted.trial.unittest import TestCase 7from twisted.web.http_headers import Headers 8from twisted.web.client import ResponseDone, ResponseFailed 9from twisted.web.http import PotentialDataLoss 10 11from treq import collect, content, json_content, text_content 12from treq.client import _BufferedResponse 13 14 15class ContentTests(TestCase): 16 def setUp(self): 17 self.response = mock.Mock() 18 self.protocol = None 19 20 def deliverBody(protocol): 21 self.protocol = protocol 22 23 self.response.deliverBody.side_effect = deliverBody 24 self.response = _BufferedResponse(self.response) 25 26 def test_collect(self): 27 data = [] 28 29 d = collect(self.response, data.append) 30 31 self.protocol.dataReceived(b'{') 32 self.protocol.dataReceived(b'"msg": "hell') 33 self.protocol.dataReceived(b'o"}') 34 35 self.protocol.connectionLost(Failure(ResponseDone())) 36 37 self.assertEqual(self.successResultOf(d), None) 38 39 self.assertEqual(data, [b'{', b'"msg": "hell', b'o"}']) 40 41 def test_collect_failure(self): 42 data = [] 43 44 d = collect(self.response, data.append) 45 46 self.protocol.dataReceived(b'foo') 47 48 self.protocol.connectionLost(Failure(ResponseFailed("test failure"))) 49 50 self.failureResultOf(d, ResponseFailed) 51 52 self.assertEqual(data, [b'foo']) 53 54 def test_collect_failure_potential_data_loss(self): 55 """ 56 PotentialDataLoss failures are treated as success. 57 """ 58 data = [] 59 60 d = collect(self.response, data.append) 61 62 self.protocol.dataReceived(b'foo') 63 64 self.protocol.connectionLost(Failure(PotentialDataLoss())) 65 66 self.assertEqual(self.successResultOf(d), None) 67 68 self.assertEqual(data, [b'foo']) 69 70 def test_collect_0_length(self): 71 self.response.length = 0 72 73 d = collect( 74 self.response, 75 lambda d: self.fail("Unexpectedly called with: {0}".format(d))) 76 77 self.assertEqual(self.successResultOf(d), None) 78 79 def test_content(self): 80 d = content(self.response) 81 82 self.protocol.dataReceived(b'foo') 83 self.protocol.dataReceived(b'bar') 84 self.protocol.connectionLost(Failure(ResponseDone())) 85 86 self.assertEqual(self.successResultOf(d), b'foobar') 87 88 def test_content_cached(self): 89 d1 = content(self.response) 90 91 self.protocol.dataReceived(b'foo') 92 self.protocol.dataReceived(b'bar') 93 self.protocol.connectionLost(Failure(ResponseDone())) 94 95 self.assertEqual(self.successResultOf(d1), b'foobar') 96 97 def _fail_deliverBody(protocol): 98 self.fail("deliverBody unexpectedly called.") 99 100 self.response.original.deliverBody.side_effect = _fail_deliverBody 101 102 d3 = content(self.response) 103 104 self.assertEqual(self.successResultOf(d3), b'foobar') 105 106 self.assertNotIdentical(d1, d3) 107 108 def test_content_multiple_waiters(self): 109 d1 = content(self.response) 110 d2 = content(self.response) 111 112 self.protocol.dataReceived(b'foo') 113 self.protocol.connectionLost(Failure(ResponseDone())) 114 115 self.assertEqual(self.successResultOf(d1), b'foo') 116 self.assertEqual(self.successResultOf(d2), b'foo') 117 118 self.assertNotIdentical(d1, d2) 119 120 def test_json_content(self): 121 self.response.headers = Headers() 122 d = json_content(self.response) 123 124 self.protocol.dataReceived(b'{"msg":"hello!"}') 125 self.protocol.connectionLost(Failure(ResponseDone())) 126 127 self.assertEqual(self.successResultOf(d), {"msg": "hello!"}) 128 129 def test_json_content_unicode(self): 130 """ 131 When Unicode JSON content is received, the JSON text should be 132 correctly decoded. 133 RFC7159 (8.1): "JSON text SHALL be encoded in UTF-8, UTF-16, or UTF-32. 134 The default encoding is UTF-8" 135 """ 136 self.response.headers = Headers() 137 d = json_content(self.response) 138 139 self.protocol.dataReceived(u'{"msg":"hëlló!"}'.encode('utf-8')) 140 self.protocol.connectionLost(Failure(ResponseDone())) 141 142 self.assertEqual(self.successResultOf(d), {u'msg': u'hëlló!'}) 143 144 def test_json_content_utf16(self): 145 """ 146 JSON received is decoded according to the charset given in the 147 Content-Type header. 148 """ 149 self.response.headers = Headers({ 150 b'Content-Type': [b"application/json; charset='UTF-16LE'"], 151 }) 152 d = json_content(self.response) 153 154 self.protocol.dataReceived(u'{"msg":"hëlló!"}'.encode('UTF-16LE')) 155 self.protocol.connectionLost(Failure(ResponseDone())) 156 157 self.assertEqual(self.successResultOf(d), {u'msg': u'hëlló!'}) 158 159 def test_text_content(self): 160 self.response.headers = Headers( 161 {b'Content-Type': [b'text/plain; charset=utf-8']}) 162 163 d = text_content(self.response) 164 165 self.protocol.dataReceived(b'\xe2\x98\x83') 166 self.protocol.connectionLost(Failure(ResponseDone())) 167 168 self.assertEqual(self.successResultOf(d), u'\u2603') 169 170 def test_text_content_default_encoding_no_param(self): 171 self.response.headers = Headers( 172 {b'Content-Type': [b'text/plain']}) 173 174 d = text_content(self.response) 175 176 self.protocol.dataReceived(b'\xa1') 177 self.protocol.connectionLost(Failure(ResponseDone())) 178 179 self.assertEqual(self.successResultOf(d), u'\xa1') 180 181 def test_text_content_default_encoding_no_header(self): 182 self.response.headers = Headers() 183 184 d = text_content(self.response) 185 186 self.protocol.dataReceived(b'\xa1') 187 self.protocol.connectionLost(Failure(ResponseDone())) 188 189 self.assertEqual(self.successResultOf(d), u'\xa1') 190 191 def test_content_application_json_default_encoding(self): 192 self.response.headers = Headers( 193 {b'Content-Type': [b'application/json']}) 194 195 d = text_content(self.response) 196 197 self.protocol.dataReceived(b'gr\xc3\xbcn') 198 self.protocol.connectionLost(Failure(ResponseDone())) 199 200 self.assertEqual(self.successResultOf(d), u'grün') 201 202 def test_text_content_unicode_headers(self): 203 """ 204 Header parsing is robust against unicode header names and values. 205 """ 206 self.response.headers = Headers({ 207 b'Content-Type': [ 208 u'text/plain; charset="UTF-16BE"; u=ᛃ'.encode('utf-8')], 209 u'Coördination'.encode('iso-8859-1'): [ 210 u'koʊˌɔrdɪˈneɪʃən'.encode('utf-8')], 211 }) 212 213 d = text_content(self.response) 214 215 self.protocol.dataReceived(u'ᚠᚡ'.encode('UTF-16BE')) 216 self.protocol.connectionLost(Failure(ResponseDone())) 217 218 self.assertEqual(self.successResultOf(d), u'ᚠᚡ') 219