1from unittest import mock
2from urllib.parse import urlencode
3
4import marshmallow as ma
5import pytest
6import tornado.concurrent
7import tornado.http1connection
8import tornado.httpserver
9import tornado.httputil
10import tornado.ioloop
11import tornado.web
12from tornado.testing import AsyncHTTPTestCase
13from webargs import fields, missing
14from webargs.core import json, parse_json
15from webargs.tornadoparser import (
16    WebArgsTornadoMultiDictProxy,
17    parser,
18    use_args,
19    use_kwargs,
20)
21
22
23name = "name"
24value = "value"
25
26
27class AuthorSchema(ma.Schema):
28    name = fields.Str(missing="World", validate=lambda n: len(n) >= 3)
29    works = fields.List(fields.Str())
30
31
32author_schema = AuthorSchema()
33
34
35def test_tornado_multidictproxy():
36    for dictval, fieldname, expected in (
37        ({"name": "Sophocles"}, "name", "Sophocles"),
38        ({"name": "Sophocles"}, "works", missing),
39        ({"works": ["Antigone", "Oedipus Rex"]}, "works", ["Antigone", "Oedipus Rex"]),
40        ({"works": ["Antigone", "Oedipus at Colonus"]}, "name", missing),
41    ):
42        proxy = WebArgsTornadoMultiDictProxy(dictval, author_schema)
43        assert proxy.get(fieldname) == expected
44
45
46class TestQueryArgs:
47    def test_it_should_get_single_values(self):
48        query = [("name", "Aeschylus")]
49        request = make_get_request(query)
50        result = parser.load_querystring(request, author_schema)
51        assert result["name"] == "Aeschylus"
52
53    def test_it_should_get_multiple_values(self):
54        query = [("works", "Agamemnon"), ("works", "Nereids")]
55        request = make_get_request(query)
56        result = parser.load_querystring(request, author_schema)
57        assert result["works"] == ["Agamemnon", "Nereids"]
58
59    def test_it_should_return_missing_if_not_present(self):
60        query = []
61        request = make_get_request(query)
62        result = parser.load_querystring(request, author_schema)
63        assert result["name"] is missing
64        assert result["works"] is missing
65
66
67class TestFormArgs:
68    def test_it_should_get_single_values(self):
69        query = [("name", "Aristophanes")]
70        request = make_form_request(query)
71        result = parser.load_form(request, author_schema)
72        assert result["name"] == "Aristophanes"
73
74    def test_it_should_get_multiple_values(self):
75        query = [("works", "The Wasps"), ("works", "The Frogs")]
76        request = make_form_request(query)
77        result = parser.load_form(request, author_schema)
78        assert result["works"] == ["The Wasps", "The Frogs"]
79
80    def test_it_should_return_missing_if_not_present(self):
81        query = []
82        request = make_form_request(query)
83        result = parser.load_form(request, author_schema)
84        assert result["name"] is missing
85        assert result["works"] is missing
86
87
88class TestJSONArgs:
89    def test_it_should_get_single_values(self):
90        query = {"name": "Euripides"}
91        request = make_json_request(query)
92        result = parser.load_json(request, author_schema)
93        assert result["name"] == "Euripides"
94
95    def test_parsing_request_with_vendor_content_type(self):
96        query = {"name": "Euripides"}
97        request = make_json_request(
98            query, content_type="application/vnd.api+json; charset=UTF-8"
99        )
100        result = parser.load_json(request, author_schema)
101        assert result["name"] == "Euripides"
102
103    def test_it_should_get_multiple_values(self):
104        query = {"works": ["Medea", "Electra"]}
105        request = make_json_request(query)
106        result = parser.load_json(request, author_schema)
107        assert result["works"] == ["Medea", "Electra"]
108
109    def test_it_should_get_multiple_nested_values(self):
110        class CustomSchema(ma.Schema):
111            works = fields.List(
112                fields.Nested({"author": fields.Str(), "workname": fields.Str()})
113            )
114
115        custom_schema = CustomSchema()
116
117        query = {
118            "works": [
119                {"author": "Euripides", "workname": "Hecuba"},
120                {"author": "Aristophanes", "workname": "The Birds"},
121            ]
122        }
123        request = make_json_request(query)
124        result = parser.load_json(request, custom_schema)
125        assert result["works"] == [
126            {"author": "Euripides", "workname": "Hecuba"},
127            {"author": "Aristophanes", "workname": "The Birds"},
128        ]
129
130    def test_it_should_not_include_fieldnames_if_not_present(self):
131        query = {}
132        request = make_json_request(query)
133        result = parser.load_json(request, author_schema)
134        assert result == {}
135
136    def test_it_should_handle_type_error_on_load_json(self, loop):
137        # but this is different from the test above where the payload was valid
138        # and empty -- missing vs {}
139        # NOTE: `loop` is the pytest-aiohttp event loop fixture, but it's
140        # important to get an event loop here so that we can construct a future
141        request = make_request(
142            body=tornado.concurrent.Future(),
143            headers={"Content-Type": "application/json"},
144        )
145        result = parser.load_json(request, author_schema)
146        assert result is missing
147
148    def test_it_should_handle_value_error_on_parse_json(self):
149        request = make_request("this is json not")
150        result = parser.load_json(request, author_schema)
151        assert result is missing
152
153
154class TestHeadersArgs:
155    def test_it_should_get_single_values(self):
156        query = {"name": "Euphorion"}
157        request = make_request(headers=query)
158        result = parser.load_headers(request, author_schema)
159        assert result["name"] == "Euphorion"
160
161    def test_it_should_get_multiple_values(self):
162        query = {"works": ["Prometheus Bound", "Prometheus Unbound"]}
163        request = make_request(headers=query)
164        result = parser.load_headers(request, author_schema)
165        assert result["works"] == ["Prometheus Bound", "Prometheus Unbound"]
166
167    def test_it_should_return_missing_if_not_present(self):
168        request = make_request()
169        result = parser.load_headers(request, author_schema)
170        assert result["name"] is missing
171        assert result["works"] is missing
172
173
174class TestFilesArgs:
175    def test_it_should_get_single_values(self):
176        query = [("name", "Sappho")]
177        request = make_files_request(query)
178        result = parser.load_files(request, author_schema)
179        assert result["name"] == "Sappho"
180
181    def test_it_should_get_multiple_values(self):
182        query = [("works", "Sappho 31"), ("works", "Ode to Aphrodite")]
183        request = make_files_request(query)
184        result = parser.load_files(request, author_schema)
185        assert result["works"] == ["Sappho 31", "Ode to Aphrodite"]
186
187    def test_it_should_return_missing_if_not_present(self):
188        query = []
189        request = make_files_request(query)
190        result = parser.load_files(request, author_schema)
191        assert result["name"] is missing
192        assert result["works"] is missing
193
194
195class TestErrorHandler:
196    def test_it_should_raise_httperror_on_failed_validation(self):
197        args = {"foo": fields.Field(validate=lambda x: False)}
198        with pytest.raises(tornado.web.HTTPError):
199            parser.parse(args, make_json_request({"foo": 42}))
200
201
202class TestParse:
203    def test_it_should_parse_query_arguments(self):
204        attrs = {"string": fields.Field(), "integer": fields.List(fields.Int())}
205
206        request = make_get_request(
207            [("string", "value"), ("integer", "1"), ("integer", "2")]
208        )
209
210        parsed = parser.parse(attrs, request, location="query")
211
212        assert parsed["integer"] == [1, 2]
213        assert parsed["string"] == value
214
215    def test_it_should_parse_form_arguments(self):
216        attrs = {"string": fields.Field(), "integer": fields.List(fields.Int())}
217
218        request = make_form_request(
219            [("string", "value"), ("integer", "1"), ("integer", "2")]
220        )
221
222        parsed = parser.parse(attrs, request, location="form")
223
224        assert parsed["integer"] == [1, 2]
225        assert parsed["string"] == value
226
227    def test_it_should_parse_json_arguments(self):
228        attrs = {"string": fields.Str(), "integer": fields.List(fields.Int())}
229
230        request = make_json_request({"string": "value", "integer": [1, 2]})
231
232        parsed = parser.parse(attrs, request)
233
234        assert parsed["integer"] == [1, 2]
235        assert parsed["string"] == value
236
237    def test_it_should_raise_when_json_is_invalid(self):
238        attrs = {"foo": fields.Str()}
239
240        request = make_request(
241            body='{"foo": 42,}', headers={"Content-Type": "application/json"}
242        )
243        with pytest.raises(tornado.web.HTTPError) as excinfo:
244            parser.parse(attrs, request)
245        error = excinfo.value
246        assert error.status_code == 400
247        assert error.messages == {"json": ["Invalid JSON body."]}
248
249    def test_it_should_parse_header_arguments(self):
250        attrs = {"string": fields.Str(), "integer": fields.List(fields.Int())}
251
252        request = make_request(headers={"string": "value", "integer": ["1", "2"]})
253
254        parsed = parser.parse(attrs, request, location="headers")
255
256        assert parsed["string"] == value
257        assert parsed["integer"] == [1, 2]
258
259    def test_it_should_parse_cookies_arguments(self):
260        attrs = {"string": fields.Str(), "integer": fields.List(fields.Int())}
261
262        request = make_cookie_request(
263            [("string", "value"), ("integer", "1"), ("integer", "2")]
264        )
265
266        parsed = parser.parse(attrs, request, location="cookies")
267
268        assert parsed["string"] == value
269        assert parsed["integer"] == [2]
270
271    def test_it_should_parse_files_arguments(self):
272        attrs = {"string": fields.Str(), "integer": fields.List(fields.Int())}
273
274        request = make_files_request(
275            [("string", "value"), ("integer", "1"), ("integer", "2")]
276        )
277
278        parsed = parser.parse(attrs, request, location="files")
279
280        assert parsed["string"] == value
281        assert parsed["integer"] == [1, 2]
282
283    def test_it_should_parse_required_arguments(self):
284        args = {"foo": fields.Field(required=True)}
285
286        request = make_json_request({})
287
288        msg = "Missing data for required field."
289        with pytest.raises(tornado.web.HTTPError, match=msg):
290            parser.parse(args, request)
291
292    def test_it_should_parse_multiple_arg_required(self):
293        args = {"foo": fields.List(fields.Int(), required=True)}
294        request = make_json_request({})
295        msg = "Missing data for required field."
296        with pytest.raises(tornado.web.HTTPError, match=msg):
297            parser.parse(args, request)
298
299
300class TestUseArgs:
301    def test_it_should_pass_parsed_as_first_argument(self):
302        class Handler:
303            request = make_json_request({"key": "value"})
304
305            @use_args({"key": fields.Field()})
306            def get(self, *args, **kwargs):
307                assert args[0] == {"key": "value"}
308                assert kwargs == {}
309                return True
310
311        handler = Handler()
312        result = handler.get()
313
314        assert result is True
315
316    def test_it_should_pass_parsed_as_kwargs_arguments(self):
317        class Handler:
318            request = make_json_request({"key": "value"})
319
320            @use_kwargs({"key": fields.Field()})
321            def get(self, *args, **kwargs):
322                assert args == ()
323                assert kwargs == {"key": "value"}
324                return True
325
326        handler = Handler()
327        result = handler.get()
328
329        assert result is True
330
331    def test_it_should_be_validate_arguments_when_validator_is_passed(self):
332        class Handler:
333            request = make_json_request({"foo": 41})
334
335            @use_kwargs({"foo": fields.Int()}, validate=lambda args: args["foo"] > 42)
336            def get(self, args):
337                return True
338
339        handler = Handler()
340        with pytest.raises(tornado.web.HTTPError):
341            handler.get()
342
343
344def make_uri(args):
345    return "/test?" + urlencode(args)
346
347
348def make_form_body(args):
349    return urlencode(args)
350
351
352def make_json_body(args):
353    return json.dumps(args)
354
355
356def make_get_request(args):
357    return make_request(uri=make_uri(args))
358
359
360def make_form_request(args):
361    return make_request(
362        body=make_form_body(args),
363        headers={"Content-Type": "application/x-www-form-urlencoded"},
364    )
365
366
367def make_json_request(args, content_type="application/json; charset=UTF-8"):
368    return make_request(
369        body=make_json_body(args), headers={"Content-Type": content_type}
370    )
371
372
373def make_cookie_request(args):
374    return make_request(headers={"Cookie": " ;".join("=".join(pair) for pair in args)})
375
376
377def make_files_request(args):
378    files = {}
379
380    for key, value in args:
381        if isinstance(value, list):
382            files.setdefault(key, []).extend(value)
383        else:
384            files.setdefault(key, []).append(value)
385
386    return make_request(files=files)
387
388
389def make_request(uri=None, body=None, headers=None, files=None):
390    uri = uri if uri is not None else ""
391    body = body if body is not None else ""
392    method = "POST" if body else "GET"
393    # Need to make a mock connection right now because Tornado 4.0 requires a
394    # remote_ip in the context attribute. 4.1 addresses this, and this
395    # will be unnecessary once it is released
396    # https://github.com/tornadoweb/tornado/issues/1118
397    mock_connection = mock.Mock(spec=tornado.http1connection.HTTP1Connection)
398    mock_connection.context = mock.Mock()
399    mock_connection.remote_ip = None
400    content_type = headers.get("Content-Type", "") if headers else ""
401    request = tornado.httputil.HTTPServerRequest(
402        method=method,
403        uri=uri,
404        body=body,
405        headers=headers,
406        files=files,
407        connection=mock_connection,
408    )
409
410    tornado.httputil.parse_body_arguments(
411        content_type=content_type,
412        body=body.encode("latin-1") if hasattr(body, "encode") else body,
413        arguments=request.body_arguments,
414        files=request.files,
415    )
416
417    return request
418
419
420class EchoHandler(tornado.web.RequestHandler):
421    ARGS = {"name": fields.Str()}
422
423    @use_args(ARGS, location="query")
424    def get(self, args):
425        self.write(args)
426
427
428class EchoFormHandler(tornado.web.RequestHandler):
429    ARGS = {"name": fields.Str()}
430
431    @use_args(ARGS, location="form")
432    def post(self, args):
433        self.write(args)
434
435
436class EchoJSONHandler(tornado.web.RequestHandler):
437    ARGS = {"name": fields.Str()}
438
439    @use_args(ARGS)
440    def post(self, args):
441        self.write(args)
442
443
444class EchoWithParamHandler(tornado.web.RequestHandler):
445    ARGS = {"name": fields.Str()}
446
447    @use_args(ARGS, location="query")
448    def get(self, id, args):
449        self.write(args)
450
451
452echo_app = tornado.web.Application(
453    [
454        (r"/echo", EchoHandler),
455        (r"/echo_form", EchoFormHandler),
456        (r"/echo_json", EchoJSONHandler),
457        (r"/echo_with_param/(\d+)", EchoWithParamHandler),
458    ]
459)
460
461
462class TestApp(AsyncHTTPTestCase):
463    def get_app(self):
464        return echo_app
465
466    def test_post(self):
467        res = self.fetch(
468            "/echo_json",
469            method="POST",
470            headers={"Content-Type": "application/json"},
471            body=json.dumps({"name": "Steve"}),
472        )
473        json_body = parse_json(res.body)
474        assert json_body["name"] == "Steve"
475        res = self.fetch(
476            "/echo_json",
477            method="POST",
478            headers={"Content-Type": "application/json"},
479            body=json.dumps({}),
480        )
481        json_body = parse_json(res.body)
482        assert "name" not in json_body
483
484    def test_get_with_no_json_body(self):
485        res = self.fetch(
486            "/echo", method="GET", headers={"Content-Type": "application/json"}
487        )
488        json_body = parse_json(res.body)
489        assert "name" not in json_body
490
491    def test_get_path_param(self):
492        res = self.fetch(
493            "/echo_with_param/42?name=Steve",
494            method="GET",
495            headers={"Content-Type": "application/json"},
496        )
497        json_body = parse_json(res.body)
498        assert json_body == {"name": "Steve"}
499
500
501class ValidateHandler(tornado.web.RequestHandler):
502    ARGS = {"name": fields.Str(required=True)}
503
504    @use_args(ARGS)
505    def post(self, args):
506        self.write(args)
507
508    @use_kwargs(ARGS, location="query")
509    def get(self, name):
510        self.write({"status": "success"})
511
512
513def always_fail(val):
514    raise ma.ValidationError("something went wrong")
515
516
517class AlwaysFailHandler(tornado.web.RequestHandler):
518    ARGS = {"name": fields.Str(validate=always_fail)}
519
520    @use_args(ARGS)
521    def post(self, args):
522        self.write(args)
523
524
525validate_app = tornado.web.Application(
526    [(r"/echo", ValidateHandler), (r"/alwaysfail", AlwaysFailHandler)]
527)
528
529
530class TestValidateApp(AsyncHTTPTestCase):
531    def get_app(self):
532        return validate_app
533
534    def test_required_field_provided(self):
535        res = self.fetch(
536            "/echo",
537            method="POST",
538            headers={"Content-Type": "application/json"},
539            body=json.dumps({"name": "johnny"}),
540        )
541        json_body = parse_json(res.body)
542        assert json_body["name"] == "johnny"
543
544    def test_missing_required_field_throws_422(self):
545        res = self.fetch(
546            "/echo",
547            method="POST",
548            headers={"Content-Type": "application/json"},
549            body=json.dumps({"occupation": "pizza"}),
550        )
551        assert res.code == 422
552
553    def test_user_validator_returns_422_by_default(self):
554        res = self.fetch(
555            "/alwaysfail",
556            method="POST",
557            headers={"Content-Type": "application/json"},
558            body=json.dumps({"name": "Steve"}),
559        )
560        assert res.code == 422
561
562    def test_use_kwargs_with_error(self):
563        res = self.fetch("/echo", method="GET")
564        assert res.code == 422
565
566
567if __name__ == "__main__":
568    echo_app.listen(8888)
569    tornado.ioloop.IOLoop.instance().start()
570