1#!/usr/bin/env python
2# pylint: skip-file
3
4
5from __future__ import absolute_import, division, print_function
6from salt.ext.tornado import netutil
7from salt.ext.tornado.escape import json_decode, json_encode, utf8, _unicode, recursive_unicode, native_str
8from salt.ext.tornado import gen
9from salt.ext.tornado.http1connection import HTTP1Connection
10from salt.ext.tornado.httpserver import HTTPServer
11from salt.ext.tornado.httputil import HTTPHeaders, HTTPMessageDelegate, HTTPServerConnectionDelegate, ResponseStartLine
12from salt.ext.tornado.iostream import IOStream
13from salt.ext.tornado.log import gen_log
14from salt.ext.tornado.netutil import ssl_options_to_context
15from salt.ext.tornado.simple_httpclient import SimpleAsyncHTTPClient
16from salt.ext.tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog, gen_test
17from salt.ext.tornado.test.util import unittest, skipOnTravis
18from salt.ext.tornado.web import Application, RequestHandler, asynchronous, stream_request_body
19from contextlib import closing
20import datetime
21import gzip
22import os
23import shutil
24import socket
25import ssl
26import sys
27import tempfile
28from io import BytesIO
29
30
31def read_stream_body(stream, callback):
32    """Reads an HTTP response from `stream` and runs callback with its
33    headers and body."""
34    chunks = []
35
36    class Delegate(HTTPMessageDelegate):
37        def headers_received(self, start_line, headers):
38            self.headers = headers
39
40        def data_received(self, chunk):
41            chunks.append(chunk)
42
43        def finish(self):
44            callback((self.headers, b''.join(chunks)))
45    conn = HTTP1Connection(stream, True)
46    conn.read_response(Delegate())
47
48
49class HandlerBaseTestCase(AsyncHTTPTestCase):
50    def get_app(self):
51        return Application([('/', self.__class__.Handler)])
52
53    def fetch_json(self, *args, **kwargs):
54        response = self.fetch(*args, **kwargs)
55        response.rethrow()
56        return json_decode(response.body)
57
58
59class HelloWorldRequestHandler(RequestHandler):
60    def initialize(self, protocol="http"):
61        self.expected_protocol = protocol
62
63    def get(self):
64        if self.request.protocol != self.expected_protocol:
65            raise Exception("unexpected protocol")
66        self.finish("Hello world")
67
68    def post(self):
69        self.finish("Got %d bytes in POST" % len(self.request.body))
70
71
72# In pre-1.0 versions of openssl, SSLv23 clients always send SSLv2
73# ClientHello messages, which are rejected by SSLv3 and TLSv1
74# servers.  Note that while the OPENSSL_VERSION_INFO was formally
75# introduced in python3.2, it was present but undocumented in
76# python 2.7
77skipIfOldSSL = unittest.skipIf(
78    getattr(ssl, 'OPENSSL_VERSION_INFO', (0, 0)) < (1, 0),
79    "old version of ssl module and/or openssl")
80
81
82class BaseSSLTest(AsyncHTTPSTestCase):
83    def get_app(self):
84        return Application([('/', HelloWorldRequestHandler,
85                             dict(protocol="https"))])
86
87
88class SSLTestMixin(object):
89    def get_ssl_options(self):
90        return dict(ssl_version=self.get_ssl_version(),  # type: ignore
91                    **AsyncHTTPSTestCase.get_ssl_options())
92
93    def get_ssl_version(self):
94        raise NotImplementedError()
95
96    def test_ssl(self):
97        response = self.fetch('/')
98        self.assertEqual(response.body, b"Hello world")
99
100    def test_large_post(self):
101        response = self.fetch('/',
102                              method='POST',
103                              body='A' * 5000)
104        self.assertEqual(response.body, b"Got 5000 bytes in POST")
105
106    def test_non_ssl_request(self):
107        # Make sure the server closes the connection when it gets a non-ssl
108        # connection, rather than waiting for a timeout or otherwise
109        # misbehaving.
110        with ExpectLog(gen_log, '(SSL Error|uncaught exception)'):
111            with ExpectLog(gen_log, 'Uncaught exception', required=False):
112                self.http_client.fetch(
113                    self.get_url("/").replace('https:', 'http:'),
114                    self.stop,
115                    request_timeout=3600,
116                    connect_timeout=3600)
117                response = self.wait()
118        self.assertEqual(response.code, 599)
119
120    def test_error_logging(self):
121        # No stack traces are logged for SSL errors.
122        with ExpectLog(gen_log, 'SSL Error') as expect_log:
123            self.http_client.fetch(
124                self.get_url("/").replace("https:", "http:"),
125                self.stop)
126            response = self.wait()
127            self.assertEqual(response.code, 599)
128        self.assertFalse(expect_log.logged_stack)
129
130# Python's SSL implementation differs significantly between versions.
131# For example, SSLv3 and TLSv1 throw an exception if you try to read
132# from the socket before the handshake is complete, but the default
133# of SSLv23 allows it.
134
135
136class SSLv23Test(BaseSSLTest, SSLTestMixin):
137    def get_ssl_version(self):
138        return ssl.PROTOCOL_SSLv23
139
140
141@skipIfOldSSL
142class SSLv3Test(BaseSSLTest, SSLTestMixin):
143    def get_ssl_version(self):
144        return ssl.PROTOCOL_SSLv3
145
146
147@skipIfOldSSL
148class TLSv1Test(BaseSSLTest, SSLTestMixin):
149    def get_ssl_version(self):
150        return ssl.PROTOCOL_TLSv1
151
152
153@unittest.skipIf(not hasattr(ssl, 'SSLContext'), 'ssl.SSLContext not present')
154class SSLContextTest(BaseSSLTest, SSLTestMixin):
155    def get_ssl_options(self):
156        context = ssl_options_to_context(
157            AsyncHTTPSTestCase.get_ssl_options(self))
158        assert isinstance(context, ssl.SSLContext)
159        return context
160
161
162class BadSSLOptionsTest(unittest.TestCase):
163    def test_missing_arguments(self):
164        application = Application()
165        self.assertRaises(KeyError, HTTPServer, application, ssl_options={
166            "keyfile": "/__missing__.crt",
167        })
168
169    def test_missing_key(self):
170        """A missing SSL key should cause an immediate exception."""
171
172        application = Application()
173        module_dir = os.path.dirname(__file__)
174        existing_certificate = os.path.join(module_dir, 'test.crt')
175        existing_key = os.path.join(module_dir, 'test.key')
176
177        self.assertRaises((ValueError, IOError),
178                          HTTPServer, application, ssl_options={
179                              "certfile": "/__mising__.crt",
180        })
181        self.assertRaises((ValueError, IOError),
182                          HTTPServer, application, ssl_options={
183                              "certfile": existing_certificate,
184                              "keyfile": "/__missing__.key"
185        })
186
187        # This actually works because both files exist
188        HTTPServer(application, ssl_options={
189                   "certfile": existing_certificate,
190                   "keyfile": existing_key,
191                   })
192
193
194class MultipartTestHandler(RequestHandler):
195    def post(self):
196        self.finish({"header": self.request.headers["X-Header-Encoding-Test"],
197                     "argument": self.get_argument("argument"),
198                     "filename": self.request.files["files"][0].filename,
199                     "filebody": _unicode(self.request.files["files"][0]["body"]),
200                     })
201
202
203# This test is also called from wsgi_test
204class HTTPConnectionTest(AsyncHTTPTestCase):
205    def get_handlers(self):
206        return [("/multipart", MultipartTestHandler),
207                ("/hello", HelloWorldRequestHandler)]
208
209    def get_app(self):
210        return Application(self.get_handlers())
211
212    def raw_fetch(self, headers, body, newline=b"\r\n"):
213        with closing(IOStream(socket.socket())) as stream:
214            stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
215            self.wait()
216            stream.write(
217                newline.join(headers +
218                             [utf8("Content-Length: %d" % len(body))]) +
219                newline + newline + body)
220            read_stream_body(stream, self.stop)
221            headers, body = self.wait()
222            return body
223
224    def test_multipart_form(self):
225        # Encodings here are tricky:  Headers are latin1, bodies can be
226        # anything (we use utf8 by default).
227        response = self.raw_fetch([
228            b"POST /multipart HTTP/1.0",
229            b"Content-Type: multipart/form-data; boundary=1234567890",
230            b"X-Header-encoding-test: \xe9",
231        ],
232            b"\r\n".join([
233                b"Content-Disposition: form-data; name=argument",
234                b"",
235                u"\u00e1".encode("utf-8"),
236                b"--1234567890",
237                u'Content-Disposition: form-data; name="files"; filename="\u00f3"'.encode("utf8"),
238                b"",
239                u"\u00fa".encode("utf-8"),
240                b"--1234567890--",
241                b"",
242            ]))
243        data = json_decode(response)
244        self.assertEqual(u"\u00e9", data["header"])
245        self.assertEqual(u"\u00e1", data["argument"])
246        self.assertEqual(u"\u00f3", data["filename"])
247        self.assertEqual(u"\u00fa", data["filebody"])
248
249    def test_newlines(self):
250        # We support both CRLF and bare LF as line separators.
251        for newline in (b"\r\n", b"\n"):
252            response = self.raw_fetch([b"GET /hello HTTP/1.0"], b"",
253                                      newline=newline)
254            self.assertEqual(response, b'Hello world')
255
256    def test_100_continue(self):
257        # Run through a 100-continue interaction by hand:
258        # When given Expect: 100-continue, we get a 100 response after the
259        # headers, and then the real response after the body.
260        stream = IOStream(socket.socket(), io_loop=self.io_loop)
261        stream.connect(("127.0.0.1", self.get_http_port()), callback=self.stop)
262        self.wait()
263        stream.write(b"\r\n".join([b"POST /hello HTTP/1.1",
264                                   b"Content-Length: 1024",
265                                   b"Expect: 100-continue",
266                                   b"Connection: close",
267                                   b"\r\n"]), callback=self.stop)
268        self.wait()
269        stream.read_until(b"\r\n\r\n", self.stop)
270        data = self.wait()
271        self.assertTrue(data.startswith(b"HTTP/1.1 100 "), data)
272        stream.write(b"a" * 1024)
273        stream.read_until(b"\r\n", self.stop)
274        first_line = self.wait()
275        self.assertTrue(first_line.startswith(b"HTTP/1.1 200"), first_line)
276        stream.read_until(b"\r\n\r\n", self.stop)
277        header_data = self.wait()
278        headers = HTTPHeaders.parse(native_str(header_data.decode('latin1')))
279        stream.read_bytes(int(headers["Content-Length"]), self.stop)
280        body = self.wait()
281        self.assertEqual(body, b"Got 1024 bytes in POST")
282        stream.close()
283
284
285class EchoHandler(RequestHandler):
286    def get(self):
287        self.write(recursive_unicode(self.request.arguments))
288
289    def post(self):
290        self.write(recursive_unicode(self.request.arguments))
291
292
293class TypeCheckHandler(RequestHandler):
294    def prepare(self):
295        self.errors = {}
296        fields = [
297            ('method', str),
298            ('uri', str),
299            ('version', str),
300            ('remote_ip', str),
301            ('protocol', str),
302            ('host', str),
303            ('path', str),
304            ('query', str),
305        ]
306        for field, expected_type in fields:
307            self.check_type(field, getattr(self.request, field), expected_type)
308
309        self.check_type('header_key', list(self.request.headers.keys())[0], str)
310        self.check_type('header_value', list(self.request.headers.values())[0], str)
311
312        self.check_type('cookie_key', list(self.request.cookies.keys())[0], str)
313        self.check_type('cookie_value', list(self.request.cookies.values())[0].value, str)
314        # secure cookies
315
316        self.check_type('arg_key', list(self.request.arguments.keys())[0], str)
317        self.check_type('arg_value', list(self.request.arguments.values())[0][0], bytes)
318
319    def post(self):
320        self.check_type('body', self.request.body, bytes)
321        self.write(self.errors)
322
323    def get(self):
324        self.write(self.errors)
325
326    def check_type(self, name, obj, expected_type):
327        actual_type = type(obj)
328        if expected_type != actual_type:
329            self.errors[name] = "expected %s, got %s" % (expected_type,
330                                                         actual_type)
331
332
333class HTTPServerTest(AsyncHTTPTestCase):
334    def get_app(self):
335        return Application([("/echo", EchoHandler),
336                            ("/typecheck", TypeCheckHandler),
337                            ("//doubleslash", EchoHandler),
338                            ])
339
340    def test_query_string_encoding(self):
341        response = self.fetch("/echo?foo=%C3%A9")
342        data = json_decode(response.body)
343        self.assertEqual(data, {u"foo": [u"\u00e9"]})
344
345    def test_empty_query_string(self):
346        response = self.fetch("/echo?foo=&foo=")
347        data = json_decode(response.body)
348        self.assertEqual(data, {u"foo": [u"", u""]})
349
350    def test_empty_post_parameters(self):
351        response = self.fetch("/echo", method="POST", body="foo=&bar=")
352        data = json_decode(response.body)
353        self.assertEqual(data, {u"foo": [u""], u"bar": [u""]})
354
355    def test_types(self):
356        headers = {"Cookie": "foo=bar"}
357        response = self.fetch("/typecheck?foo=bar", headers=headers)
358        data = json_decode(response.body)
359        self.assertEqual(data, {})
360
361        response = self.fetch("/typecheck", method="POST", body="foo=bar", headers=headers)
362        data = json_decode(response.body)
363        self.assertEqual(data, {})
364
365    def test_double_slash(self):
366        # urlparse.urlsplit (which tornado.httpserver used to use
367        # incorrectly) would parse paths beginning with "//" as
368        # protocol-relative urls.
369        response = self.fetch("//doubleslash")
370        self.assertEqual(200, response.code)
371        self.assertEqual(json_decode(response.body), {})
372
373    def test_malformed_body(self):
374        # parse_qs is pretty forgiving, but it will fail on python 3
375        # if the data is not utf8.  On python 2 parse_qs will work,
376        # but then the recursive_unicode call in EchoHandler will
377        # fail.
378        if str is bytes:
379            return
380        with ExpectLog(gen_log, 'Invalid x-www-form-urlencoded body'):
381            response = self.fetch(
382                '/echo', method="POST",
383                headers={'Content-Type': 'application/x-www-form-urlencoded'},
384                body=b'\xe9')
385        self.assertEqual(200, response.code)
386        self.assertEqual(b'{}', response.body)
387
388
389class HTTPServerRawTest(AsyncHTTPTestCase):
390    def get_app(self):
391        return Application([
392            ('/echo', EchoHandler),
393        ])
394
395    def setUp(self):
396        super(HTTPServerRawTest, self).setUp()
397        self.stream = IOStream(socket.socket())
398        self.stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
399        self.wait()
400
401    def tearDown(self):
402        self.stream.close()
403        super(HTTPServerRawTest, self).tearDown()
404
405    def test_empty_request(self):
406        self.stream.close()
407        self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop)
408        self.wait()
409
410    def test_malformed_first_line(self):
411        with ExpectLog(gen_log, '.*Malformed HTTP request line'):
412            self.stream.write(b'asdf\r\n\r\n')
413            # TODO: need an async version of ExpectLog so we don't need
414            # hard-coded timeouts here.
415            self.io_loop.add_timeout(datetime.timedelta(seconds=0.05),
416                                     self.stop)
417            self.wait()
418
419    def test_malformed_headers(self):
420        with ExpectLog(gen_log, '.*Malformed HTTP headers'):
421            self.stream.write(b'GET / HTTP/1.0\r\nasdf\r\n\r\n')
422            self.io_loop.add_timeout(datetime.timedelta(seconds=0.05),
423                                     self.stop)
424            self.wait()
425
426    def test_chunked_request_body(self):
427        # Chunked requests are not widely supported and we don't have a way
428        # to generate them in AsyncHTTPClient, but HTTPServer will read them.
429        self.stream.write(b"""\
430POST /echo HTTP/1.1
431Transfer-Encoding: chunked
432Content-Type: application/x-www-form-urlencoded
433
4344
435foo=
4363
437bar
4380
439
440""".replace(b"\n", b"\r\n"))
441        read_stream_body(self.stream, self.stop)
442        headers, response = self.wait()
443        self.assertEqual(json_decode(response), {u'foo': [u'bar']})
444
445    def test_chunked_request_uppercase(self):
446        # As per RFC 2616 section 3.6, "Transfer-Encoding" header's value is
447        # case-insensitive.
448        self.stream.write(b"""\
449POST /echo HTTP/1.1
450Transfer-Encoding: Chunked
451Content-Type: application/x-www-form-urlencoded
452
4534
454foo=
4553
456bar
4570
458
459""".replace(b"\n", b"\r\n"))
460        read_stream_body(self.stream, self.stop)
461        headers, response = self.wait()
462        self.assertEqual(json_decode(response), {u'foo': [u'bar']})
463
464    def test_invalid_content_length(self):
465        with ExpectLog(gen_log, '.*Only integer Content-Length is allowed'):
466            self.stream.write(b"""\
467POST /echo HTTP/1.1
468Content-Length: foo
469
470bar
471
472""".replace(b"\n", b"\r\n"))
473            self.stream.read_until_close(self.stop)
474            self.wait()
475
476
477class XHeaderTest(HandlerBaseTestCase):
478    class Handler(RequestHandler):
479        def get(self):
480            self.write(dict(remote_ip=self.request.remote_ip,
481                            remote_protocol=self.request.protocol))
482
483    def get_httpserver_options(self):
484        return dict(xheaders=True, trusted_downstream=['5.5.5.5'])
485
486    def test_ip_headers(self):
487        self.assertEqual(self.fetch_json("/")["remote_ip"], "127.0.0.1")
488
489        valid_ipv4 = {"X-Real-IP": "4.4.4.4"}
490        self.assertEqual(
491            self.fetch_json("/", headers=valid_ipv4)["remote_ip"],
492            "4.4.4.4")
493
494        valid_ipv4_list = {"X-Forwarded-For": "127.0.0.1, 4.4.4.4"}
495        self.assertEqual(
496            self.fetch_json("/", headers=valid_ipv4_list)["remote_ip"],
497            "4.4.4.4")
498
499        valid_ipv6 = {"X-Real-IP": "2620:0:1cfe:face:b00c::3"}
500        self.assertEqual(
501            self.fetch_json("/", headers=valid_ipv6)["remote_ip"],
502            "2620:0:1cfe:face:b00c::3")
503
504        valid_ipv6_list = {"X-Forwarded-For": "::1, 2620:0:1cfe:face:b00c::3"}
505        self.assertEqual(
506            self.fetch_json("/", headers=valid_ipv6_list)["remote_ip"],
507            "2620:0:1cfe:face:b00c::3")
508
509        invalid_chars = {"X-Real-IP": "4.4.4.4<script>"}
510        self.assertEqual(
511            self.fetch_json("/", headers=invalid_chars)["remote_ip"],
512            "127.0.0.1")
513
514        invalid_chars_list = {"X-Forwarded-For": "4.4.4.4, 5.5.5.5<script>"}
515        self.assertEqual(
516            self.fetch_json("/", headers=invalid_chars_list)["remote_ip"],
517            "127.0.0.1")
518
519        invalid_host = {"X-Real-IP": "www.google.com"}
520        self.assertEqual(
521            self.fetch_json("/", headers=invalid_host)["remote_ip"],
522            "127.0.0.1")
523
524    def test_trusted_downstream(self):
525
526        valid_ipv4_list = {"X-Forwarded-For": "127.0.0.1, 4.4.4.4, 5.5.5.5"}
527        self.assertEqual(
528            self.fetch_json("/", headers=valid_ipv4_list)["remote_ip"],
529            "4.4.4.4")
530
531    def test_scheme_headers(self):
532        self.assertEqual(self.fetch_json("/")["remote_protocol"], "http")
533
534        https_scheme = {"X-Scheme": "https"}
535        self.assertEqual(
536            self.fetch_json("/", headers=https_scheme)["remote_protocol"],
537            "https")
538
539        https_forwarded = {"X-Forwarded-Proto": "https"}
540        self.assertEqual(
541            self.fetch_json("/", headers=https_forwarded)["remote_protocol"],
542            "https")
543
544        bad_forwarded = {"X-Forwarded-Proto": "unknown"}
545        self.assertEqual(
546            self.fetch_json("/", headers=bad_forwarded)["remote_protocol"],
547            "http")
548
549
550class SSLXHeaderTest(AsyncHTTPSTestCase, HandlerBaseTestCase):
551    def get_app(self):
552        return Application([('/', XHeaderTest.Handler)])
553
554    def get_httpserver_options(self):
555        output = super(SSLXHeaderTest, self).get_httpserver_options()
556        output['xheaders'] = True
557        return output
558
559    def test_request_without_xprotocol(self):
560        self.assertEqual(self.fetch_json("/")["remote_protocol"], "https")
561
562        http_scheme = {"X-Scheme": "http"}
563        self.assertEqual(
564            self.fetch_json("/", headers=http_scheme)["remote_protocol"], "http")
565
566        bad_scheme = {"X-Scheme": "unknown"}
567        self.assertEqual(
568            self.fetch_json("/", headers=bad_scheme)["remote_protocol"], "https")
569
570
571class ManualProtocolTest(HandlerBaseTestCase):
572    class Handler(RequestHandler):
573        def get(self):
574            self.write(dict(protocol=self.request.protocol))
575
576    def get_httpserver_options(self):
577        return dict(protocol='https')
578
579    def test_manual_protocol(self):
580        self.assertEqual(self.fetch_json('/')['protocol'], 'https')
581
582
583@unittest.skipIf(not hasattr(socket, 'AF_UNIX') or sys.platform == 'cygwin',
584                 "unix sockets not supported on this platform")
585class UnixSocketTest(AsyncTestCase):
586    """HTTPServers can listen on Unix sockets too.
587
588    Why would you want to do this?  Nginx can proxy to backends listening
589    on unix sockets, for one thing (and managing a namespace for unix
590    sockets can be easier than managing a bunch of TCP port numbers).
591
592    Unfortunately, there's no way to specify a unix socket in a url for
593    an HTTP client, so we have to test this by hand.
594    """
595    def setUp(self):
596        super(UnixSocketTest, self).setUp()
597        self.tmpdir = tempfile.mkdtemp()
598        self.sockfile = os.path.join(self.tmpdir, "test.sock")
599        sock = netutil.bind_unix_socket(self.sockfile)
600        app = Application([("/hello", HelloWorldRequestHandler)])
601        self.server = HTTPServer(app, io_loop=self.io_loop)
602        self.server.add_socket(sock)
603        self.stream = IOStream(socket.socket(socket.AF_UNIX), io_loop=self.io_loop)
604        self.stream.connect(self.sockfile, self.stop)
605        self.wait()
606
607    def tearDown(self):
608        self.stream.close()
609        self.server.stop()
610        shutil.rmtree(self.tmpdir)
611        super(UnixSocketTest, self).tearDown()
612
613    def test_unix_socket(self):
614        self.stream.write(b"GET /hello HTTP/1.0\r\n\r\n")
615        self.stream.read_until(b"\r\n", self.stop)
616        response = self.wait()
617        self.assertEqual(response, b"HTTP/1.1 200 OK\r\n")
618        self.stream.read_until(b"\r\n\r\n", self.stop)
619        headers = HTTPHeaders.parse(self.wait().decode('latin1'))
620        self.stream.read_bytes(int(headers["Content-Length"]), self.stop)
621        body = self.wait()
622        self.assertEqual(body, b"Hello world")
623
624    def test_unix_socket_bad_request(self):
625        # Unix sockets don't have remote addresses so they just return an
626        # empty string.
627        with ExpectLog(gen_log, "Malformed HTTP message from"):
628            self.stream.write(b"garbage\r\n\r\n")
629            self.stream.read_until_close(self.stop)
630            response = self.wait()
631        self.assertEqual(response, b"")
632
633
634class KeepAliveTest(AsyncHTTPTestCase):
635    """Tests various scenarios for HTTP 1.1 keep-alive support.
636
637    These tests don't use AsyncHTTPClient because we want to control
638    connection reuse and closing.
639    """
640    def get_app(self):
641        class HelloHandler(RequestHandler):
642            def get(self):
643                self.finish('Hello world')
644
645            def post(self):
646                self.finish('Hello world')
647
648        class LargeHandler(RequestHandler):
649            def get(self):
650                # 512KB should be bigger than the socket buffers so it will
651                # be written out in chunks.
652                self.write(''.join(chr(i % 256) * 1024 for i in range(512)))
653
654        class FinishOnCloseHandler(RequestHandler):
655            @asynchronous
656            def get(self):
657                self.flush()
658
659            def on_connection_close(self):
660                # This is not very realistic, but finishing the request
661                # from the close callback has the right timing to mimic
662                # some errors seen in the wild.
663                self.finish('closed')
664
665        return Application([('/', HelloHandler),
666                            ('/large', LargeHandler),
667                            ('/finish_on_close', FinishOnCloseHandler)])
668
669    def setUp(self):
670        super(KeepAliveTest, self).setUp()
671        self.http_version = b'HTTP/1.1'
672
673    def tearDown(self):
674        # We just closed the client side of the socket; let the IOLoop run
675        # once to make sure the server side got the message.
676        self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop)
677        self.wait()
678
679        if hasattr(self, 'stream'):
680            self.stream.close()
681        super(KeepAliveTest, self).tearDown()
682
683    # The next few methods are a crude manual http client
684    def connect(self):
685        self.stream = IOStream(socket.socket(), io_loop=self.io_loop)
686        self.stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
687        self.wait()
688
689    def read_headers(self):
690        self.stream.read_until(b'\r\n', self.stop)
691        first_line = self.wait()
692        self.assertTrue(first_line.startswith(b'HTTP/1.1 200'), first_line)
693        self.stream.read_until(b'\r\n\r\n', self.stop)
694        header_bytes = self.wait()
695        headers = HTTPHeaders.parse(header_bytes.decode('latin1'))
696        return headers
697
698    def read_response(self):
699        self.headers = self.read_headers()
700        self.stream.read_bytes(int(self.headers['Content-Length']), self.stop)
701        body = self.wait()
702        self.assertEqual(b'Hello world', body)
703
704    def close(self):
705        self.stream.close()
706        del self.stream
707
708    def test_two_requests(self):
709        self.connect()
710        self.stream.write(b'GET / HTTP/1.1\r\n\r\n')
711        self.read_response()
712        self.stream.write(b'GET / HTTP/1.1\r\n\r\n')
713        self.read_response()
714        self.close()
715
716    def test_request_close(self):
717        self.connect()
718        self.stream.write(b'GET / HTTP/1.1\r\nConnection: close\r\n\r\n')
719        self.read_response()
720        self.stream.read_until_close(callback=self.stop)
721        data = self.wait()
722        self.assertTrue(not data)
723        self.close()
724
725    # keepalive is supported for http 1.0 too, but it's opt-in
726    def test_http10(self):
727        self.http_version = b'HTTP/1.0'
728        self.connect()
729        self.stream.write(b'GET / HTTP/1.0\r\n\r\n')
730        self.read_response()
731        self.stream.read_until_close(callback=self.stop)
732        data = self.wait()
733        self.assertTrue(not data)
734        self.assertTrue('Connection' not in self.headers)
735        self.close()
736
737    def test_http10_keepalive(self):
738        self.http_version = b'HTTP/1.0'
739        self.connect()
740        self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
741        self.read_response()
742        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
743        self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
744        self.read_response()
745        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
746        self.close()
747
748    def test_http10_keepalive_extra_crlf(self):
749        self.http_version = b'HTTP/1.0'
750        self.connect()
751        self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n\r\n')
752        self.read_response()
753        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
754        self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
755        self.read_response()
756        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
757        self.close()
758
759    def test_pipelined_requests(self):
760        self.connect()
761        self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n')
762        self.read_response()
763        self.read_response()
764        self.close()
765
766    def test_pipelined_cancel(self):
767        self.connect()
768        self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n')
769        # only read once
770        self.read_response()
771        self.close()
772
773    def test_cancel_during_download(self):
774        self.connect()
775        self.stream.write(b'GET /large HTTP/1.1\r\n\r\n')
776        self.read_headers()
777        self.stream.read_bytes(1024, self.stop)
778        self.wait()
779        self.close()
780
781    def test_finish_while_closed(self):
782        self.connect()
783        self.stream.write(b'GET /finish_on_close HTTP/1.1\r\n\r\n')
784        self.read_headers()
785        self.close()
786
787    def test_keepalive_chunked(self):
788        self.http_version = b'HTTP/1.0'
789        self.connect()
790        self.stream.write(b'POST / HTTP/1.0\r\n'
791                          b'Connection: keep-alive\r\n'
792                          b'Transfer-Encoding: chunked\r\n'
793                          b'\r\n'
794                          b'0\r\n'
795                          b'\r\n')
796        self.read_response()
797        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
798        self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
799        self.read_response()
800        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
801        self.close()
802
803
804class GzipBaseTest(object):
805    def get_app(self):
806        return Application([('/', EchoHandler)])
807
808    def post_gzip(self, body):
809        bytesio = BytesIO()
810        gzip_file = gzip.GzipFile(mode='w', fileobj=bytesio)
811        gzip_file.write(utf8(body))
812        gzip_file.close()
813        compressed_body = bytesio.getvalue()
814        return self.fetch('/', method='POST', body=compressed_body,
815                          headers={'Content-Encoding': 'gzip'})
816
817    def test_uncompressed(self):
818        response = self.fetch('/', method='POST', body='foo=bar')
819        self.assertEquals(json_decode(response.body), {u'foo': [u'bar']})
820
821
822class GzipTest(GzipBaseTest, AsyncHTTPTestCase):
823    def get_httpserver_options(self):
824        return dict(decompress_request=True)
825
826    def test_gzip(self):
827        response = self.post_gzip('foo=bar')
828        self.assertEquals(json_decode(response.body), {u'foo': [u'bar']})
829
830
831class GzipUnsupportedTest(GzipBaseTest, AsyncHTTPTestCase):
832    def test_gzip_unsupported(self):
833        # Gzip support is opt-in; without it the server fails to parse
834        # the body (but parsing form bodies is currently just a log message,
835        # not a fatal error).
836        with ExpectLog(gen_log, "Unsupported Content-Encoding"):
837            response = self.post_gzip('foo=bar')
838        self.assertEquals(json_decode(response.body), {})
839
840
841class StreamingChunkSizeTest(AsyncHTTPTestCase):
842    # 50 characters long, and repetitive so it can be compressed.
843    BODY = b'01234567890123456789012345678901234567890123456789'
844    CHUNK_SIZE = 16
845
846    def get_http_client(self):
847        # body_producer doesn't work on curl_httpclient, so override the
848        # configured AsyncHTTPClient implementation.
849        return SimpleAsyncHTTPClient(io_loop=self.io_loop)
850
851    def get_httpserver_options(self):
852        return dict(chunk_size=self.CHUNK_SIZE, decompress_request=True)
853
854    class MessageDelegate(HTTPMessageDelegate):
855        def __init__(self, connection):
856            self.connection = connection
857
858        def headers_received(self, start_line, headers):
859            self.chunk_lengths = []
860
861        def data_received(self, chunk):
862            self.chunk_lengths.append(len(chunk))
863
864        def finish(self):
865            response_body = utf8(json_encode(self.chunk_lengths))
866            self.connection.write_headers(
867                ResponseStartLine('HTTP/1.1', 200, 'OK'),
868                HTTPHeaders({'Content-Length': str(len(response_body))}))
869            self.connection.write(response_body)
870            self.connection.finish()
871
872    def get_app(self):
873        class App(HTTPServerConnectionDelegate):
874            def start_request(self, server_conn, request_conn):
875                return StreamingChunkSizeTest.MessageDelegate(request_conn)
876        return App()
877
878    def fetch_chunk_sizes(self, **kwargs):
879        response = self.fetch('/', method='POST', **kwargs)
880        response.rethrow()
881        chunks = json_decode(response.body)
882        self.assertEqual(len(self.BODY), sum(chunks))
883        for chunk_size in chunks:
884            self.assertLessEqual(chunk_size, self.CHUNK_SIZE,
885                                 'oversized chunk: ' + str(chunks))
886            self.assertGreater(chunk_size, 0,
887                               'empty chunk: ' + str(chunks))
888        return chunks
889
890    def compress(self, body):
891        bytesio = BytesIO()
892        gzfile = gzip.GzipFile(mode='w', fileobj=bytesio)
893        gzfile.write(body)
894        gzfile.close()
895        compressed = bytesio.getvalue()
896        if len(compressed) >= len(body):
897            raise Exception("body did not shrink when compressed")
898        return compressed
899
900    def test_regular_body(self):
901        chunks = self.fetch_chunk_sizes(body=self.BODY)
902        # Without compression we know exactly what to expect.
903        self.assertEqual([16, 16, 16, 2], chunks)
904
905    def test_compressed_body(self):
906        self.fetch_chunk_sizes(body=self.compress(self.BODY),
907                               headers={'Content-Encoding': 'gzip'})
908        # Compression creates irregular boundaries so the assertions
909        # in fetch_chunk_sizes are as specific as we can get.
910
911    def test_chunked_body(self):
912        def body_producer(write):
913            write(self.BODY[:20])
914            write(self.BODY[20:])
915        chunks = self.fetch_chunk_sizes(body_producer=body_producer)
916        # HTTP chunk boundaries translate to application-visible breaks
917        self.assertEqual([16, 4, 16, 14], chunks)
918
919    def test_chunked_compressed(self):
920        compressed = self.compress(self.BODY)
921        self.assertGreater(len(compressed), 20)
922
923        def body_producer(write):
924            write(compressed[:20])
925            write(compressed[20:])
926        self.fetch_chunk_sizes(body_producer=body_producer,
927                               headers={'Content-Encoding': 'gzip'})
928
929
930class MaxHeaderSizeTest(AsyncHTTPTestCase):
931    def get_app(self):
932        return Application([('/', HelloWorldRequestHandler)])
933
934    def get_httpserver_options(self):
935        return dict(max_header_size=1024)
936
937    def test_small_headers(self):
938        response = self.fetch("/", headers={'X-Filler': 'a' * 100})
939        response.rethrow()
940        self.assertEqual(response.body, b"Hello world")
941
942    def test_large_headers(self):
943        with ExpectLog(gen_log, "Unsatisfiable read", required=False):
944            response = self.fetch("/", headers={'X-Filler': 'a' * 1000})
945        # 431 is "Request Header Fields Too Large", defined in RFC
946        # 6585. However, many implementations just close the
947        # connection in this case, resulting in a 599.
948        self.assertIn(response.code, (431, 599))
949
950
951@skipOnTravis
952class IdleTimeoutTest(AsyncHTTPTestCase):
953    def get_app(self):
954        return Application([('/', HelloWorldRequestHandler)])
955
956    def get_httpserver_options(self):
957        return dict(idle_connection_timeout=0.1)
958
959    def setUp(self):
960        super(IdleTimeoutTest, self).setUp()
961        self.streams = []
962
963    def tearDown(self):
964        super(IdleTimeoutTest, self).tearDown()
965        for stream in self.streams:
966            stream.close()
967
968    def connect(self):
969        stream = IOStream(socket.socket())
970        stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
971        self.wait()
972        self.streams.append(stream)
973        return stream
974
975    def test_unused_connection(self):
976        stream = self.connect()
977        stream.set_close_callback(self.stop)
978        self.wait()
979
980    def test_idle_after_use(self):
981        stream = self.connect()
982        stream.set_close_callback(lambda: self.stop("closed"))
983
984        # Use the connection twice to make sure keep-alives are working
985        for i in range(2):
986            stream.write(b"GET / HTTP/1.1\r\n\r\n")
987            stream.read_until(b"\r\n\r\n", self.stop)
988            self.wait()
989            stream.read_bytes(11, self.stop)
990            data = self.wait()
991            self.assertEqual(data, b"Hello world")
992
993        # Now let the timeout trigger and close the connection.
994        data = self.wait()
995        self.assertEqual(data, "closed")
996
997
998class BodyLimitsTest(AsyncHTTPTestCase):
999    def get_app(self):
1000        class BufferedHandler(RequestHandler):
1001            def put(self):
1002                self.write(str(len(self.request.body)))
1003
1004        @stream_request_body
1005        class StreamingHandler(RequestHandler):
1006            def initialize(self):
1007                self.bytes_read = 0
1008
1009            def prepare(self):
1010                if 'expected_size' in self.request.arguments:
1011                    self.request.connection.set_max_body_size(
1012                        int(self.get_argument('expected_size')))
1013                if 'body_timeout' in self.request.arguments:
1014                    self.request.connection.set_body_timeout(
1015                        float(self.get_argument('body_timeout')))
1016
1017            def data_received(self, data):
1018                self.bytes_read += len(data)
1019
1020            def put(self):
1021                self.write(str(self.bytes_read))
1022
1023        return Application([('/buffered', BufferedHandler),
1024                            ('/streaming', StreamingHandler)])
1025
1026    def get_httpserver_options(self):
1027        return dict(body_timeout=3600, max_body_size=4096)
1028
1029    def get_http_client(self):
1030        # body_producer doesn't work on curl_httpclient, so override the
1031        # configured AsyncHTTPClient implementation.
1032        return SimpleAsyncHTTPClient(io_loop=self.io_loop)
1033
1034    def test_small_body(self):
1035        response = self.fetch('/buffered', method='PUT', body=b'a' * 4096)
1036        self.assertEqual(response.body, b'4096')
1037        response = self.fetch('/streaming', method='PUT', body=b'a' * 4096)
1038        self.assertEqual(response.body, b'4096')
1039
1040    def test_large_body_buffered(self):
1041        with ExpectLog(gen_log, '.*Content-Length too long'):
1042            response = self.fetch('/buffered', method='PUT', body=b'a' * 10240)
1043        self.assertEqual(response.code, 599)
1044
1045    def test_large_body_buffered_chunked(self):
1046        with ExpectLog(gen_log, '.*chunked body too large'):
1047            response = self.fetch('/buffered', method='PUT',
1048                                  body_producer=lambda write: write(b'a' * 10240))
1049        self.assertEqual(response.code, 599)
1050
1051    def test_large_body_streaming(self):
1052        with ExpectLog(gen_log, '.*Content-Length too long'):
1053            response = self.fetch('/streaming', method='PUT', body=b'a' * 10240)
1054        self.assertEqual(response.code, 599)
1055
1056    def test_large_body_streaming_chunked(self):
1057        with ExpectLog(gen_log, '.*chunked body too large'):
1058            response = self.fetch('/streaming', method='PUT',
1059                                  body_producer=lambda write: write(b'a' * 10240))
1060        self.assertEqual(response.code, 599)
1061
1062    def test_large_body_streaming_override(self):
1063        response = self.fetch('/streaming?expected_size=10240', method='PUT',
1064                              body=b'a' * 10240)
1065        self.assertEqual(response.body, b'10240')
1066
1067    def test_large_body_streaming_chunked_override(self):
1068        response = self.fetch('/streaming?expected_size=10240', method='PUT',
1069                              body_producer=lambda write: write(b'a' * 10240))
1070        self.assertEqual(response.body, b'10240')
1071
1072    @gen_test
1073    def test_timeout(self):
1074        stream = IOStream(socket.socket())
1075        try:
1076            yield stream.connect(('127.0.0.1', self.get_http_port()))
1077            # Use a raw stream because AsyncHTTPClient won't let us read a
1078            # response without finishing a body.
1079            stream.write(b'PUT /streaming?body_timeout=0.1 HTTP/1.0\r\n'
1080                         b'Content-Length: 42\r\n\r\n')
1081            with ExpectLog(gen_log, 'Timeout reading body'):
1082                response = yield stream.read_until_close()
1083            self.assertEqual(response, b'')
1084        finally:
1085            stream.close()
1086
1087    @gen_test
1088    def test_body_size_override_reset(self):
1089        # The max_body_size override is reset between requests.
1090        stream = IOStream(socket.socket())
1091        try:
1092            yield stream.connect(('127.0.0.1', self.get_http_port()))
1093            # Use a raw stream so we can make sure it's all on one connection.
1094            stream.write(b'PUT /streaming?expected_size=10240 HTTP/1.1\r\n'
1095                         b'Content-Length: 10240\r\n\r\n')
1096            stream.write(b'a' * 10240)
1097            headers, response = yield gen.Task(read_stream_body, stream)
1098            self.assertEqual(response, b'10240')
1099            # Without the ?expected_size parameter, we get the old default value
1100            stream.write(b'PUT /streaming HTTP/1.1\r\n'
1101                         b'Content-Length: 10240\r\n\r\n')
1102            with ExpectLog(gen_log, '.*Content-Length too long'):
1103                data = yield stream.read_until_close()
1104            self.assertEqual(data, b'')
1105        finally:
1106            stream.close()
1107
1108
1109class LegacyInterfaceTest(AsyncHTTPTestCase):
1110    def get_app(self):
1111        # The old request_callback interface does not implement the
1112        # delegate interface, and writes its response via request.write
1113        # instead of request.connection.write_headers.
1114        def handle_request(request):
1115            self.http1 = request.version.startswith("HTTP/1.")
1116            if not self.http1:
1117                # This test will be skipped if we're using HTTP/2,
1118                # so just close it out cleanly using the modern interface.
1119                request.connection.write_headers(
1120                    ResponseStartLine('', 200, 'OK'),
1121                    HTTPHeaders())
1122                request.connection.finish()
1123                return
1124            message = b"Hello world"
1125            request.write(utf8("HTTP/1.1 200 OK\r\n"
1126                               "Content-Length: %d\r\n\r\n" % len(message)))
1127            request.write(message)
1128            request.finish()
1129        return handle_request
1130
1131    def test_legacy_interface(self):
1132        response = self.fetch('/')
1133        if not self.http1:
1134            self.skipTest("requires HTTP/1.x")
1135        self.assertEqual(response.body, b"Hello world")
1136