1# -*- coding: utf-8 -*- 2from __future__ import unicode_literals 3 4from six import PY3 5from six import StringIO 6from tests.compat import unittest 7from webob import Request, Response 8 9import warnings 10import mock 11 12from webtest import TestApp 13from webtest.compat import to_bytes 14from webtest.lint import check_headers 15from webtest.lint import check_content_type 16from webtest.lint import check_environ 17from webtest.lint import IteratorWrapper 18from webtest.lint import WriteWrapper 19from webtest.lint import ErrorWrapper 20from webtest.lint import InputWrapper 21from webtest.lint import to_string 22from webtest.lint import middleware 23from webtest.lint import _assert_latin1_str 24 25from six import BytesIO 26 27 28def application(environ, start_response): 29 req = Request(environ) 30 resp = Response() 31 env_input = environ['wsgi.input'] 32 len_body = len(req.body) 33 env_input.input.seek(0) 34 if req.path_info == '/read': 35 resp.body = env_input.read(len_body) 36 elif req.path_info == '/read_line': 37 resp.body = env_input.readline(len_body) 38 elif req.path_info == '/read_lines': 39 resp.body = b'-'.join(env_input.readlines(len_body)) 40 elif req.path_info == '/close': 41 resp.body = env_input.close() 42 return resp(environ, start_response) 43 44 45class TestLatin1Assertion(unittest.TestCase): 46 47 def test_valid_type(self): 48 value = "useful-inførmation-5" 49 if not PY3: 50 value = value.encode("latin1") 51 assert value == _assert_latin1_str(value, "fail") 52 53 def test_invalid_type(self): 54 value = b"useful-information-5" 55 if not PY3: 56 value = value.decode("utf8") 57 self.assertRaises(AssertionError, _assert_latin1_str, value, "fail") 58 59 60class TestToString(unittest.TestCase): 61 62 def test_to_string(self): 63 self.assertEqual(to_string('foo'), 'foo') 64 self.assertEqual(to_string(b'foo'), 'foo') 65 66 67class TestMiddleware(unittest.TestCase): 68 69 def test_lint_too_few_args(self): 70 linter = middleware(application) 71 with self.assertRaisesRegexp(AssertionError, "Two arguments required"): 72 linter() 73 with self.assertRaisesRegexp(AssertionError, "Two arguments required"): 74 linter({}) 75 76 def test_lint_no_keyword_args(self): 77 linter = middleware(application) 78 with self.assertRaisesRegexp(AssertionError, "No keyword arguments " 79 "allowed"): 80 linter({}, 'foo', baz='baz') 81 82 # TODO: test start_response_wrapper 83 84 @mock.patch.multiple('webtest.lint', 85 check_environ=lambda x: True, # don't block too early 86 InputWrapper=lambda x: True) 87 def test_lint_iterator_returned(self): 88 linter = middleware(lambda x, y: None) # None is not an iterator 89 msg = "The application must return an iterator, if only an empty list" 90 with self.assertRaisesRegexp(AssertionError, msg): 91 linter({'wsgi.input': 'foo', 'wsgi.errors': 'foo'}, 'foo') 92 93 94class TestInputWrapper(unittest.TestCase): 95 def test_read(self): 96 app = TestApp(application) 97 resp = app.post('/read', 'hello') 98 self.assertEqual(resp.body, b'hello') 99 100 def test_readline(self): 101 app = TestApp(application) 102 resp = app.post('/read_line', 'hello\n') 103 self.assertEqual(resp.body, b'hello\n') 104 105 def test_readlines(self): 106 app = TestApp(application) 107 resp = app.post('/read_lines', 'hello\nt\n') 108 self.assertEqual(resp.body, b'hello\n-t\n') 109 110 def test_close(self): 111 input_wrapper = InputWrapper(None) 112 self.assertRaises(AssertionError, input_wrapper.close) 113 114 def test_iter(self): 115 data = to_bytes("A line\nAnother line\nA final line\n") 116 input_wrapper = InputWrapper(BytesIO(data)) 117 self.assertEquals(to_bytes("").join(input_wrapper), data, '') 118 119 def test_seek(self): 120 data = to_bytes("A line\nAnother line\nA final line\n") 121 input_wrapper = InputWrapper(BytesIO(data)) 122 input_wrapper.seek(0) 123 self.assertEquals(to_bytes("").join(input_wrapper), data, '') 124 125 126class TestMiddleware2(unittest.TestCase): 127 def test_exc_info(self): 128 def application_exc_info(environ, start_response): 129 body = to_bytes('body stuff') 130 headers = [ 131 ('Content-Type', 'text/plain; charset=utf-8'), 132 ('Content-Length', str(len(body)))] 133 # PEP 3333 requires native strings: 134 headers = [(str(k), str(v)) for k, v in headers] 135 start_response(to_bytes('200 OK'), headers, ('stuff',)) 136 return [body] 137 138 app = TestApp(application_exc_info) 139 app.get('/') 140 # don't know what to assert here... a bit cheating, just covers code 141 142 143class TestCheckContentType(unittest.TestCase): 144 def test_no_content(self): 145 status = "204 No Content" 146 headers = [ 147 ('Content-Type', 'text/plain; charset=utf-8'), 148 ('Content-Length', '4') 149 ] 150 self.assertRaises(AssertionError, check_content_type, status, headers) 151 152 def test_no_content_type(self): 153 status = "200 OK" 154 headers = [ 155 ('Content-Length', '4') 156 ] 157 self.assertRaises(AssertionError, check_content_type, status, headers) 158 159 160class TestCheckHeaders(unittest.TestCase): 161 162 @unittest.skipIf(PY3, 'unicode is str in Python3') 163 def test_header_unicode_name(self): 164 headers = [(u'X-Price', str('100'))] 165 self.assertRaises(AssertionError, check_headers, headers) 166 167 @unittest.skipIf(PY3, 'unicode is str in Python3') 168 def test_header_unicode_value(self): 169 headers = [(str('X-Price'), u'100')] 170 self.assertRaises(AssertionError, check_headers, headers) 171 172 @unittest.skipIf(not PY3, 'bytes is str in Python2') 173 def test_header_bytes_name(self): 174 headers = [(b'X-Price', '100')] 175 self.assertRaises(AssertionError, check_headers, headers) 176 177 @unittest.skipIf(not PY3, 'bytes is str in Python2') 178 def test_header_bytes_value(self): 179 headers = [('X-Price', b'100')] 180 self.assertRaises(AssertionError, check_headers, headers) 181 182 def test_header_non_latin1_value(self): 183 headers = [(str('X-Price'), '100€')] 184 self.assertRaises(AssertionError, check_headers, headers) 185 186 def test_header_non_latin1_name(self): 187 headers = [('X-€', str('foo'))] 188 self.assertRaises(AssertionError, check_headers, headers) 189 190 191class TestCheckEnviron(unittest.TestCase): 192 def test_no_query_string(self): 193 environ = { 194 'REQUEST_METHOD': str('GET'), 195 'SERVER_NAME': str('localhost'), 196 'SERVER_PORT': str('80'), 197 'wsgi.version': (1, 0, 1), 198 'wsgi.input': StringIO('test'), 199 'wsgi.errors': StringIO(), 200 'wsgi.multithread': None, 201 'wsgi.multiprocess': None, 202 'wsgi.run_once': None, 203 'wsgi.url_scheme': 'http', 204 'PATH_INFO': str('/'), 205 } 206 with warnings.catch_warnings(record=True) as w: 207 warnings.simplefilter("always") 208 check_environ(environ) 209 self.assertEqual(len(w), 1, "We should have only one warning") 210 self.assertTrue( 211 "QUERY_STRING" in str(w[-1].message), 212 "The warning message should say something about QUERY_STRING") 213 214 def test_no_valid_request(self): 215 environ = { 216 'REQUEST_METHOD': str('PROPFIND'), 217 'SERVER_NAME': str('localhost'), 218 'SERVER_PORT': str('80'), 219 'wsgi.version': (1, 0, 1), 220 'wsgi.input': StringIO('test'), 221 'wsgi.errors': StringIO(), 222 'wsgi.multithread': None, 223 'wsgi.multiprocess': None, 224 'wsgi.run_once': None, 225 'wsgi.url_scheme': 'http', 226 'PATH_INFO': str('/'), 227 'QUERY_STRING': str(''), 228 } 229 with warnings.catch_warnings(record=True) as w: 230 warnings.simplefilter("always") 231 check_environ(environ) 232 self.assertEqual(len(w), 1, "We should have only one warning") 233 self.assertTrue( 234 "REQUEST_METHOD" in str(w[-1].message), 235 "The warning message should say something " 236 "about REQUEST_METHOD") 237 238 def test_handles_native_strings_in_variables(self): 239 # "native string" means unicode in py3, but bytes in py2 240 path = '/umläut' 241 if not PY3: 242 path = path.encode('utf-8') 243 environ = { 244 'REQUEST_METHOD': str('GET'), 245 'SERVER_NAME': str('localhost'), 246 'SERVER_PORT': str('80'), 247 'wsgi.version': (1, 0, 1), 248 'wsgi.input': StringIO('test'), 249 'wsgi.errors': StringIO(), 250 'wsgi.multithread': None, 251 'wsgi.multiprocess': None, 252 'wsgi.run_once': None, 253 'wsgi.url_scheme': 'http', 254 'PATH_INFO': path, 255 'QUERY_STRING': str(''), 256 } 257 with warnings.catch_warnings(record=True) as w: 258 warnings.simplefilter("always") 259 check_environ(environ) 260 self.assertEqual(0, len(w), "We should have no warning") 261 262 263class TestIteratorWrapper(unittest.TestCase): 264 def test_close(self): 265 class MockIterator(object): 266 267 def __init__(self): 268 self.closed = False 269 270 def __iter__(self): 271 return self 272 273 def __next__(self): 274 return None 275 276 next = __next__ 277 278 def close(self): 279 self.closed = True 280 281 mock = MockIterator() 282 wrapper = IteratorWrapper(mock, None) 283 wrapper.close() 284 285 self.assertTrue(mock.closed, "Original iterator has not been closed") 286 287 288class TestWriteWrapper(unittest.TestCase): 289 def test_wrong_type(self): 290 write_wrapper = WriteWrapper(None) 291 self.assertRaises(AssertionError, write_wrapper, 'not a binary') 292 293 def test_normal(self): 294 class MockWriter(object): 295 def __init__(self): 296 self.written = [] 297 298 def __call__(self, s): 299 self.written.append(s) 300 301 data = to_bytes('foo') 302 mock = MockWriter() 303 write_wrapper = WriteWrapper(mock) 304 write_wrapper(data) 305 self.assertEqual( 306 mock.written, [data], 307 "WriterWrapper should call original writer when data is binary " 308 "type") 309 310 311class TestErrorWrapper(unittest.TestCase): 312 313 def test_dont_close(self): 314 error_wrapper = ErrorWrapper(None) 315 self.assertRaises(AssertionError, error_wrapper.close) 316 317 class FakeError(object): 318 def __init__(self): 319 self.written = [] 320 self.flushed = False 321 322 def write(self, s): 323 self.written.append(s) 324 325 def writelines(self, lines): 326 for line in lines: 327 self.write(line) 328 329 def flush(self): 330 self.flushed = True 331 332 def test_writelines(self): 333 fake_error = self.FakeError() 334 error_wrapper = ErrorWrapper(fake_error) 335 data = [to_bytes('a line'), to_bytes('another line')] 336 error_wrapper.writelines(data) 337 self.assertEqual(fake_error.written, data, 338 "ErrorWrapper should call original writer") 339 340 def test_flush(self): 341 fake_error = self.FakeError() 342 error_wrapper = ErrorWrapper(fake_error) 343 error_wrapper.flush() 344 self.assertTrue( 345 fake_error.flushed, 346 "ErrorWrapper should have called original wsgi_errors's flush") 347