1# -*- coding: utf-8 -*-
2from __future__ import unicode_literals
3
4from six import PY3
5from six import StringIO
6from tests.compat import unittest
7from webob import Request, Response
8
9import warnings
10import mock
11
12from webtest import TestApp
13from webtest.compat import to_bytes
14from webtest.lint import check_headers
15from webtest.lint import check_content_type
16from webtest.lint import check_environ
17from webtest.lint import IteratorWrapper
18from webtest.lint import WriteWrapper
19from webtest.lint import ErrorWrapper
20from webtest.lint import InputWrapper
21from webtest.lint import to_string
22from webtest.lint import middleware
23from webtest.lint import _assert_latin1_str
24
25from six import BytesIO
26
27
28def application(environ, start_response):
29    req = Request(environ)
30    resp = Response()
31    env_input = environ['wsgi.input']
32    len_body = len(req.body)
33    env_input.input.seek(0)
34    if req.path_info == '/read':
35        resp.body = env_input.read(len_body)
36    elif req.path_info == '/read_line':
37        resp.body = env_input.readline(len_body)
38    elif req.path_info == '/read_lines':
39        resp.body = b'-'.join(env_input.readlines(len_body))
40    elif req.path_info == '/close':
41        resp.body = env_input.close()
42    return resp(environ, start_response)
43
44
45class TestLatin1Assertion(unittest.TestCase):
46
47    def test_valid_type(self):
48        value = "useful-inførmation-5"
49        if not PY3:
50            value = value.encode("latin1")
51        assert value == _assert_latin1_str(value, "fail")
52
53    def test_invalid_type(self):
54        value = b"useful-information-5"
55        if not PY3:
56            value = value.decode("utf8")
57        self.assertRaises(AssertionError, _assert_latin1_str, value, "fail")
58
59
60class TestToString(unittest.TestCase):
61
62    def test_to_string(self):
63        self.assertEqual(to_string('foo'), 'foo')
64        self.assertEqual(to_string(b'foo'), 'foo')
65
66
67class TestMiddleware(unittest.TestCase):
68
69    def test_lint_too_few_args(self):
70        linter = middleware(application)
71        with self.assertRaisesRegexp(AssertionError, "Two arguments required"):
72            linter()
73        with self.assertRaisesRegexp(AssertionError, "Two arguments required"):
74            linter({})
75
76    def test_lint_no_keyword_args(self):
77        linter = middleware(application)
78        with self.assertRaisesRegexp(AssertionError, "No keyword arguments "
79                                                     "allowed"):
80            linter({}, 'foo', baz='baz')
81
82    # TODO: test start_response_wrapper
83
84    @mock.patch.multiple('webtest.lint',
85                         check_environ=lambda x: True,  # don't block too early
86                         InputWrapper=lambda x: True)
87    def test_lint_iterator_returned(self):
88        linter = middleware(lambda x, y: None)  # None is not an iterator
89        msg = "The application must return an iterator, if only an empty list"
90        with self.assertRaisesRegexp(AssertionError, msg):
91            linter({'wsgi.input': 'foo', 'wsgi.errors': 'foo'}, 'foo')
92
93
94class TestInputWrapper(unittest.TestCase):
95    def test_read(self):
96        app = TestApp(application)
97        resp = app.post('/read', 'hello')
98        self.assertEqual(resp.body, b'hello')
99
100    def test_readline(self):
101        app = TestApp(application)
102        resp = app.post('/read_line', 'hello\n')
103        self.assertEqual(resp.body, b'hello\n')
104
105    def test_readlines(self):
106        app = TestApp(application)
107        resp = app.post('/read_lines', 'hello\nt\n')
108        self.assertEqual(resp.body, b'hello\n-t\n')
109
110    def test_close(self):
111        input_wrapper = InputWrapper(None)
112        self.assertRaises(AssertionError, input_wrapper.close)
113
114    def test_iter(self):
115        data = to_bytes("A line\nAnother line\nA final line\n")
116        input_wrapper = InputWrapper(BytesIO(data))
117        self.assertEquals(to_bytes("").join(input_wrapper), data, '')
118
119    def test_seek(self):
120        data = to_bytes("A line\nAnother line\nA final line\n")
121        input_wrapper = InputWrapper(BytesIO(data))
122        input_wrapper.seek(0)
123        self.assertEquals(to_bytes("").join(input_wrapper), data, '')
124
125
126class TestMiddleware2(unittest.TestCase):
127    def test_exc_info(self):
128        def application_exc_info(environ, start_response):
129            body = to_bytes('body stuff')
130            headers = [
131                ('Content-Type', 'text/plain; charset=utf-8'),
132                ('Content-Length', str(len(body)))]
133            # PEP 3333 requires native strings:
134            headers = [(str(k), str(v)) for k, v in headers]
135            start_response(to_bytes('200 OK'), headers, ('stuff',))
136            return [body]
137
138        app = TestApp(application_exc_info)
139        app.get('/')
140        # don't know what to assert here... a bit cheating, just covers code
141
142
143class TestCheckContentType(unittest.TestCase):
144    def test_no_content(self):
145        status = "204 No Content"
146        headers = [
147            ('Content-Type', 'text/plain; charset=utf-8'),
148            ('Content-Length', '4')
149        ]
150        self.assertRaises(AssertionError, check_content_type, status, headers)
151
152    def test_no_content_type(self):
153        status = "200 OK"
154        headers = [
155            ('Content-Length', '4')
156        ]
157        self.assertRaises(AssertionError, check_content_type, status, headers)
158
159
160class TestCheckHeaders(unittest.TestCase):
161
162    @unittest.skipIf(PY3, 'unicode is str in Python3')
163    def test_header_unicode_name(self):
164        headers = [(u'X-Price', str('100'))]
165        self.assertRaises(AssertionError, check_headers, headers)
166
167    @unittest.skipIf(PY3, 'unicode is str in Python3')
168    def test_header_unicode_value(self):
169        headers = [(str('X-Price'), u'100')]
170        self.assertRaises(AssertionError, check_headers, headers)
171
172    @unittest.skipIf(not PY3, 'bytes is str in Python2')
173    def test_header_bytes_name(self):
174        headers = [(b'X-Price', '100')]
175        self.assertRaises(AssertionError, check_headers, headers)
176
177    @unittest.skipIf(not PY3, 'bytes is str in Python2')
178    def test_header_bytes_value(self):
179        headers = [('X-Price', b'100')]
180        self.assertRaises(AssertionError, check_headers, headers)
181
182    def test_header_non_latin1_value(self):
183        headers = [(str('X-Price'), '100€')]
184        self.assertRaises(AssertionError, check_headers, headers)
185
186    def test_header_non_latin1_name(self):
187        headers = [('X-€', str('foo'))]
188        self.assertRaises(AssertionError, check_headers, headers)
189
190
191class TestCheckEnviron(unittest.TestCase):
192    def test_no_query_string(self):
193        environ = {
194            'REQUEST_METHOD': str('GET'),
195            'SERVER_NAME': str('localhost'),
196            'SERVER_PORT': str('80'),
197            'wsgi.version': (1, 0, 1),
198            'wsgi.input': StringIO('test'),
199            'wsgi.errors': StringIO(),
200            'wsgi.multithread': None,
201            'wsgi.multiprocess': None,
202            'wsgi.run_once': None,
203            'wsgi.url_scheme': 'http',
204            'PATH_INFO': str('/'),
205        }
206        with warnings.catch_warnings(record=True) as w:
207            warnings.simplefilter("always")
208            check_environ(environ)
209            self.assertEqual(len(w), 1, "We should have only one warning")
210            self.assertTrue(
211                "QUERY_STRING" in str(w[-1].message),
212                "The warning message should say something about QUERY_STRING")
213
214    def test_no_valid_request(self):
215        environ = {
216            'REQUEST_METHOD': str('PROPFIND'),
217            'SERVER_NAME': str('localhost'),
218            'SERVER_PORT': str('80'),
219            'wsgi.version': (1, 0, 1),
220            'wsgi.input': StringIO('test'),
221            'wsgi.errors': StringIO(),
222            'wsgi.multithread': None,
223            'wsgi.multiprocess': None,
224            'wsgi.run_once': None,
225            'wsgi.url_scheme': 'http',
226            'PATH_INFO': str('/'),
227            'QUERY_STRING': str(''),
228        }
229        with warnings.catch_warnings(record=True) as w:
230            warnings.simplefilter("always")
231            check_environ(environ)
232            self.assertEqual(len(w), 1, "We should have only one warning")
233            self.assertTrue(
234                "REQUEST_METHOD" in str(w[-1].message),
235                "The warning message should say something "
236                "about REQUEST_METHOD")
237
238    def test_handles_native_strings_in_variables(self):
239        # "native string" means unicode in py3, but bytes in py2
240        path = '/umläut'
241        if not PY3:
242            path = path.encode('utf-8')
243        environ = {
244            'REQUEST_METHOD': str('GET'),
245            'SERVER_NAME': str('localhost'),
246            'SERVER_PORT': str('80'),
247            'wsgi.version': (1, 0, 1),
248            'wsgi.input': StringIO('test'),
249            'wsgi.errors': StringIO(),
250            'wsgi.multithread': None,
251            'wsgi.multiprocess': None,
252            'wsgi.run_once': None,
253            'wsgi.url_scheme': 'http',
254            'PATH_INFO': path,
255            'QUERY_STRING': str(''),
256        }
257        with warnings.catch_warnings(record=True) as w:
258            warnings.simplefilter("always")
259            check_environ(environ)
260            self.assertEqual(0, len(w), "We should have no warning")
261
262
263class TestIteratorWrapper(unittest.TestCase):
264    def test_close(self):
265        class MockIterator(object):
266
267            def __init__(self):
268                self.closed = False
269
270            def __iter__(self):
271                return self
272
273            def __next__(self):
274                return None
275
276            next = __next__
277
278            def close(self):
279                self.closed = True
280
281        mock = MockIterator()
282        wrapper = IteratorWrapper(mock, None)
283        wrapper.close()
284
285        self.assertTrue(mock.closed, "Original iterator has not been closed")
286
287
288class TestWriteWrapper(unittest.TestCase):
289    def test_wrong_type(self):
290        write_wrapper = WriteWrapper(None)
291        self.assertRaises(AssertionError, write_wrapper, 'not a binary')
292
293    def test_normal(self):
294        class MockWriter(object):
295            def __init__(self):
296                self.written = []
297
298            def __call__(self, s):
299                self.written.append(s)
300
301        data = to_bytes('foo')
302        mock = MockWriter()
303        write_wrapper = WriteWrapper(mock)
304        write_wrapper(data)
305        self.assertEqual(
306            mock.written, [data],
307            "WriterWrapper should call original writer when data is binary "
308            "type")
309
310
311class TestErrorWrapper(unittest.TestCase):
312
313    def test_dont_close(self):
314        error_wrapper = ErrorWrapper(None)
315        self.assertRaises(AssertionError, error_wrapper.close)
316
317    class FakeError(object):
318        def __init__(self):
319            self.written = []
320            self.flushed = False
321
322        def write(self, s):
323            self.written.append(s)
324
325        def writelines(self, lines):
326            for line in lines:
327                self.write(line)
328
329        def flush(self):
330            self.flushed = True
331
332    def test_writelines(self):
333        fake_error = self.FakeError()
334        error_wrapper = ErrorWrapper(fake_error)
335        data = [to_bytes('a line'), to_bytes('another line')]
336        error_wrapper.writelines(data)
337        self.assertEqual(fake_error.written, data,
338                         "ErrorWrapper should call original writer")
339
340    def test_flush(self):
341        fake_error = self.FakeError()
342        error_wrapper = ErrorWrapper(fake_error)
343        error_wrapper.flush()
344        self.assertTrue(
345            fake_error.flushed,
346            "ErrorWrapper should have called original wsgi_errors's flush")
347